知乎:刀刀宁
链接:https://zhuanlan.zhihu.com/p/718156896
线性注意力机制的文章有很多了,在本篇笔记中,我们简单地对各种方法进行一下图解比较,串一下当前的线性注意力机制,涉及的公式极少,主要梳理逻辑脉络。本文会从 state space model 中间状态模型这条主线,来梳理 RNN、LSTM,再到 Retentive、GLA 等 Linear Attention 的改进版,最后再到 Mamba、Mamba-2、RWKV 等方法。
线性注意力机制的好处很多,可以用“多快好省”来形容:理论复杂度低、速度快、结构简单、上下文长度线性依赖、KVCache 不需要额外存储,且优化容易。但相比 full attention,线性注意力机制的表达能力确实差一截,且无法完全丢弃历史信息,类似于 RNN 的遗忘和依赖关系,因此产生了各种改进方法。
同时,线性注意力也具备很多并行和 IO 感知的优化,否则复杂度线性化后,并行和运算速度若不如 full attention,就显得鸡肋。因此,如何结合硬件(主要是 CUDA GPU 的特点)来进行注意力机制的系统级优化是不可忽略的问题。
Part 1: Linear Attention 与非必要 softmax
Linear Attention Transformers (Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention) 的论文发表于 2020 年。
https://proceedings.mlr.press/v119/katharopoulos20a/katharopoulos20a.pdf
为了帮助理解,先引用苏神 2020 的文章:《Attention 必须有个 Softmax 吗?》以及《超无聊的:Gated Linear Attention Transformers with Hardware-Efficient Training》,来看一下去掉 softmax 函数后的 attention 机制。这里省略公式和证明,感兴趣的读者可移步前文。
https://spaces.ac.cn/archives/7546 https://zhuanlan.zhihu.com/p/672824235
下图左图是原来的 attention 机制,矩阵乘法的顺序和计算复杂度:设序列长度为 ,当前复杂度为 级,这是我们熟悉的情形。而右图则去掉了 softmax,用近似函数 sim
替代,并改变了 QKV 的计算顺序(本文中的典型线性注意力机制)。这时,神奇的事情发生了:中间结果从 的矩阵变成了 ,复杂度变成了 线性(当然,若 是 4096 级别, 也很大,此时还需考虑减小 等方法)。但整个运算过程与 的长度呈线性相关性。
这只是 softmax 的原因吗?最初认为是,但深入研究后发现,这与 softmax 关系不大。继续往后看。
自回归阶段的矩阵计算
上图展示的是 prefill 阶段(或整体大矩阵相乘阶段)。现在看一下自回归阶段或生成阶段(逐个 token 输出时)的矩阵计算:
下图左图是 full attention,而右图为线性注意力机制。每次解码时,计算出的 中间结果矩阵可以直接叠加到历史中间矩阵中。这个中间矩阵可以称为 State Space Model(SSM),是中间状态模型。每次新的 SSM 可以与之前的所有 SSM 直接相加,这与 prefill 阶段的大矩阵乘法在数学上是等价的。比较 full attention 和 linear attention 在不同长度下的区别:随着 增加,full attention 中 attention 第二步的矩阵长度也增加;而 linear attention 中的 SSM 大小一直保持 。
Part 2: Linear Attention 的 state space 和 full attention 的本质区别
进一步看,full attention 保留了每个 query 与历史上每个生成 token 之间的关系;linear attention 则通过更新 SSM,将所有信息保留在 SSM 中。SSM 大小不变,叠加进去的信息由于加法操作失去了具体的 query 指向,运行时无法单独抽取特定信息。
同时,线性注意力无法强调或丢弃特定信息,无法像 softmax 那样突出重点或遗忘非重点,如同一锅粥。softmax 不仅强调了当次重点,抛弃了非重点,还通过 级运算保留了所有 token 间的相对关系,使信息关系得到完整保留。SSM 则忽略了步数索引。
Part 3: RNN、LSTM 与 cell state、state space model
这时,我们自然联想到 RNN 的 hidden state,即这里的 SSM 的中间状态矩阵。因此,Linear Attention Transformers 的论文标题为 Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention。这种情形类似于 RNN 遇到的问题,即无法有效依赖之前的特定节点,所以有了 LSTM。LSTM 是一种长短期记忆网络,通过一系列门控和组合机制,使 hidden state space 中的状态根据输入捕捉与历史信息关联紧密的部分。
https://xiaosheng.blog/2017/09/16/what-is-lstm
LSTM 的关键在于 cell state,即贯穿顶部的水平线,类似输送带,持续流动,信息保持稳定。这实际上就是状态空间(state space)。
Part 4: Linear Attention Recurrent Representation
在 RNN 和 LSTM 中的 hidden state 的基础上,我们现在可以套用 Linear Attention 进行表示。通过对比 RNN 和 LSTM,发现它们的状态线是一致的,而输入由直接输入变成了 QKV 三类键值。与 full attention 相比,Linear Attention 先对 KV 和之前的 时刻的 SSM 进行叠加得到 时刻的 SSM,再与 Q 相乘得到输出。
Part 5: Linear Attention 的变种:Retention 与 GLA
通过图解可以发现,技术线上的一些热门方法都可以应用这种逻辑。
Retentive Network
例如:Retentive Network: A Successor to Transformer for Large Language Models。
https://arxiv.org/abs/2307.08621
本文使用的图的雏形就源于这篇论文。核心公式并不复杂,只是增加了一个用于控制对之前 SSM 状态的加权机制。回到前面的图片逻辑表达方式,可以将其简化为下图所示:这种简洁的结构具有一定的问题:它类似于循环神经网络(RNN)和长短期记忆网络(LSTM),需要更复杂的带有遗忘门(forget gate)的注意力机制/状态空间结构。这里的遗忘门是输入数据依赖的,而非如保留机制(retention)中那样人为固定。
https://arxiv.org/abs/2312.06635
GLA 模型与 LSTM/GRU 类似,但其门控不依赖于上一层递归的状态,并具有非线性结构。GLA 的 G 矩阵类似于 QKV 结构。门控计算完成后,整个模型依然属于线性循环神经网络(RNN)的范畴。
注意:⊙ 符号代表 element-wise 的计算,称为 Hadamard 乘法,而非矩阵乘法。通过将公式与 recurrent 表示形式相结合,发现 GLA 的结构变得复杂了不少,输入输出各增加了一个新的门控机制。GLA 可以算是当前较好的 Linear Attention 的改进版。
Part 6: Mamba
Mamba 结构其实与 Linear Attention 关系不大,更类似 RNN 系列。我们在此进行类比以提升对比学习效率。
参考https://blog.csdn.net/v_JULY_v/article/details/134923301 https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-mamba-and-state
Mamba 的 S4 结构公式并不复杂,可以用以下 recurrent 结构进行表示(虽然不非常精确)。其中的 ABCD 四个矩阵:A 是一个结构化的 HiPPO 矩阵,作用是对 hidden state 进行加权。
S4 的问题在于线性不变性(Linear Time Invariance),SSM 依赖的 A、B 和 C 不会随不同的输入变化。为解决这个问题,S6 引入了 selective 和 scan 操作:
Selective:为每个维度或通道维护一个独立的 SSM,类似于 Multi-Head Attention 中不同的注意力头。
Scan:为支持在 selective 机制下的 parallel 处理进行加速。有了 selective 和 scan 机制,Mamba 模型可以更有效地关注必须关注的部分,过滤掉可忽略的信息。
Part 7: Mamba-2
Mamba 在学术界很火,但实测效果不佳。如今又有了 Mamba-2,提出了一些新的概念,如 SSM 衍生的 SSA 和 SSD。
https://arxiv.org/abs/2405.21060 https://blog.csdn.net/v_JULY_v/article/details/140131413 https://www.jiqizhixin.com/articles/2024-06-04-7 https://goombalab.github.io/blog/2024/mamba2-part1-model/
SSA:state space Attention
SSD:state space duality
SSD 是个 dual model 的双模逻辑,有点像比亚迪的 DMi,也像掼蛋:大规则是升级,小规则是斗地主。Mamba-2 的 SSD 设计初衷是 full Attention 结构,但通过各种 mask 和分块模式,将整体转换为一个个小分块,每个分块内部是一个“线性”的 SSM 模式。SSD 就是上图中的重叠部分,兼具 SSM 和 SMA 的特点。
然而,Mamba-2 过于复杂,硬要将 full attention 的优势强加到 Mamba 这种纯 RNN 模型上。业界反应也较为冷淡,所以 Mamba-2 的效果如何,尚需观察。
Part 8: RWKV
最后一道菜是 RWKV。
RWKV 的文档很多,可以参考《RWKV解读:在Transformer的时代的新RNN》,苏建林的《如何评价最新的RWKV论文 (arXiv 2305.13048)》,以及 PENG Bo 的《RWKV:用 RNN 达到 Transformer 性能,且支持并行模式和长程记忆,既快又省显存,已在14B参数规模检验》。
https://zhuanlan.zhihu.com/p/656323242 https://www.zhihu.com/question/602564718/answer/3062973388 https://zhuanlan.zhihu.com/p/599150009
从结构图中可以看出,RWKV 的 time mixing 模块是重点,它通过 SSM 的 hidden state 将 Transformer 中的 Attention 替换为 RWKV 独有的 recurrent 方式。因此,RWKV 是一个线性 RNN 模型。虽然表面上与 GLA 不同,但底层逻辑上是否相似或存在本质区别,尚需进一步研究。
总结
这篇笔记酝酿了很久,要学习和思考的内容不少。通过 state space model 的 recurrent 表达形式,我对前面各种方法的大体思路进行了梳理和总结。由于篇幅限制,很多细节没有展开,可能存在一些错误,欢迎批评指正。我会慢慢更新这篇笔记。
本文未分析各个方法的运算友好程度和优化方式,也未涉及这些方法的训练方法,包括常规训练和长文本训练外推的逻辑等。
备注:昵称-学校/公司-方向/会议(eg.ACL),进入技术/投稿群
id:DLNLPer,记得备注呦