Swin Transformer解读
论文题目:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.
官方代码地址:https://github.com/microsoft/Swin-Transformer.
引言与概括
-
ICCV2021的最佳论文作者是来自微软亚洲研究院。
-
SwinTransformer名字的由来:ShiftedWindows,swin名字的由来S+win。
-
标题中提到的两个重点 Hierarchical:层级,为的是可以提取出多尺度的图像特征 Shifted Window:移动窗口。
摘要概括:
本文提出了一种新的vision Transformer,称为Swin Transformer,它能够作为计算机视觉的通用骨干网络。
从语言到视觉的挑战来自于这两个领域之间的差异,比如视觉实体规模的变化很大,以及图像中像素比文本中单词的高分辨率。为了解决这些差异,我们提出了一个分层Transformer,其表示是由Shifted windows计算的。Shifted windows方案将自注意计算限制在非重叠的局部窗口上,同时允许跨窗口连接,从而提高了更高的效率。
核心总结: 将之前在Vision Transform中的那个Vision Transform Block块 两个连接在一起,将第一个块中的多头注意力机制模块替换为了W-MSA 第二个块中的多头注意力机制模块替换为了SW-MSA
论文中提到了之前的Vision Transform中存在的一个问题(或者说是挑战)
输入的尺寸是W x H 而输出的尺寸依然是W x H 没有对其进行下采样的处理,从而缺乏多尺度的检测能力。
- MSA : multi-head self-attention
- W-MSA : windows - multi-head self-attention
- SW-MSA:shifted windows - multi-head self-attention
从而解决了Transform中的多尺度检测的问题。
整体网络结构(swin-tiny版本)
我们使用的是SWin -T的版本来进行讲解和说明的。
在代码实现的时候更多是参考下面的层来进行实现的。(执行分类任务加了一个用来分类的头)
It first splits an input RGB image into non-overlapping patches by a patch splitting module, like ViT. Each patch is treated as a “token” and its feature is set as a concatenation of the raw pixel RGB。values. In our implementation, we usea patch size of 4×4 and thus the feature dimension of each patch is 4 × 4 × 3 = 48. A linear embedding layer is applied on this raw-valued feature to project it to an arbitrary dimension
和vision Transform 一样采样的是 Embedding层,对输入的图片进行处理,从而将输入的图像转变为token。在通过Linear Embedding将通道数扩展为原来的2倍。(其他的模型将通道的数目扩展到C)
从而将图像的高宽下采样了4倍,得到了对应的图中的第一个位置的w/4 H/4和 48的结果信息。
这两步在代码上面是使用的一个4x4的卷积层来之间进行实现的得到的是56x56x96的结果,最后将56x56进行一个展平的操作步骤。
对于不同的任务来说,在结构的最后需要使用到不同的形式来进行,例如对于分类任务在加入一个全连接层,和分类头。检测或者其他的任务可以进行进一步的类比。
之后的Patch Merging就相当于是一个进行下采样的一个操作。
Patch Merging部分
前面有说,在每个Stage中首先要通过一个Patch Merging层进行下采样(Stage1除外)。如下图所示,假设输入Patch Merging的是一个4x4大小的单通道特征图(feature map),Patch Merging会将每个2x2的相邻像素划分为一个patch,然后将每个patch中相同位置(同一颜色)像素给拼在一起就得到了4个feature map。接着将这四个feature map在深度方向进行concat拼接,然后在通过一个LayerNorm层。最后通过一个全连接层在feature map的深度方向做线性变化,将feature map的深度由C变成C/2。通过这个简单的例子可以看出,通过Patch Merging层后,feature map的高和宽会减半,深度会翻倍。
Windows Multi-head Self-Attention(W-MSA)
引入Windows Multi-head Self-Attention(W-MSA
)模块是为了减少计算量
。如下图所示,左侧使用的是普通的Multi-head Self-Attention(MSA)模块,对于feature map中的每个像素
(或称作token
,patch
)在Self-Attention计算过程中需要和所有的像素
去计算。但在图右侧,在使用Windows Multi-head Self-Attention(W-MSA)模块时,首先将feature map按照MxM
大小划分成一个个Windows
,然后单独对每个Windows内部
进行Self-Attention。
- 方式:在每个window中独立的计算MSA
- 优势:计算量较小、
- 弊端:Windows之间没有信息交互
对于节省了多少计算量论文中给出了相关的公式来进行进一步的描述。
Ω ( M S A ) = 4 h w C 2 + 2 ( h w ) 2 C Ω ( W − M S A ) = 4 h w C 2 + 2 M 2 h w C \begin{array}{l} \Omega(\mathrm{MSA})=4 h w C^{2}+2(h w)^{2} C \\ \Omega(\mathrm{W}-\mathrm{MSA})=4 h w C^{2}+2 M^{2} h w C \end{array} Ω(MSA)=4hwC2+2(hw)2CΩ(W−MSA)=4hwC2+2M2hwC
- h代表featuremap的高度
- w代表featuremap的宽度
- C代表featuremap的深度
- M代表每个窗口(Windows)的大小
MSA的计算量分析
首先引入矩阵乘法的计算量公式:
A a × b ⋅ B b × c FLOT: a × b × c \begin{array}{l} A^{a \times b} \cdot B^{b \times c}\\ \text { FLOT: } a \times b \times c \end{array} Aa×b⋅Bb×c FLOT: a×b×c
自注意力机制计算量
对于feature map中的每个像素(或称作token,patch),都要通过 W q , W k , W v \mathrm{W}_{\mathrm{q}}, \mathrm{W}_{\mathrm{k}}, \mathrm{W}_{\mathrm{v}} Wq,Wk,Wv
生成对应的query(q),key(k)以及value(v)。这里假设q, k, v的向量长度与feature map的深度C保持一致。那么对应所有像素生成Q的过程如下式:
A h w × C ⋅ W q C × C = Q h w × C \mathrm{A}^{\mathrm{hw} \times \mathrm{C}} \cdot \mathrm{W}_{\mathrm{q}}^{\mathrm{C} \times \mathrm{C}}=\mathrm{Q}^{\mathrm{hw} \times \mathrm{C}} Ahw×C⋅WqC×C=Qhw×C
A h w × C 为将所有像素 (token) 拼接在一起得到的矩阵 (一共有hw个像素,每个像素的深度为C) A^{\mathrm{hw}} \times \mathrm{C} \text { 为将所有像素 (token) 拼接在一起得到的矩阵 (一共有hw个像素,每个像素的深度为C) } Ahw×C 为将所有像素 (token) 拼接在一起得到的矩阵 (一共有hw个像素,每个像素的深度为C)
W q C × C 为生成query的变换矩阵 \mathrm{W}_{\mathrm{q}}^{\mathrm{C} \times \mathrm{C}} \text { 为生成query的变换矩阵 } WqC×C 为生成query的变换矩阵
Q h w × C 为所有像素通过 W q C × C 得到的query拼接后的矩阵 \mathrm{Q}^{\mathrm{hw} \times \mathrm{C}} \text { 为所有像素通过 } \mathrm{W}_{\mathrm{q}}^{\mathrm{C} \times \mathrm{C}} \text { 得到的query拼接后的矩阵 } Qhw×C 为所有像素通过 WqC×C 得到的query拼接后的矩阵
根据矩阵运算的计算量公式可以得到生成Q的计算量为hw×C×C,生成K和V同理都是hwC的2,那么总共是3hwC的2。接下来Q和KT相乘,对应计算量为(hw)的2C
Q h w × C ⋅ K T ( C × h w ) = X h w × h w \mathrm{Q}^{\mathrm{hw} \times \mathrm{C}} \cdot \mathrm{K}^{\mathrm{T}(\mathrm{C} \times \mathrm{hw})}=\mathrm{X}^{\mathrm{hw} \times \mathrm{hw}} Qhw×C⋅KT(C×hw)=Xhw×hw
假设得到 Λ h w × h w ,最后还要乘以V,对应的计算量为 ( h w ) 2 C : \text { 假设得到 } \Lambda^{\mathrm{hw} \times h w} \text { ,最后还要乘以V,对应的计算量为 }(\mathrm{hw})^{2} \mathrm{C} \text { : } 假设得到 Λhw×hw ,最后还要乘以V,对应的计算量为 (hw)2C :
Λ h w × h w ⋅ V h w × C = B h w × C \Lambda^{\mathrm{hw} \times \mathrm{hw}} \cdot \mathrm{V}^{\mathrm{hw} \times \mathrm{C}}=\mathrm{B}^{\mathrm{hw} \times \mathrm{C}} Λhw×hw⋅Vhw×C=Bhw×C
将之前的计算量进行一个求和的操作就可以得到单头注意力机制的总体计算总量
3 h w C 2 + ( h w ) 2 C + ( h w ) 2 C = 3 h w C 2 + 2 ( h w ) 2 C 3 \mathrm{hwC}^{2}+(\mathrm{hw})^{2} \mathrm{C}+(\mathrm{hw})^{2} \mathrm{C}=3 \mathrm{hwC}{ }^{2}+2(\mathrm{hw})^{2} \mathrm{C} 3hwC2+(hw)2C+(hw)2C=3hwC2+2(hw)2C
多头注意力模块相比单头注意力模块的计算量仅多了最后一个融合矩阵wo计算量为hwc2
4 h w C 2 + 2 ( h w ) 2 C 4 \mathrm{hwC}^{2}+2(\mathrm{hw})^{2} \mathrm{C} 4hwC2+2(hw)2C
W-MSA的计算量
对于W-MSA模块首先要将feature map划分到一个个窗口(Windows)中,假设每个窗口的宽高都是M,那么总共会得到 h M × w M \frac{\mathrm{h}}{\mathrm{M}} \times \frac{\mathrm{w}}{\mathrm{M}} Mh×Mw
个窗口。
然后对每个窗口内使用多头注意力模块。刚刚计算高为h,宽为w,深度为C的feature map的计算量为。
4 h w C 2 + 2 ( h w ) 2 C 4 \mathrm{hwC}^{2}+2(\mathrm{hw})^{2} \mathrm{C} 4hwC2+2(hw)2C
这里每个窗口的高为M宽为M,带入公式得
4 ( M C ) 2 + 2 ( M ) 4 C 4(\mathrm{MC})^{2}+2(\mathrm{M})^{4} \mathrm{C} 4(MC)2+2(M)4C
结合最终得到的窗口的数目就可以得到最后的计算量了
h M × w M × ( 4 ( M C ) 2 + 2 ( M ) 4 C ) = 4 h w 2 + 2 M 2 h w C \frac{\mathrm{h}}{\mathrm{M}} \times \frac{\mathrm{w}}{\mathrm{M}} \times\left(4(\mathrm{MC})^{2}+2(\mathrm{M})^{4} \mathrm{C}\right)=4 \mathrm{hw}^{2}+2 \mathrm{M}^{2} \mathrm{hwC} Mh×Mw×(4(MC)2+2(M)4C)=4hw2+2M2hwC
Shifted Multi-head Self-Attention Window (SW-MSA)
目的:实现不同 Window之间的信息交互。
直观的理解就是在原来window的基础上分别向右和向下移动了两个单位的 将windows向右下角移动windowsize//2的位置大小。
根据信息的位置就可以看出下一层的窗口在计算的过程中就融合了之前各个窗口的信息,从而实现了交互的功能。
经过分离之后就可以偏离出9个窗口了。如下所示。
如何通过偏移窗口来进行计算的
参考上面的示意图:将0号区域标记为A,将1号和2号区域标记为B,将3号和六号区域标记为C
- 将A和C两个区域先移动到下方来完成第一步的操作。
2. 在将A和B的区域移动到最右边去完成第二步的一个操作。
- 将对应的区域坐标进行一个合并最后可以得到4x4的 4个划分的区域信息位置。
即分别为:
- 4
- 5 3
- 7 1
- 8 6 2 0
要单独的计算区域5的MSA和区域3的MSA部分,论文当中的图4中结合给出了一部分的说明,就是使用一个masked—MSA来完成。
移动完后,4是一个单独的窗口;将5和3合并成一个窗口;7和1合并成一个窗口;8、6、2和0合并成一个窗口。这样又和原来一样是4个4x4的窗口了,所以能够保证计算量是一样的。这里肯定有人会想,把不同的区域合并在一起(比如5和3)进行MSA,这信息不就乱窜了吗?是的,为了防止这个问题,在实际计算中使用的是masked MSA即带蒙板mask的MSA,这样就能够通过设置蒙板来隔绝不同区域的信息了。
masked掩膜操作就相当于是一个减法的操作步骤。使得3和5的两个区域保持独立的计算方式。
注意:在全部计算完成之后需要将数据在挪回到原来的位置上的,而不是使用应该移动之后的位置来进行一个计算的操作。向左和向上进行一个还原的操作步骤。
相对位置偏移(Relative bias position)
Attention ( Q , K , V ) = SoftMax ( Q K T / d + B ) V \operatorname{Attention}(Q, K, V)=\operatorname{SoftMax}\left(Q K^{T} / \sqrt{d}+B\right) V Attention(Q,K,V)=SoftMax(QKT/d+B)V
论文中的表4给出了使用相对位置偏移之后,可以提升一定的性能:
计算步骤
- 计算相对的位置
得到M x M大小的相对位置索引的矩阵信息。 有2m-1的平方中组合的方式。(因为取值的范围是-M+1到M-1的取值范围)
- 对相对的位置进行编码处理。
根据上一步生成的相对位置索引,得到相对位置偏移矩阵B最后根据相对位置偏移矩阵B结合之前提到过的位置索引矩阵得到最后的相对位置偏移带入公式中进行计算