目录
- 1. 为什么要引入Softmax?
- 2. Softmax的导数计算
- 3. Softmax及其导数的一些性质
- 4. 交叉熵损失的梯度计算
- 5. Softmax的各种变体
- 5.1 Naive Softmax
- 5.2 Safe Softmax
- 5.3 Online Softmax
- Ref
1. 为什么要引入Softmax?
在进行 n n n 分类任务时,神经网络的最后一层输出(logits)是一个 n n n 维向量 z ∈ R n \mathbf{z}\in \mathbb{R}^n z∈Rn。我们希望将其转化为概率分布,以便选取概率最大的类别作为神经网络的最终分类结果。
直观上来讲, z \mathbf{z} z 中哪个分量越大,相应的概率就应该越高,我们可以很容易想到使用下面的公式进行概率建模:
P ( z i ) = z i z 1 + z 2 + ⋯ + z n (1) P(z_i)=\frac{z_i}{z_1+z_2+\cdots+z_n}\tag{1} P(zi)=z1+z2+⋯+znzi(1)
显然有 ∑ i P ( z i ) = 1 \sum_i P(z_i)=1 ∑iP(zi)=1,但由于 z i ∈ R z_i\in \mathbb{R} zi∈R, P ( z i ) P(z_i) P(zi) 可能为负数,这不符合概率的性质。
注意到指数函数 e x e^x ex 满足单调性,非负性,导数不变性,并且可以放大输入值之间的差异,因此我们可以先让 z \mathbf{z} z 过一遍指数函数,然后再使用公式 ( 1 ) (1) (1) 进行概率建模,此时便得到了Softmax公式:
σ i = Softmax ( z i ) = e z i ∑ j = 1 n e z j (2) \sigma_i=\text{Softmax}(z_i)=\frac{e^{z_i}}{\sum_{j=1}^n e^{z_j}} \tag{2} σi=Softmax(zi)=∑j=1nezjezi(2)
可以看出, Softmax ( ⋅ ) \text{Softmax}(\cdot) Softmax(⋅) 函数是一个 R n → R n \mathbb{R}^n\to\mathbb{R}^n Rn→Rn 的映射。
📝 为什么Softmax函数中要选取 e x e^x ex,而不是 2 x , 3 x 2^x,3^x 2x,3x 等函数?注意到 ( a x ) ′ = a x ln a (a^x)'=a^x\ln a (ax)′=axlna,如果不选取 e x e^x ex,那么神经网络在反向传播计算梯度时会产生大量的 ln a \ln a lna,这对于计算并不友好,而选择 e x e^x ex 就可以简化掉这些繁琐的对数项。
2. Softmax的导数计算
先前我们提到过,Softmax函数接受 n n n 个输入,并产生相应的 n n n 个输出。对于每一个输出,我们都可以对任意一个输入进行求导,所以Softmax函数的导数实际上是一个 n × n n\times n n×n 的雅克比矩阵:
J = [ ∂ σ 1 ∂ z 1 ⋯ ∂ σ 1 ∂ z n ⋮ ⋱ ⋮ ∂ σ n ∂ z 1 ⋯ ∂ σ n ∂ z n ] n × n \mathbf{J}= \begin{bmatrix} \displaystyle \frac{\partial \sigma_1}{\partial z_1} & \cdots & \displaystyle \frac{\partial \sigma_1}{\partial z_n} \\ \vdots & \ddots & \vdots \\ \displaystyle \frac{\partial \sigma_n}{\partial z_1} & \cdots & \displaystyle \frac{\partial \sigma_n}{\partial z_n} \end{bmatrix}_{n\times n} J= ∂z1∂σ1⋮∂z1∂σn⋯⋱⋯∂zn∂σ1⋮∂zn∂σn n×n
接下来我们关注如何计算 ∂ σ i ∂ z j \displaystyle\frac{\partial \sigma_i}{\partial z_j} ∂zj∂σi,此时要分两种情况讨论。
当 i = j i=j i=j 时,此时计算的是对角线上的元素:
∂ σ i ∂ z i = ∂ ∂ z i ( e z i ∑ j = 1 n e z j ) = e z i ⋅ ∑ j = 1 n e z j − e z i ⋅ e z i ( ∑ j = 1 n e z j ) 2 = e z i ⋅ ( ∑ j = 1 n e z j − e z i ) ( ∑ j = 1 n e z j ) ⋅ ( ∑ j = 1 n e z j ) = e z i ∑ j = 1 n e z j ⋅ ( 1 − e z i ∑ j = 1 n e z j ) = σ i ( 1 − σ i ) \begin{aligned} \frac{\partial \sigma_i}{\partial z_i}&=\frac{\partial}{\partial z_i}\left( \frac{e^{z_i}}{\sum_{j=1}^n e^{z_j}} \right) = \frac{e^{z_i}\cdot \sum_{j=1}^n e^{z_j}-e^{z_i}\cdot e^{z_i}}{(\sum_{j=1}^n e^{z_j})^2} \\ &= \frac{e^{z_i}\cdot( \sum_{j=1}^n e^{z_j}-e^{z_i})}{(\sum_{j=1}^n e^{z_j})\cdot (\sum_{j=1}^n e^{z_j})}=\frac{e^{z_i}}{\sum_{j=1}^n e^{z_j}}\cdot \left(1-\frac{e^{z_i}}{\sum_{j=1}^n e^{z_j}} \right) \\ &= \sigma_i(1-\sigma_i) \end{aligned} ∂zi∂σi=∂zi∂(∑j=1nezjezi)=(∑j=1nezj)2ezi⋅∑j=1nezj−ezi⋅ezi=(∑j=1nezj)⋅(∑j=1nezj)ezi⋅(∑j=1nezj−ezi)=∑j=1nezjezi⋅(1−∑j=1nezjezi)=σi(1−σi)
当 i ≠ j i\neq j i=j 时,此时计算的是非对角线上的元素:
∂ σ i ∂ z j = ∂ ∂ z j ( e z i ∑ j = 1 n e z j ) = 0 − e z i ⋅ e z j ( ∑ j = 1 n e z j ) 2 = − e z i ∑ j = 1 n e z j ⋅ e z j ∑ j = 1 n e z j = − σ i σ j \begin{aligned} \frac{\partial \sigma_i}{\partial z_j}&=\frac{\partial}{\partial z_j}\left( \frac{e^{z_i}}{\sum_{j=1}^n e^{z_j}} \right)=\frac{0-e^{z_i}\cdot e^{z_j}}{(\sum_{j=1}^n e^{z_j})^2} \\ &=-\frac{e^{z_i}}{\sum_{j=1}^n e^{z_j}}\cdot \frac{e^{z_j}}{\sum_{j=1}^n e^{z_j}} \\ &=-\sigma_i\sigma_j \end{aligned} ∂zj∂σi=∂zj∂(∑j=1nezjezi)=(∑j=1nezj)20−ezi⋅ezj=−∑j=1nezjezi⋅∑j=1nezjezj=−σiσj
有了这些结果,我们可以尝试对雅可比矩阵进行一些变换:
J = [ σ 1 ( 1 − σ 1 ) − σ 1 σ 2 ⋯ − σ 1 σ 2 − σ 2 σ 1 σ 2 ( 1 − σ 2 ) ⋯ − σ 2 σ n ⋮ ⋮ ⋱ ⋮ − σ n σ 1 − σ n σ 2 ⋯ σ n ( 1 − σ n ) ] = [ σ 1 0 ⋯ 0 0 σ 2 ⋯ 0 ⋮ ⋮ ⋱ ⋮ 0 0 ⋯ σ n ] − [ σ 1 σ 1 σ 1 σ 2 ⋯ σ 1 σ n σ 2 σ 1 σ 2 σ 2 ⋯ σ 2 σ n ⋮ ⋮ ⋱ ⋮ σ n σ 1 σ n σ 2 ⋯ σ n σ n ] = diag ( σ ) − σ σ T \begin{aligned} \mathbf{J}&= \begin{bmatrix} \sigma_1(1-\sigma_1) & -\sigma_1\sigma_2 & \cdots & -\sigma_1\sigma_2 \\ -\sigma_2\sigma_1 & \sigma_2(1-\sigma_2) & \cdots &-\sigma_2\sigma_n \\ \vdots & \vdots & \ddots & \vdots \\ -\sigma_n\sigma_1 & -\sigma_n\sigma_2 & \cdots & \sigma_n(1-\sigma_n) \end{bmatrix} \\ &= \begin{bmatrix} \sigma_1 & 0 & \cdots & 0 \\ 0 & \sigma_2 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & \sigma_n \end{bmatrix}- \begin{bmatrix} \sigma_1\sigma_1 & \sigma_1 \sigma_2 & \cdots & \sigma_1 \sigma_n \\ \sigma_2 \sigma_1 & \sigma_2\sigma_2 & \cdots & \sigma_2 \sigma_n \\ \vdots & \vdots & \ddots & \vdots \\ \sigma_n \sigma_1 & \sigma_n \sigma_2 & \cdots & \sigma_n\sigma_n \end{bmatrix} \\ &=\boxed{\text{diag}(\sigma)-\sigma\sigma^{\text{T}}} \end{aligned} J= σ1(1−σ1)−σ2σ1⋮−σnσ1−σ1σ2σ2(1−σ2)⋮−σnσ2⋯⋯⋱⋯−σ1σ2−σ2σn⋮σn(1−σn) = σ10⋮00σ2⋮0⋯⋯⋱⋯00⋮σn − σ1σ1σ2σ1⋮σnσ1σ1σ2σ2σ2⋮σnσ2⋯⋯⋱⋯σ1σnσ2σn⋮σnσn =diag(σ)−σσT
其中 σ = ( σ 1 , σ 2 , ⋯ , σ n ) \sigma=(\sigma_1,\sigma_2,\cdots,\sigma_n) σ=(σ1,σ2,⋯,σn) 是列向量, diag ( σ ) \text{diag}(\sigma) diag(σ) 表示使用向量 σ \sigma σ 构建对角矩阵。
我们可以利用这一结果,编写相应的代码实现Softmax及其导数计算:
import numpy as npdef softmax(z):e_z = np.exp(z)return e_z / e_z.sum()def softmax_jacobian(z):s = softmax(z).reshape(-1, 1)diag_s = np.diagflat(s)outer_s = np.dot(s, s.T)return diag_s - outer_sif __name__ == "__main__":z = np.array([1.0, 2.0, 3.0])print("Input vector z:")print(z)print("\nSoftmax of z:")s = softmax(z)print(s)print("\nJacobian matrix of the softmax function at z:")J = softmax_jacobian(z)print(J)
输出:
Input vector z:
[1. 2. 3.]Softmax of z:
[0.09003057 0.24472847 0.66524096]Jacobian matrix of the softmax function at z:
[[ 0.08192507 -0.02203304 -0.05989202][-0.02203304 0.18483645 -0.1628034 ][-0.05989202 -0.1628034 0.22269543]]
3. Softmax及其导数的一些性质
性质1:Softmax具有平移不变性,即
Softmax ( [ z 1 + a , z 2 + a , ⋯ , z n + a ] ) = Softmax ( [ z 1 , z 2 , ⋯ , z n ] ) \text{Softmax}([z_1+a,z_2+a,\cdots,z_n+a])= \text{Softmax}([z_1,z_2,\cdots,z_n]) Softmax([z1+a,z2+a,⋯,zn+a])=Softmax([z1,z2,⋯,zn])
证明:这是一个十分显然的事情
LHS = e z i + a ∑ j = 1 n e z j + a = e z i ∑ j = 1 n e z j = RHS \text{LHS}=\frac{e^{z_i+a}}{\sum_{j=1}^ne^{z_j+a}}=\frac{e^{z_i}}{\sum_{j=1}^ne^{z_j}}=\text{RHS} LHS=∑j=1nezj+aezi+a=∑j=1nezjezi=RHS
性质2:Softmax的导数 J \mathbf{J} J 中,任意一行或任意一列的和均为 0 0 0.
证明:因为 J \mathbf{J} J 是一个对称矩阵,即 J = J T \mathbf{J}=\mathbf{J}^{\text{T}} J=JT,我们只需证明 J \mathbf{J} J 的任意一行的和为 0 0 0 即可。
∑ j = 1 n J i j = ∑ j ≠ i ( − σ i σ j ) + σ i ( 1 − σ i ) = − ∑ j ≠ i σ i σ j − σ i σ i + σ i = − σ i + σ i = 0 \sum_{j=1}^n \mathbf{J}_{ij}=\sum_{j\neq i} (-\sigma_i\sigma_j) + \sigma_i(1-\sigma_i)= -\sum_{j\neq i} \sigma_i\sigma_j-\sigma_i\sigma_i+\sigma_i=-\sigma_i+\sigma_i=0 j=1∑nJij=j=i∑(−σiσj)+σi(1−σi)=−j=i∑σiσj−σiσi+σi=−σi+σi=0
由以上结果可知 J 1 = 0 \mathbf{J}\mathbf{1}=\mathbf{0} J1=0,这说明 1 \mathbf{1} 1 是 J \mathbf{J} J 的一个特征向量,对应的特征值为 0 0 0。
性质3: J \mathbf{J} J 为半正定矩阵且 rank ( J ) = n − 1 \text{rank}(\mathbf{J})=n-1 rank(J)=n−1。
证明:对于任意向量 x ∈ R n \mathbf{x} \in \mathbb{R}^n x∈Rn,考虑二次型:
x T J x = x T ( diag ( σ ) − σ σ T ) x = ∑ i = 1 n σ i x i 2 − ( ∑ i = 1 n σ i x i ) 2 \mathbf{x}^\mathrm{T} \mathbf{J} \mathbf{x} = \mathbf{x}^\mathrm{T} \left( \text{diag}(\sigma) - \sigma \sigma^\mathrm{T} \right) \mathbf{x} = \sum_{i=1}^n \sigma_i x_i^2 - \left( \sum_{i=1}^n \sigma_i x_i \right)^2 xTJx=xT(diag(σ)−σσT)x=i=1∑nσixi2−(i=1∑nσixi)2
将 σ i \sigma_i σi 视为概率分布,定义随机变量 X X X,其取值为 x i x_i xi,概率为 σ i \sigma_i σi。则:
x T J x = E [ X 2 ] − ( E [ X ] ) 2 = Var ( X ) ≥ 0 \mathbf{x}^\mathrm{T} \mathbf{J} \mathbf{x} = \mathbb{E}[X^2] - (\mathbb{E}[X])^2 = \operatorname{Var}(X) \geq 0 xTJx=E[X2]−(E[X])2=Var(X)≥0
因此 J \mathbf{J} J 是半正定矩阵。
当 Var ( X ) > 0 \operatorname{Var}(X) > 0 Var(X)>0 时,意味着随机变量 X X X 至少有两个不同的取值,也就是说 x \mathbf{x} x 不是一个常数向量,且 ∥ x ∥ 2 > 0 \Vert \mathbf{x}\Vert^2>0 ∥x∥2>0,从而
0 < Var ( X ) = x T J x = λ ∥ x ∥ 2 ⇒ λ > 0 0<\operatorname{Var}(X)=\mathbf{x}^\mathrm{T} \mathbf{J} \mathbf{x}=\lambda \Vert \mathbf{x}\Vert^2 \quad \Rightarrow\quad \lambda >0 0<Var(X)=xTJx=λ∥x∥2⇒λ>0
结合性质2可知 J \mathbf{J} J 只有一个零特征值,剩余的 n − 1 n-1 n−1 个特征值均为正数,从而 rank ( J ) = n − 1 \text{rank}(\mathbf{J})=n-1 rank(J)=n−1。
4. 交叉熵损失的梯度计算
设Softmax的输入为 z \mathbf{z} z,输出为 σ \mathbf{\sigma} σ,标签为 y \mathbf{y} y(one-hot向量),那么交叉熵损失定义为:
L = − ∑ i = 1 n y i ⋅ log σ i = − y T log σ \mathcal{L}=-\sum_{i=1}^n y_i\cdot \log \sigma_i=-\mathbf{y}^{\text{T}}\log \sigma L=−i=1∑nyi⋅logσi=−yTlogσ
于是
∂ L ∂ z = ( ∂ σ ∂ z ) T ⋅ ∂ L ∂ σ = ( diag ( σ ) − σ σ T ) ⋅ ( − y σ ) = σ − y \begin{aligned} \frac{\partial \mathcal{L}}{\partial \mathbf{z}}=\left( \frac{\partial \mathcal{\sigma}}{\partial \mathbf{z}}\right)^{\text{T}}\cdot \frac{\partial \mathcal{L}}{\partial \mathbf{\sigma}}= (\text{diag}(\sigma) - \sigma \sigma^\mathrm{T})\cdot \left(-\frac{\mathbf{y}}{\sigma}\right)=\sigma-\mathbf{y} \end{aligned} ∂z∂L=(∂z∂σ)T⋅∂σ∂L=(diag(σ)−σσT)⋅(−σy)=σ−y
这说明损失函数对输入 z \mathbf{z} z 的梯度可以简洁地表示为Softmax的输出与真实标签之间的差值。
5. Softmax的各种变体
在这一节中,我们会使用C++来实现Softmax的各种变体。
5.1 Naive Softmax
即最原始的softmax。总共需要两次遍历,第一次遍历是用来计算分母,第二次遍历是用来计算每一个输出。
std::vector<double> softmax(const std::vector<double>& input) {std::vector<double> output(input.size());double sum = 0.0;for (size_t i = 0; i < input.size(); i++) {output[i] = std::exp(input[i]);sum += output[i];}for (size_t i = 0; i < output.size(); i++) {output[i] /= sum;}return output;
}
5.2 Safe Softmax
在实际计算中, e x e^x ex 很容易发生上溢问题。例如,在 Python 中,当 x ≥ 710 x \geq 710 x≥710 时,计算 e x e^x ex 就会导致数值溢出。此外,在深度学习模型中,fp16 精度广泛用于加速训练,但它的最大表示范围仅为 65536 65536 65536。这意味着,当 x ≥ 11 x \geq 11 x≥11 时, e x e^x ex 的值将超过 fp16 能够表示的数值范围,进而导致上溢错误。
鉴于Softmax的平移不变性以及 e x e^x ex 对于负输入的计算会更加精确,我们可以通过将所有输入减去其最大值来缓解上溢问题,即计算 Softmax ( z − max ( z ) ) \text{Softmax}(\mathbf{z} - \max(\mathbf{z})) Softmax(z−max(z))。这一操作不仅可以避免数值溢出,还能提升计算的稳定性和精度,并且不影响Softmax的输出结果。
大部分的深度学习框架都采用了Safe Softmax,但它的缺点是需要进行三次遍历,第一次用来统计最大值,第二次用来计算分母,第三次用来计算每个输出。
std::vector<double> safe_softmax(const std::vector<double>& input) std::vector<double> safe_softmax(const std::vector<double>& input) {std::vector<double> output(input.size());double max_val = input[0];for (size_t i = 1; i < input.size(); i++) {max_val = std::max(max_val, input[i]);}double sum = 0.0;for (size_t i = 0; i < input.size(); i++) {output[i] = std::exp(input[i] - max_val);sum += output[i];}for (size_t i = 0; i < output.size(); i++) {output[i] /= sum;}return output;
}
5.3 Online Softmax
Safe Softmax需要三次遍历,那有没有可能将其缩减至两次遍历呢?
定义 m k = max i = 1 k z i m_k=\max_{i=1}^k z_i mk=maxi=1kzi,那么有 m n = max ( z ) m_n=\max(\mathbf{z}) mn=max(z)。定义 d k = ∑ i = 1 k e z i − m n d_k=\sum_{i=1}^k e^{z_i-m_n} dk=∑i=1kezi−mn,那么 d n d_n dn 就是Safe Softmax中的分母。
回顾Safe Softmax的计算过程,为了得到 d n d_n dn,我们需要从 d 1 d_1 d1 计算到 d n d_n dn,但无论是哪一个 d k d_k dk,都会涉及到 m n m_n mn,这意味着在此之前必须先通过一轮循环得到 m n m_n mn。
能否把计算 m n m_n mn 的这一轮循环砍掉呢?一个直观的想法就是在计算 d 1 d_1 d1 到 d n d_n dn 的过程中同时去计算 m 1 m_1 m1 到 m n m_n mn。但此时计算 d k d_k dk 时,我们只有 m k m_k mk,并没有 m n m_n mn,该怎么办呢?
先将计就计试一下,令 d k ′ = ∑ i = 1 k e z i − m k d'_k=\sum_{i=1}^ke^{z_i-m_k} dk′=∑i=1kezi−mk,容易发现
d n ′ = ∑ i = 1 n e z i − m n = d n d'_n=\sum_{i=1}^ne^{z_i-m_n}=d_n dn′=i=1∑nezi−mn=dn
这说明我们可以用 d n ′ d'_n dn′ 作为Safe Softmax的分母。接下来只需要关注如何去计算 d k ′ d'_k dk′。注意到
d k ′ = ∑ i = 1 k e z i − m k = ∑ i = 1 k − 1 e z i − m k + e z k − m k = ( ∑ i = 1 k − 1 e z i − m k − 1 ) ⋅ e m k − 1 − m k + e z k − m k = d k − 1 ′ ⋅ e m k − 1 − m k + e z k − m k \begin{aligned} d'_k&=\sum_{i=1}^ke^{z_i-m_k} \\ &=\sum_{i=1}^{k-1} e^{z_i-m_k} +e^{z_k-m_k} \\ &=\left(\sum_{i=1}^{k-1} e^{z_i-m_{k-1}}\right)\cdot e^{m_{k-1}-m_k}+e^{z_k-m_k} \\ &=d'_{k-1}\cdot e^{m_{k-1}-m_k}+e^{z_k-m_k} \end{aligned} dk′=i=1∑kezi−mk=i=1∑k−1ezi−mk+ezk−mk=(i=1∑k−1ezi−mk−1)⋅emk−1−mk+ezk−mk=dk−1′⋅emk−1−mk+ezk−mk
且 d 1 ′ = e z 1 − m 1 d'_1=e^{z_1-m_1} d1′=ez1−m1,于是我们可以用这个递推式来计算得到 d n ′ d'_n dn′。
std::vector<double> online_softmax(const std::vector<double>& input) {std::vector<double> output(input.size());double pre_max_val = input[0];double sum = std::exp(input[0] - pre_max_val);for (size_t i = 1; i < input.size(); i++) {double cur_max_val = std::max(pre_max_val, input[i]);sum = sum * std::exp(pre_max_val - cur_max_val) + std::exp(input[i] - cur_max_val);pre_max_val = cur_max_val;}for (size_t i = 0; i < input.size(); i++) {output[i] = std::exp(input[i] - pre_max_val) / sum;}return output;
}
Ref
[1] https://zhuanlan.zhihu.com/p/638788074
[2] https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf
[3] https://arxiv.org/pdf/1805.02867v2