1. Transformer为什么需要位置编码
因为 transformer 结构本身是和位置编码无关的: Y = T ( X ) = F ( A ( X ) ) Y=\Tau(X)=F(A(X)) Y=T(X)=F(A(X)),其中 A ( ) A() A() 是 attention 变换,只进行了矩阵变换,跟位置无关, F ( ) F() F() 是前馈网络变换,也跟位置无关。
高阶语义向量不仅仅是由周围 token 的语义向量组合表达而成,还需要加上每个 token 所处的位置。比如我们在句子前面加上
< c l s > <cls> <cls> 来学习整个句子的语义向量,如果没有位置编码的话下面两个句子 < c l s > <cls> <cls> token 得到的语义向量是完全一样的。
像是 RNN 和 CNN 结构天然就有序列的位置信息。
2. 如何加入位置编码
2.1. 直接修改输入
在 X i → Q i , K i , V i X_i \to Q_i,K_i,V_i Xi→Qi,Ki,Vi 之前直接加入位置的embedding: X i ′ = X i + P i X'_i = X_i+P_i Xi′=Xi+Pi。优点是简单,缺点是外推困难,对于超过序列最大长度的位置无法表示。用在BERT和GPT中。
典型的方式比如自定义绝对位置编码:
2.2. 修改Attention
从直觉上来说语义信息关注的是不同 token 之间的相对位置信息,上面的绝对位置编码不够直接,所以引入相对位置编码。
通过 attention 过程引入位置信息说白了就是在计算 token 之间的相关性矩阵时通过引入某种带有位置信息的操作使得算出来的相关性矩阵里带有相对位置信息。
下面从数学的角度和工程的角度来讲解,可以认为是先通过数学推出了用什么函数修改 attention,然后从工程的角度证明这个函数是有效的。
2.2.1 从数学的角度推理
假设我们引入函数 f ( ) f() f(),它有两个参数,一个是 token 的 embedding 信息,一个是这个 token 的位置信息。比如 f ( q , m ) f(q, m) f(q,m),其中 q q q 是 某个token 的 embedding, m m m 是它的位置信息,但是此时 f ( q , m ) f(q, m) f(q,m) 里面是绝对位置信息。Attention 的核心运算是内积,所以我们希望的内积的结果带有相对位置信息,因此假设存在恒等关系(其中 q q q 和 k k k 是两个不同位置的 token 的 embedding, m m m 和 n n n 是对应的位置信息):
这样的话只要能求解出满足上式的函数 f ( ) f() f() 就可以将相对位置信息在 attention 阶段引入了,只要 f ( ) f() f() 确定了, g ( ) g() g() 也就可以确定了。
要求解一个函数,首先我们可以设定一些简单的,符合直觉的初始条件:
我们先考虑二维情形,然后借助复数来求解。在复数中有 < q , k > = R e [ q k ∗ ] <q,k>=Re[qk^*] <q,k>=Re[qk∗], R e [ ] Re[] Re[] 代表复数的实部,所以有:
其中 f ∗ ( k , n ) f^*(k, n) f∗(k,n) 表示 f ( k , n ) f(k, n) f(k,n) 的共轭。
简单起见,我们假设存在复数 g ( q , k , m − n ) g(q, k, m-n) g(q,k,m−n),使得 f ( q , m ) f ∗ ( k , n ) = g ( q , k , m − n ) f(q, m)f^*(k, n) = g(q, k, m-n) f(q,m)f∗(k,n)=g(q,k,m−n),然后设:
注: R f ( q , m ) R_{f(q, m)} Rf(q,m) 是复数 f ( q , m ) f(q, m) f(q,m) 的模, Θ f ( q , m ) \Theta_{f(q,m)} Θf(q,m) 是是复数 f ( q , m ) f(q, m) f(q,m) 的辐角,在苏剑林老师的博客里因为排版问题可能会造成误识。
带入式 (4) 中得到:
详细的推导如下:
f ( q , m ) f ∗ ( k , n ) = R f ( q , m ) e i θ f ( q , m ) R f ( k , n ) e − i θ f ( k , n ) = R f ( q , m ) R f ( k , n ) e i ( θ f ( q , m ) − θ f ( k , n ) ) = g ( q , k , m − n ) = R g ( q , k , m − n ) e i θ g ( q , k , m − n ) f(q, m)f^*(k, n)=R_{f(q, m)}e^{i\theta_{f(q,m)}}R_{f(k, n)}e^{-i\theta_{f(k,n)}}=R_{f(q, m)}R_{f(k, n)}e^{i(\theta_{f(q,m)-\theta_{f(k,n)}})}=g(q, k, m-n)=R_{g(q, k, m-n)}e^{i\theta_{g(q,k,m-n)}} f(q,m)f∗(k,n)=Rf(q,m)eiθf(q,m)Rf(k,n)e−iθf(k,n)=Rf(q,m)Rf(k,n)ei(θf(q,m)−θf(k,n))=g(q,k,m−n)=Rg(q,k,m−n)eiθg(q,k,m−n)
对于第一个方程,代入 m = n m=n m=n 得到:
最后一个等号源于初始条件 f ( q , 0 ) = q f(q,0)=q f(q,0)=q 和 f ( k , 0 ) = k f(k,0)=k f(k,0)=k。我们假设 R f ( q , m ) = ∣ ∣ q ∣ ∣ , R f ( k , m ) = ∣ ∣ k ∣ ∣ R_{f(q, m)}=||q||,R_{f(k,m)}=||k|| Rf(q,m)=∣∣q∣∣,Rf(k,m)=∣∣k∣∣,即不依赖于 m m m(因为我们是要构造出符合假设的函数 f ( ) f() f(),所以无论怎么假设都可以,只要最后能构造出来就行,这样的话假设条件越简单越好)。
公式 (6) 中的第二个等式带入 m = n m=n m=n 得到:
这里的 Θ q , Θ k \Theta_{q},\Theta_{k} Θq,Θk 是 q , k q,k q,k 本身的辐角,最后一个等号同样源于初始条件。根据上式得到 Θ f ( q , m ) − Θ q = Θ f ( k , m ) − Θ k \Theta_{f(q,m)}-\Theta_{q}=\Theta_{f(k,m)}-\Theta_{k} Θf(q,m)−Θq=Θf(k,m)−Θk,所以 Θ f ( q , m ) − Θ q \Theta_{f(q,m)}-\Theta_{q} Θf(q,m)−Θq 应该是一个只与 m m m 相关,跟 q q q 无关的函数,记为 φ ( m ) \varphi(m) φ(m),即 Θ f ( q , m ) = Θ q + φ ( m ) \Theta_{f(q,m)}=\Theta_{q}+\varphi(m) Θf(q,m)=Θq+φ(m)。接着代入 n = m − 1 n=m-1 n=m−1,整理得到:
可以看到等式的自变量为 m m m,而上述等式右边跟自变量 m m m 无关,是一个常数,所以 φ ( m ) \varphi(m) φ(m) 是一个等差数列,我们假设右端为 θ \theta θ,则 φ ( m ) = m θ \varphi(m)=m\theta φ(m)=mθ(这是推导出函数 f ( ) f() f() 的关键一步)
至此,我们已经得到了二维情况下用复数表示的RoPE:
其中 q = ∣ ∣ q ∣ ∣ e i Θ q q=||q||e^{i\Theta_q} q=∣∣q∣∣eiΘq
注:这里要特别注意 Θ \Theta Θ 和 θ \theta θ 的区别,前者是个函数,后者是个数.
根据复数乘法的几何意义,该变换实际上对应着向量的旋转,所以我们称之为“旋转式位置编码”,它还可以写成矩阵形式:
由于内积满足线性叠加性,因此任意偶数维的RoPE,我们都可以表示为二维情形的拼接,即:
也就是给位置为 m m m 的向量 q q q 乘上矩阵 R m R_m Rm、位置为 n n n 的向量 k k k 乘上矩阵 R n R_n Rn,用变换后的 Q , K Q,K Q,K 序列做 Attention,那么Attention 就自动包含相对位置信息了,因为成立恒等式:
按照 (12) 展开之后用三角函数即可证明。
值得指出的是, R m R_m Rm 是一个正交矩阵,它不会改变向量的模长,因此通常来说它不会改变原模型的稳定性,但是向量的方向会发生变化,这就是旋转位置编码中旋转的由来。
由于 R m R_m Rm 的稀疏性,所以直接用矩阵乘法来实现会很浪费算力,推荐通过下述方式来实现RoPE:
其中⊗是逐位对应相乘,即Numpy、Tensorflow等计算框架中的∗运算。从这个实现也可以看到,RoPE可以视为是乘性位置编码的变体。
2.2.2 从工程的角度
工程上直接用了数学上推出的结论,也就是说 f ( Q , i ) = Q i R ( i ) f(Q,i)=Q_iR(i) f(Q,i)=QiR(i),可以看到上面的数学推理出的式子是 f ( q , m ) = R ( m ) q m f(q,m)=R(m)q_m f(q,m)=R(m)qm,这是因为数学上一般使用矩阵左乘一个列向量,而代码实现的时候一般是矩阵右乘一个行向量,也正是因为这个原因,所以工程上的旋转矩阵为:
与数学上有区别。
如果在投影之前做旋转的话就无法在 attention 的时候加上相对位置信息了。
在高维空间做旋转可以拆分为在多个二维子空间上左旋转然后拼接起来。
这里只是设计了一个基准 θ k = 1000 0 − k / d \theta_k=10000^{-k/d} θk=10000−k/d,也可以设为其他。一个复数乘以 i i i 就会在复平面上逆时针旋转90°
进入·每一层Transformer的时候都会计算一次相对位置编码,不像在绝对位置编码中只在第一层Transformer中添加位置编码,这样的话随着层数的增加相对位置编码的位置信息是一直都很强的。
这里解释了上一小节中式子 (14) 是怎么来的。
3. 复数相关知识
其中 r r r 是 z z z 的模, θ \theta θ 是 z z z 辐角
内积满足线性叠加性:
4. 矩阵相关知识
正交矩阵:如果: A A T = E AA^T = E AAT=E(E为单位矩阵,AT表示“矩阵A的转置矩阵”。)或 A T A = E A^TA=E ATA=E,则n阶实矩阵A称为正交矩阵。
5. 三角函数
参考1:Transformer升级之路:2、博采众长的旋转式位置编码
参考2:一文看懂 LLaMA 中的旋转式位置编码
参考3:十分钟读懂旋转编码
参考4:Meta最新模型LLaMA细节与代码详解
参考5:B站视频