gMLP模型代码讲解
- Introduction
- gMLP网络结构
- Spatial Gating Unit (SGU)
- code
- gMLPBlock
- Spatial Gating Unit
基于MLP-Mixer 的改进…
Introduction
总的来说,gMLP 在视觉和NLP领域的惊人有效性表明,自我注意并不是扩大机器学习模型的必要因素,尽管它根据任务的不同可以是一个有用的补充随着数据和计算量的增加。具有gMLP等更简单的空间交互机制的模型可以像变压器一样强大,分配给自我注意的能力可以被删除或大幅减少。
gMLP网络结构
gMLP 的输入仍为若干图像块(即将一张图像切割成若干图像块),输出为若干个向量(token)堆叠组成的矩阵,例如token的维度为L,个数为N,则输出为N ∗ L 的矩阵,通过池化等操作转换为最终的特征向量。
由若干个基本构成单元堆叠而成
设输入矩阵(即图中的Input Embeddedings)为 n ∗ d n∗d n∗d 的矩阵X , n为序列长度, d为特征维度,则gMLP的unit结构可以简化为 Z = δ ( X U ) Z ~ = s ( Z ) Y = δ ( Z ~ V ) + X Z=\delta (XU)\\ \tilde{Z} = s(Z)\\ Y=\delta(\tilde{Z}V)+X Z=δ(XU)Z~=s(Z)Y=δ(Z~V)+X
U , V U,V U,V为可学习的矩阵, δ \delta δ 为激活函数, s ( z ) s(z) s(z) 为图中的Spatial Gating Unit.
Spatial Gating Unit (SGU)
为了能有跨token的交互, s ( ⋅ ) s(\cdot) s(⋅) 操作须在空间维度。可以简单的使用线性映射表示: f W , b ( Z ) = W Z + b s ( Z ) = Z ⊙ f W , b ( Z ) f_{W,b}(Z)=WZ+b\\ s(Z)=Z⊙f_{W,b}(Z) fW,b(Z)=WZ+bs(Z)=Z⊙fW,b(Z) 设 Z Z Z 为 n ∗ d n∗d n∗d 的矩阵,则 W W W 为 n ∗ n n∗n n∗n 的矩阵,表示空间交互的映射参数,b 为n 维向量(WZ+b表示WZ的第一行元素与b的第一维元素相加),为了保证训练的稳定性,W 初始化值接近于0(貌似用[-1,1]的均匀分布初始化),b 的初始值为1,此时 f W , b ( Z ) ≈ 1 , s ( Z ) ≈ Z f_{W,b}(Z)\approx1,s(Z)\approx Z fW,b(Z)≈1,s(Z)≈Z,这种初始化确保了每个gMLP块在训练的早期阶段像一个常规的FFN,其中每个token 都被独立处理,并且只在学习过程中逐步跨token注入空间信息。
更进一步的作者发现将Z 沿着channel维度切割成 Z 1 , Z 2 Z_1,Z_2 Z1,Z2 ( Z 1 , Z 2 Z_1,Z_2 Z1,Z2的维度分别为 n ∗ d 1 , n ∗ d 2 , d 1 + d 2 = n n*d_1,n*d_2,d_1+d_2=n n∗d1,n∗d2,d1+d2=n)两个部分更为有效,此时s(Z)操作变为
s ( Z ) = Z 1 ⊙ f W , b ( Z 2 ) s(Z)=Z_1\odot f_{W,b}(Z_2) s(Z)=Z1⊙fW,b(Z2)
code
先看整体结构,在整个gMLP结构中,gmlp代替self-attention设计了框架结构。每一个层级使用gMLPBlock作为一个block阶段。整个残差形式为gmlp(norm(x))+x
.
class gMLP(nn.Module):def __init__(self,*,...):super().__init__()dim_ff = dim * ff_multself.seq_len = seq_lenself.prob_survival = prob_survivalself.to_embed = nn.Embedding(num_tokens, dim) if exists(num_tokens) else nn.Identity()self.layers = nn.ModuleList([Residual(PreNorm(dim, gMLPBlock(dim = dim, dim_ff = dim_ff, seq_len = seq_len, attn_dim = attn_dim, causal = causal, act = act))) for i in range(depth)])# gmlp(norm(x))+xself.to_logits = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, num_tokens)) if exists(num_tokens) else nn.Identity()def forward(self, x):x = self.to_embed(x)layers = self.layers if not self.training else dropout_layers(self.layers, self.prob_survival)out = nn.Sequential(*layers)(x)return self.to_logits(out)
class PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.fn = fnself.norm = nn.LayerNorm(dim)def forward(self, x, **kwargs):x = self.norm(x)return self.fn(x, **kwargs)
gMLPBlock
class gMLPBlock(nn.Module):def __init__(self,*,dim,dim_ff,seq_len,attn_dim = None,causal = False,act = nn.Identity()):super().__init__()self.proj_in = nn.Sequential(nn.Linear(dim, dim_ff),nn.GELU())# dim_ff = dim * ff_mult(4)# dim -> dim*4self.attn = Attention(dim, dim_ff // 2, attn_dim, causal) if exists(attn_dim) else Noneself.sgu = SpatialGatingUnit(dim_ff, seq_len, causal, act)self.proj_out = nn.Linear(dim_ff // 2, dim)def forward(self, x):gate_res = self.attn(x) if exists(self.attn) else None# 默认的attn是None,即不进行该操作x = self.proj_in(x)x = self.sgu(x, gate_res = gate_res)x = self.proj_out(x)return x
Spatial Gating Unit
class SpatialGatingUnit(nn.Module):def __init__(self, dim, dim_seq, causal = False, act = nn.Identity(), init_eps = 1e-3):super().__init__()dim_out = dim // 2self.causal = causalself.norm = nn.LayerNorm(dim_out)self.proj = nn.Conv1d(dim_seq, dim_seq, 1)self.act = actinit_eps /= dim_seqnn.init.uniform_(self.proj.weight, -init_eps, init_eps)nn.init.constant_(self.proj.bias, 1.)def forward(self, x, gate_res = None):device, n = x.device, x.shape[1]res, gate = x.chunk(2, dim = -1)# self-atten 用的dim# sgu用的dim_ff = dim * ff_mult(4),即4倍# chunk之后,每个为2倍,用两倍的值进行attentiongate = self.norm(gate)weight, bias = self.proj.weight, self.proj.biasif self.causal:...gate = F.conv1d(gate, weight, bias)# 1d卷积混合w*h维度的信息,patch通道的混合if exists(gate_res):gate = gate + gate_resreturn self.act(gate) * res