矩阵求导
https://blog.csdn.net/qq_39942341/article/details/128739604?spm=1001.2014.3001.5502
(看微分那部分就够了)
设X∈RB×m,W1∈Rn×m,1∈Rn×1,b1∈R1×n,Y1∈RB×n\mathbf{X}\in \mathbb{R}^{B\times m},\mathbf{W}_1\in \mathbb{R}^{n\times m},\mathbf{1}\in \mathbb{R}^{n\times1},\mathbf{b}_1\in\mathbb{R}^{1\times n},\mathbf{Y}_1\in\mathbb{R}^{B\times n}X∈RB×m,W1∈Rn×m,1∈Rn×1,b1∈R1×n,Y1∈RB×n
W2∈Rp×n,b1∈R1×p,Y2∈RB×p\mathbf{W}_2\in \mathbb{R}^{p\times n},\mathbf{b}_1\in\mathbb{R}^{1\times p},\mathbf{Y}_2\in\mathbb{R}^{B\times p}W2∈Rp×n,b1∈R1×p,Y2∈RB×p
σ(⋅)\sigma\left(\cdot\right)σ(⋅)是激活函数,例如sigmoid
Y1=XW1T+1b1A1=σ(Y1)Y2=A1W2T+1b2A2=σ(Y2)l=12mse(A,A2)=12∥A−A2∥F2\mathbf{Y}_1 = \mathbf{X}\mathbf{W}_1^T + \mathbf{1}\mathbf{b}_1\\ \mathbf{A}_1 = \sigma\left(\mathbf{Y}_1\right)\\ \mathbf{Y}_2 = \mathbf{A}_1\mathbf{W}_2^T +\mathbf{1}\mathbf{b}_2\\ \mathbf{A}_2 = \sigma\left(\mathbf{Y}_2\right)\\ l = \frac{1}{2}mse\left(\mathbf{A},\mathbf{A}_2\right) = \frac{1}{2}\|\mathbf{A}-\mathbf{A}_2\|_F^2 Y1=XW1T+1b1A1=σ(Y1)Y2=A1W2T+1b2A2=σ(Y2)l=21mse(A,A2)=21∥A−A2∥F2
∂l∂A2=A2−A\frac{\partial l}{\partial \mathbf{A}_2} = \mathbf{A}_2 - \mathbf{A} ∂A2∂l=A2−A
dl=tr(∂l∂A2TdA2)=tr(∂l∂A2Tdσ(Y2))=tr(∂l∂A2Tσ′(Y2)dY2)=tr((∂l∂A2⊙σ′(Y2))TdY2)=tr(∂l∂Y2TdY2)\begin{aligned} \rm{d}l &= tr\left(\frac{\partial l}{\partial \mathbf{A}_2}^T \rm{d}\mathbf{A}_2\right)\\ &=tr\left(\frac{\partial l}{\partial \mathbf{A}_2}^T \rm{d}\sigma\left(\mathbf{Y}_2\right)\right)\\ &=tr\left(\frac{\partial l}{\partial \mathbf{A}_2}^T\sigma^\prime\left(\mathbf{Y}_2\right) \rm{d}\mathbf{Y}_2\right) \\ &= tr\left(\left(\frac{\partial l}{\partial \mathbf{A}_2}\odot\sigma^\prime\left(\mathbf{Y}_2\right) \right)^T\rm{d}\mathbf{Y}_2\right) \\ &= tr\left(\frac{\partial l}{\partial \mathbf{Y}_2}^T\rm{d}\mathbf{Y}_2\right) \end{aligned} dl=tr(∂A2∂lTdA2)=tr(∂A2∂lTdσ(Y2))=tr(∂A2∂lTσ′(Y2)dY2)=tr((∂A2∂l⊙σ′(Y2))TdY2)=tr(∂Y2∂lTdY2)
因此
∂l∂Y2=∂l∂A2⊙σ′(Y2)\frac{\partial l}{\partial \mathbf{Y}_2} = \frac{\partial l}{\partial \mathbf{A}_2}\odot\sigma^\prime\left(\mathbf{Y}_2\right) ∂Y2∂l=∂A2∂l⊙σ′(Y2)
dl=tr(∂l∂Y2TdY2)=tr(∂l∂Y2Td(A1W2T+1b2))=tr(∂l∂Y2T(dA1)W2T)+tr(∂l∂Y2TA1(dW2T))+tr(∂l∂Y2T1d(db2))=tr(W2T∂l∂Y2T(dA1))+tr((dW2T)∂l∂Y2TA1)+tr(∂l∂Y2T1d(db2))\begin{aligned} \rm{d}l &= tr\left(\frac{\partial l}{\partial \mathbf{Y}_2}^T\rm{d}\mathbf{Y}_2\right)\\ &= tr\left(\frac{\partial l}{\partial \mathbf{Y}_2}^T\rm{d}\left(\mathbf{A}_1\mathbf{W}_2^T +\mathbf{1}\mathbf{b}_2\right)\right)\\ &= tr\left(\frac{\partial l}{\partial \mathbf{Y}_2}^T\left(\rm{d}\mathbf{A}_1\right)\mathbf{W}_2^T\right) + tr\left(\frac{\partial l}{\partial \mathbf{Y}_2}^T\mathbf{A}_1\left(\rm{d}\mathbf{W}_2^T\right)\right) + tr\left(\frac{\partial l}{\partial \mathbf{Y}_2}^T\mathbf{1}\rm{d}\left(\rm{d}\mathbf{b}_2\right)\right)\\ &= tr\left(\mathbf{W}_2^T\frac{\partial l}{\partial \mathbf{Y}_2}^T\left(\rm{d}\mathbf{A}_1\right)\right) + tr\left(\left(\rm{d}\mathbf{W}_2^T\right)\frac{\partial l}{\partial \mathbf{Y}_2}^T\mathbf{A}_1\right) + tr\left(\frac{\partial l}{\partial \mathbf{Y}_2}^T\mathbf{1}\rm{d}\left(\rm{d}\mathbf{b}_2\right)\right)\\ \end{aligned} dl=tr(∂Y2∂lTdY2)=tr(∂Y2∂lTd(A1W2T+1b2))=tr(∂Y2∂lT(dA1)W2T)+tr(∂Y2∂lTA1(dW2T))+tr(∂Y2∂lT1d(db2))=tr(W2T∂Y2∂lT(dA1))+tr((dW2T)∂Y2∂lTA1)+tr(∂Y2∂lT1d(db2))
因此
∂l∂A1=∂l∂Y2W2∂l∂W2=∂l∂Y2TA1∂l∂b2=1T∂l∂Y2\frac{\partial l}{\partial \mathbf{A}_1} = \frac{\partial l}{\partial \mathbf{Y}_2}\mathbf{W}_2\\ \frac{\partial l}{\partial \mathbf{W}_2} = \frac{\partial l}{\partial \mathbf{Y}_2}^T\mathbf{A}_1\\ \frac{\partial l}{\partial \mathbf{b}_2} =\mathbf{1}^T\frac{\partial l}{\partial \mathbf{Y}_2}\\ ∂A1∂l=∂Y2∂lW2∂W2∂l=∂Y2∂lTA1∂b2∂l=1T∂Y2∂l
同理
∂l∂Y1=∂l∂A1⊙σ′(Y1)∂l∂W2=∂l∂Y1TX∂l∂b1=1T∂l∂Y1\frac{\partial l}{\partial \mathbf{Y}_1} = \frac{\partial l}{\partial \mathbf{A}_1}\odot\sigma^\prime\left(\mathbf{Y}_1\right)\\ \frac{\partial l}{\partial \mathbf{W}_2} = \frac{\partial l}{\partial \mathbf{Y}_1}^T\mathbf{X}\\ \frac{\partial l}{\partial \mathbf{b}_1} =\mathbf{1}^T\frac{\partial l}{\partial \mathbf{Y}_1}\\ ∂Y1∂l=∂A1∂l⊙σ′(Y1)∂W2∂l=∂Y1∂lTX∂b1∂l=1T∂Y1∂l
如果采用sigmoid,则σ′(X)=σ(X)(1−σ(X))\sigma^{\prime}\left(\mathbf{X}\right) =\sigma\left(\mathbf{X}\right)\left(1-\sigma\left(\mathbf{X}\right)\right)σ′(X)=σ(X)(1−σ(X))
如果采用relu,则[σ′(X)]ij={1,Xij>00,otherwise\left[\sigma^{\prime}\left(\mathbf{X}\right)\right]_{ij} =\begin{cases} 1,X_{ij}>0\\ 0, otherwise \end{cases}[σ′(X)]ij={1,Xij>00,otherwise
代码验证
#!/usr/bin/env python
# _*_ coding:utf-8 _*_
import torch
from torch import nndef sigmoid_derivative(Y):return Y * (1 - Y)def relu_derivative(Y):return torch.where(Y > 0, 1, 0)if __name__ == '__main__':B, m, n, p = 3, 5, 4, 6linear1 = nn.Linear(m, n)active1 = nn.Sigmoid()derivative_1 = sigmoid_derivativelinear2 = nn.Linear(n, p)active2 = nn.ReLU()derivative_2 = relu_derivativeA = torch.randn(B, p)X = torch.randn(B, m, requires_grad=True)Y1 = linear1(X)A1 = active1(Y1)Y2 = linear2(A1)A2 = active2(Y2)# 1/2 mse(A2, A)l = torch.sum((A2 - A) ** 2) * 0.5l.backward()grad_A2 = A2 - Agrad_Y2 = grad_A2 * derivative_2(A2)grad_W2 = torch.mm(grad_Y2.T, A1)grad_b2 = torch.mm(torch.ones(B, 1).T, grad_Y2)print(torch.allclose(grad_W2, linear2.weight.grad))print(torch.allclose(grad_b2, linear2.bias.grad))grad_A1 = torch.mm(grad_Y2, linear2.weight)grad_Y1 = grad_A1 * derivative_1(A1)grad_W1 = torch.mm(grad_Y1.T, X)grad_b1 = torch.mm(torch.ones(B, 1).T, grad_Y1)print(torch.allclose(grad_W1, linear1.weight.grad))print(torch.allclose(grad_b1, linear1.bias.grad))
对于行向量a∈R1×n\mathbf{a}\in\mathbb{R}^{1\times n}a∈R1×n
softmax(a)=eaea1nsoftmax\left(\mathbf{a}\right) = \frac{e^{\mathbf{a}}}{e^{\mathbf{a}}\mathbf{1}_n} softmax(a)=ea1nea
其中1n∈Rn\mathbf{1}_n\in\mathbb{R}^n1n∈Rn,为全1向量
设y∈R1×n\mathbf{y}\in\mathbb{R}^{1\times n}y∈R1×n只有一个元素为1,其他元素为0
交叉熵
ce(a,y)=−log(softmax(a))yT=−(a−log(ea1n)1nT)yT=−ayT+log(ea1n)\begin{aligned} ce\left(\mathbf{a},\mathbf{y}\right) &= -\log\left(softmax\left(\mathbf{a}\right)\right)\mathbf{y}^T\\ &= -\left(\mathbf{a}-\log \left(e^{\mathbf{a}}\mathbf{1}_n\right)\mathbf{1}_n^T\right)\mathbf{y}^T\\ &= -\mathbf{a}\mathbf{y}^T+\log\left(e^{\mathbf{a}}\mathbf{1}_n\right) \end{aligned} ce(a,y)=−log(softmax(a))yT=−(a−log(ea1n)1nT)yT=−ayT+log(ea1n)
求导
dl=tr(−(da)yT+1ea1n(ea⊙da)1n)=tr(−(da)yT+1ea1n(1nT)T(ea⊙da))=tr(−(da)yT+1ea1n(1nT⊙ea)T(da))=tr(−(da)yT+1ea1n(ea)T(da))=tr(−yT(da)+(softmax(a))T(da))\begin{aligned} \rm{d}l &= tr\left(-\left(\rm{d}\mathbf{a}\right) \mathbf{y}^T + \frac{1}{e^{\mathbf{a}}\mathbf{1}_n}\left(e^{\mathbf{a}}\odot\rm{d} \mathbf{a}\right)\mathbf{1}_n\right)\\ &= tr\left(-\left(\rm{d}\mathbf{a}\right) \mathbf{y}^T + \frac{1}{e^{\mathbf{a}}\mathbf{1}_n}\left(\mathbf{1}_n^T\right)^T\left(e^{\mathbf{a}}\odot\rm{d} \mathbf{a}\right)\right)\\ &= tr\left(-\left(\rm{d}\mathbf{a}\right) \mathbf{y}^T + \frac{1}{e^{\mathbf{a}}\mathbf{1}_n}\left(\mathbf{1}_n^T\odot e^{\mathbf{a}}\right)^T\left(\rm{d} \mathbf{a}\right)\right)\\ &= tr\left(-\left(\rm{d}\mathbf{a}\right) \mathbf{y}^T + \frac{1}{e^{\mathbf{a}}\mathbf{1}_n}\left(e^{\mathbf{a}} \right)^T\left(\rm{d} \mathbf{a}\right)\right)\\ &= tr\left(-\mathbf{y}^T\left(\rm{d}\mathbf{a}\right) + \left(softmax\left(\mathbf{a}\right) \right)^T\left(\rm{d} \mathbf{a}\right)\right)\\ \end{aligned} dl=tr(−(da)yT+ea1n1(ea⊙da)1n)=tr(−(da)yT+ea1n1(1nT)T(ea⊙da))=tr(−(da)yT+ea1n1(1nT⊙ea)T(da))=tr(−(da)yT+ea1n1(ea)T(da))=tr(−yT(da)+(softmax(a))T(da))
于是
∂l∂a=softmax(a)−y\frac{\partial l}{\partial \mathbf{a}} = softmax\left(\mathbf{a}\right)-\mathbf{y} ∂a∂l=softmax(a)−y
设A∈RB×n,Y∈RB×n\mathbf{A}\in\mathbb{R}^{B\times n},\mathbf{Y} \in\mathbb{R}^{B\times n}A∈RB×n,Y∈RB×n,
其中Y\mathbf{Y}Y每行只有一个元素为1,其他元素为0
设ai\mathbf{a}_iai表示A\mathbf{A}A第iii行
设yi\mathbf{y}_iyi表示Y\mathbf{Y}Y第iii行
softmax(A)=(softmax(a1)softmax(a2)⋮softmax(aB))softmax\left(\mathbf{A}\right) = \begin{pmatrix} softmax\left(\mathbf{a}_1\right)\\ softmax\left(\mathbf{a}_2\right)\\ \vdots\\ softmax\left(\mathbf{a}_B\right)\\ \end{pmatrix}softmax(A)=softmax(a1)softmax(a2)⋮softmax(aB)
设1B∈RB\mathbf{1}_{B}\in\mathbb{R}^B1B∈RB,为全1向量
ce(A,Y)=∑i=1Bce(ai,yi)=1BTlog(eA1n)−tr(AYT)ce\left(\mathbf{A},\mathbf{Y}\right) = \sum_{i=1}^{B}ce\left(\mathbf{a}_i,\mathbf{y}_i\right) = \mathbf{1}_B^T\log\left(e^{\mathbf{A}}\mathbf{1}_n\right)-tr\left(\mathbf{A}\mathbf{Y}^T\right)ce(A,Y)=∑i=1Bce(ai,yi)=1BTlog(eA1n)−tr(AYT)
求导得
∂l∂A=(∂l∂a1∂l∂a2⋮∂l∂aB)=(softmax(a1)−y1softmax(a2)−y2⋮softmax(aB)−yB)=softmax(A)−Y\frac{\partial l}{\partial \mathbf{A}} = \begin{pmatrix} \frac{\partial l}{\partial \mathbf{a}_1}\\ \frac{\partial l}{\partial \mathbf{a}_2}\\ \vdots\\ \frac{\partial l}{\partial \mathbf{a}_B}\\ \end{pmatrix} = \begin{pmatrix} softmax\left(\mathbf{a}_1\right) - \mathbf{y}_1\\ softmax\left(\mathbf{a}_2\right)- \mathbf{y}_2\\ \vdots\\ softmax\left(\mathbf{a}_B\right)- \mathbf{y}_B\\ \end{pmatrix}=softmax\left(\mathbf{A}\right)-\mathbf{Y} ∂A∂l=∂a1∂l∂a2∂l⋮∂aB∂l=softmax(a1)−y1softmax(a2)−y2⋮softmax(aB)−yB=softmax(A)−Y
验证:
#!/usr/bin/env python
# _*_ coding:utf-8 _*_
import torch
from torch import nn
import torch.nn.functional as Fif __name__ == '__main__':B, n = 3, 4ce = nn.CrossEntropyLoss(reduction='sum')target = torch.empty(B, dtype=torch.long).random_(n)target_one_hot = F.one_hot(target, num_classes=n)A = torch.randn(B, n, requires_grad=True)l = ce(A, target)l.backward()ones_B = torch.ones(B, 1)ones_n = torch.ones(n, 1)output = torch.mm(ones_B.T, torch.log(torch.mm(torch.exp(A), ones_n))) - (torch.mm(A, target_one_hot.T.float())).trace()print(torch.allclose(output, l))grad_A = F.softmax(A, dim=1) - target_one_hotprint(torch.allclose(grad_A, A.grad))