目录
- 1 代码原理简述
- 1.1 原始结构——SSM
- 1.2 结构改进——S4(Structured State Space for Sequences)
- 1.2.1 离散化
- 1.2.2HiPPO
- 1.3 最终版本——Mamba(又称S6或selective SSMs)
- 2 代码库目录结构
- 2.1 mamba_simple.py主体结构
- 2.1.1 Mamba类
- (1)__init__函数
- (2)forward函数
- (3)step函数
- 2.1.2 Block类
- 2.2 官方完整实例
Mamba原理:参考链接1,参考链接2,参考链接3
Mamba代码:参考链接,mamba-mini实现参考链接官方mamba库链接: https://github.com/state-spaces/mamba
(注:以下以1.1.4版本mamba-ssm库的代码为例)🥳🥳🥳欢迎各位大佬一起来交流讨论🥳🥳🥳
1 代码原理简述
熟悉原理的话可以直接看第二章代码
原理介绍摘录参考了链接,(这位大佬的博客应该是全网最详细的了,想摸清原理一定要读一下)
Mamba算法的本质是借鉴了现代控制理论中的状态空间模型SSM。
1.1 原始结构——SSM
状态空间模型SSM的公式如下:
通过求解这些方程,可以根据观察到的数据:输入序列和先前状态,去预测系统的未来状态。(A、B、C、D这4个矩阵是可学习的参数,是可以学习到的,学习好之后,在SSM中,矩阵A、B、C、D便固定不变了,但到后续mamba中则这4个矩阵可以随着输入不同而可变)
这一过程可以由如下图来表示:
1.2 结构改进——S4(Structured State Space for Sequences)
S4模型对SSM的改进有以下三点:
- 采用零阶保持,来进行连续化:由于SSM模型是针对连续函数的,但是在文本、图像等领域,数据都是离散的,所以我们需要将离散的点连续化,才能输入进SSM模型,最后再从连续的输出中采样离散的点来得到真正的输出
- 使用循环+卷积结构表示,从而能够并行训练,加快训练速度
- 使用HIPPO矩阵,以处理远程依赖
1.2.1 离散化
由于以上公式是针对连续变量的,单实际计算机相关的应用中往往离散的,所以经过采样和零阶保持的方法,便有了以下离散化公式(这个公式很重要,在代码中会体现):
然后公式变成了这样:
在每个时间步,都会涉及到隐藏状态的更新(比如 h k h_{k} hk取决于 B ‾ x k \overline{\mathrm{B}}\mathrm{x}_{k} Bxk和 A ‾ h k − 1 \overline{\mathrm{A}}\mathrm{h}_{k-1} Ahk−1的共同作用结果,然后通过 C h k Ch_k Chk预测输出 y k y_k yk),这样反复套娃,就得到了类似RNN的循环结构:
换个图表示更明显:
用公式来表示就是这样的:
y 2 = C h 2 = C ( A ˉ h 1 + B ˉ x 2 ) = C ( A ˉ ( A ˉ h 0 + B ˉ x 1 ) + B ˉ x 2 ) = C ( A ˉ ( A ˉ ⋅ B ˉ x 0 + B ˉ x 1 ) + B ˉ x 2 ) = C ( A ˉ ⋅ A ˉ ⋅ B ˉ x 0 + A ˉ ⋅ B ˉ x 1 + B ˉ x 2 ) = C ⋅ A ˉ 2 ⋅ B ˉ x 0 + C ⋅ A ˉ ⋅ B ˉ ⋅ x 1 + C ⋅ B ˉ x 2 \begin{aligned} y_2& =Ch_{2} \\ &=C\left(\bar{A}h_{1}+\bar{B}x_{2}\right) \\ &=C\begin{pmatrix}\bar{A}\begin{pmatrix}\bar{A}h_{0}+\bar{B}x_{1}\end{pmatrix}+\bar{B}x_{2}\end{pmatrix} \\ &=C\begin{pmatrix}\bar{A}\begin{pmatrix}\bar{A}\cdot\bar{B}x_{0}+\bar{B}x_{1}\end{pmatrix}+\bar{B}x_{2}\end{pmatrix} \\ &=C\begin{pmatrix}\bar{A}\cdot\bar{A}\cdot\bar{B}x_{0}+\bar{A}\cdot\bar{B}x_{1}+\bar{B}x_{2}\end{pmatrix} \\ &=C\cdot\bar{A}^{2}\cdot\bar{B}x_{0}+C\cdot\bar{A}\cdot\bar{B}\cdot x_{1}+C\cdot\bar{B}x_{2} \end{aligned} y2=Ch2=C(Aˉh1+Bˉx2)=C(Aˉ(Aˉh0+Bˉx1)+Bˉx2)=C(Aˉ(Aˉ⋅Bˉx0+Bˉx1)+Bˉx2)=C(Aˉ⋅Aˉ⋅Bˉx0+Aˉ⋅Bˉx1+Bˉx2)=C⋅Aˉ2⋅Bˉx0+C⋅Aˉ⋅Bˉ⋅x1+C⋅Bˉx2
由此类推:
y 3 = C A A A B ‾ x 0 + C A A B ‾ x 1 + C A B ‾ x 2 + C B ‾ x 3 y_{3}=\mathbf{C\overline{AAAB}}x_{0}+\mathbf{C\overline{AAB}}x_{1}+\mathbf{C\overline{AB}}x_{2}+\mathbf{C\overline{B}}x_{3} y3=CAAABx0+CAABx1+CABx2+CBx3
y k = C A ˉ k B ˉ x 0 + C A ˉ k − 1 B ˉ x 1 + ⋯ + C A ˉ B ˉ x k − 1 + C B ˉ x k y_{k}=C\bar{A}^{k}\bar{B}x_{0}+C\bar{A}^{k-1}\bar{B}x_{1}+\cdots+C\bar{A}\bar{B}x_{k-1}+C\bar{B}x_{k} yk=CAˉkBˉx0+CAˉk−1Bˉx1+⋯+CAˉBˉxk−1+CBˉxk
这样看来循环结构就更明显了。
1.2.2HiPPO
HiPPO是一种解决长距离依赖问题的方法,用于构建并调整状态矩阵A使其逼近最优解,这种方法比把A初始化为随机矩阵要好得多
1.3 最终版本——Mamba(又称S6或selective SSMs)
Mamba模型对于S4模型的改进有以下三点:
- 参数化矩阵:对输入信息进行有选择性的处理,从而得到类似Attention的效果,即不同的输入拥有不同的状态,token信息
- 硬件感知算法,并行化–选择扫描算法,加快训练推理速度
- 更简化的SSM模型架构
在最终的Mamaba中,作者让B矩阵、C矩阵、 Δ \Delta Δ 成为输入的函数,让模型能够根据输入内容自适应地调整其行为
最终的整体流程:
结构细节如下:
其中的“选择性SSM(即Selective SSM)”具有以下属性:
- Recurrent SSM通过离散化创建循环SSM
- HiPPO对矩阵A进行初始化A以捕获长程依赖性
- 选择性扫描算法(Selective scan algorithm)选择性压缩信息
- 硬件感知算法(Hardware-aware algorithm)加速计算
Mamba和Trasformer,RNN相对的优势:
2 代码库目录结构
mamba
├── benchmarks
│ └── benchmark_generation_mamba_simple.py // 示例模型的推理脚本
├── csrc
│ └── selective_scan // 选择性扫描的c++实现
├── evals
│ └── lm_harness_eval.py
├── mamba_ssm
│ ├── models
│ │ ├── config_mamba.py
│ │ └── mixer_seq_simple.py // 使用mamba构建的一个完整的语言模型示例
│ ├── modules
│ │ └── mamba_simple.py // mamba block的实现
│ ├── ops
│ │ ├── triton
│ │ │ ├── layernorm.py
│ │ │ ├── selective_state_update.py
│ │ └── selective_scan_interface.py // 选择性SSM层的实现
│ ├── utils
│ │ ├── generation.py
│ │ └── hf.py
└── test└── ops├── triton│ ├── test_selective_state_update.py└──test_selective_scan.py
2.1 mamba_simple.py主体结构
2.1.1 Mamba类
class Mamba(nn.Module):def __init__(...def forward(self, hidden_states, inference_params=None):...def step(self, hidden_states, conv_state, ssm_state):...def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):...def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):...
(1)__init__函数
Mamba类/__init__函数负责各个模块的初始化
- 输入投影初始化:输入特征通过线性投影被扩展为两倍的内部特征维度,准备进行卷积和状态空间模型的计算。
- 卷积层:使用 1D 卷积操作,捕捉局部的时间特征,使用
d_conv
作为卷积核大小。 - 状态空间模型:通过
A_log
矩阵初始化状态矩阵,并通过跳跃连接D
来捕捉长期依赖。 - 时间常数初始化:时间常数
dt
通过dt_proj
层生成,并在初始化时通过 softplus 函数的逆函数来确保其在合理范围内。 - 输出层:通过线性投影,将扩展后的特征维度恢复到输入特征维度,完成一步处理后的输出。
def __init__(self,d_model,d_state=16,d_conv=4,expand=2,dt_rank="auto",dt_min=0.001,dt_max=0.1,dt_init="random",dt_scale=1.0,dt_init_floor=1e-4,conv_bias=True,bias=False,use_fast_path=True, # Fused kernel optionslayer_idx=None,device=None,dtype=None,):factory_kwargs = {"device": device, "dtype": dtype}super().__init__()self.d_model = d_model # 输入特征维度self.d_state = d_state # 状态空间模型的维度self.d_conv = d_conv # 卷积核大小self.expand = expand # 特征扩展倍数self.d_inner = int(self.expand * self.d_model) # 扩展后的内部维度self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank # 自动确定时间常数的秩self.use_fast_path = use_fast_path # 是否使用优化的快速路径self.layer_idx = layer_idx # 当前层的索引# 线性投影层,将输入特征扩展为内部维度的两倍self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)# 1D 卷积操作,卷积核大小为 `d_conv`,输入和输出通道数均为 `d_inner`self.conv1d = nn.Conv1d(in_channels=self.d_inner,out_channels=self.d_inner,bias=conv_bias,kernel_size=d_conv,groups=self.d_inner,padding=d_conv - 1,**factory_kwargs,)# 激活函数 SiLU (Sigmoid Linear Unit),一种常用的非线性激活self.activation = "silu"self.act = nn.SiLU()# 再次线性投影,用于生成时间常数、B 和 C 矩阵self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs)# 时间常数的线性投影层self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)# 初始化时间常数的投影层参数,确保初始化时方差保持稳定dt_init_std = self.dt_rank**-0.5 * dt_scaleif dt_init == "constant":nn.init.constant_(self.dt_proj.weight, dt_init_std)elif dt_init == "random":nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)else:raise NotImplementedError# 初始化 dt 投影的偏置,使得F.softplus(dt_bias)后的值在 `dt_min` 和 `dt_max` 之间dt = torch.exp(torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))+ math.log(dt_min)).clamp(min=dt_init_floor)# 使用 softplus 的逆函数来计算偏置inv_dt = dt + torch.log(-torch.expm1(-dt))with torch.no_grad():self.dt_proj.bias.copy_(inv_dt)# 我们的初始化会将所有的Linear.bias变成0,所以这里标记Linear.bias不需要重新初始化self.dt_proj.bias._no_reinit = True# S4D 初始化过程,状态矩阵 A 的对数A = repeat(torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),"n -> d n",d=self.d_inner,).contiguous()A_log = torch.log(A) # 对 A 取对数,保持精度fp32self.A_log_temp = GwParameters(A_log)self.A_log = self.A_log_temp.paramself.A_log._no_weight_decay = True # 防止权重衰减# 跳跃连接参数 Dself.D_temp = GwParameters(torch.ones(self.d_inner, device=device))self.D = self.D_temp.paramself.D._no_weight_decay = True# 输出层的线性投影,将内部维度投影回原始特征维度self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
(2)forward函数
Mamba类/ forward 函数负责在前向传播过程中对输入的序列数据进行处理。它主要包括以下几步:
- 输入预处理:对输入的序列进行线性投影和卷积操作,以提取有用的特征。
- 状态空间模型 (SSM):通过状态空间模型对时间序列的短期和长期依赖进行建模,捕捉序列中的动态变化。
- 卷积层和激活函数:使用一维卷积层提取局部时序信息,激活函数用于引入非线性特征。
- 时间常数的处理:通过
dt_proj
线性投影将特定的时间常数应用到特征上,适应不同时序的动态。 - 输出层:处理后的特征经过线性投影恢复为输入的形状,最终得到输出。
[def forward(self, hidden_states, inference_params=None):"""主要思路:1. 输入预处理:对输入的序列进行线性投影和卷积操作,以提取有用的特征。2. 状态空间模型 (SSM) :通过状态空间模型对时间序列的短期和长期依赖进行建模,捕捉序列中的动态变化。3. 卷积层和激活函数:使用一维卷积层提取局部时序信息,激活函数用于引入非线性特征。4. 时间常数的处理:通过 dt_proj 线性投影将特定的时间常数应用到特征上,适应不同时序的动态。5. 输出层:处理后的特征经过线性投影恢复为输入的形状,最终得到输出。hidden_states: (B, L, D) 表示批次大小 B、序列长度 L、特征维度 DReturns: 返回与输入相同形状的输出"""batch, seqlen, dim = hidden_states.shape# 初始化卷积和状态空间模型的状态为 Noneconv_state, ssm_state = None, Noneif inference_params is not None:# 如果推理参数不为空# 从缓存中获取卷积和状态空间模型的状态conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)if inference_params.seqlen_offset > 0:# 调用 step 函数逐步更新状态,并返回输出out, _, _ = self.step(hidden_states, conv_state, ssm_state)return out# 线性投影和转置:将输入序列 (B, L, D) 投影到内部维度,并进行转置 BLH -> HBLxz = rearrange(self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),"d (b l) -> b d l",l=seqlen,)if self.in_proj.bias is not None: # 如果存在偏置# 将偏置加到投影结果上xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")A = -torch.exp(self.A_log.float()) # 计算 A 矩阵 (d_inner, d_state)# 如果使用快速路径 (不使用推理状态且启用了快速内核)# 在反向传播中,我们将 dx 和 dz 写在一起,以避免 torch.cat 操作if self.use_fast_path and inference_params is None: # 不支持输出状态# 使用快速内核进行计算out = mamba_inner_fn(xz,self.conv1d.weight,self.conv1d.bias,self.x_proj.weight,self.dt_proj.weight,self.out_proj.weight,self.out_proj.bias,A,None, # input-dependent BNone, # input-dependent Cself.D.float(),delta_bias=self.dt_proj.bias.float(),delta_softplus=True,)else: # 如果不使用快速路径 (常规路径)# 将 xz 分割为 x 和 z 两部分 (xz 的形状为 (B, 2 * d_inner, L))x, z = xz.chunk(2, dim=1) # x 和 z 分别为输入投影的两部分# 处理短卷积 (如果推理时有状态缓存,则更新卷积状态)if conv_state is not None:# 使用 F.pad 填充卷积状态,使得卷积操作在序列长度不足时不会出错conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # 更新卷积状态 (B D W)# 执行一维卷积并进行激活操作 (如果存在加速的卷积函数,则使用加速路径)if causal_conv1d_fn is None: # 如果没有使用加速卷积函数x = self.act(self.conv1d(x)[..., :seqlen]) # 常规卷积操作并激活else: # 使用加速卷积函数x = causal_conv1d_fn(x=x, # 输入weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), # 调整卷积权重bias=self.conv1d.bias, # 卷积偏置activation=self.activation, # 激活函数 (silu 或 swish))# We're careful here about the layout, to avoid extra transposes.# We want dt to have d as the slowest moving dimension# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.# 对卷积结果进行投影以生成 dt, B 和 Cx_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # 线性投影 (bl d)dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) # 分割为 dt, B, Cdt = self.dt_proj.weight @ dt.t() # 对时间常数 dt 进行线性变换dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) # 重塑为 (b, d, l)B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() # 重塑 BC = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() # 重塑 Cassert self.activation in ["silu", "swish"]# 使用 selective_scan_fn 进行状态空间模型计算y = selective_scan_fn(x,dt,A,B,C,self.D.float(),z=z,delta_bias=self.dt_proj.bias.float(),delta_softplus=True,return_last_state=ssm_state is not None,)# 如果存在状态空间模型的状态缓存if ssm_state is not None:y, last_state = y # 更新最后的状态ssm_state.copy_(last_state) # 将更新后的状态缓存起来# 将输出 y 重新调整维度 (B, D, L) -> (B, L, D)y = rearrange(y, "b d l -> b l d")# 通过输出层的线性投影恢复原始特征维度out = self.out_proj(y)return out](<def forward(self, hidden_states, inference_params=None):"""主要思路:1. 输入预处理:对输入的序列进行线性投影和卷积操作,以提取有用的特征。2. 状态空间模型 (SSM) :通过状态空间模型对时间序列的短期和长期依赖进行建模,捕捉序列中的动态变化。3. 卷积层和激活函数:使用一维卷积层提取局部时序信息,激活函数用于引入非线性特征。4. 时间常数的处理:通过 dt_proj 线性投影将特定的时间常数应用到特征上,适应不同时序的动态。5. 输出层:处理后的特征经过线性投影恢复为输入的形状,最终得到输出。hidden_states: (B, L, D) 表示批次大小 B、序列长度 L、特征维度 DReturns: 返回与输入相同形状的输出"""batch, seqlen, dim = hidden_states.shape# 初始化卷积和状态空间模型的状态为 Noneconv_state, ssm_state = None, None# TODO: inference_params这个参数有什么用?什么时候会被调整?
# ============================================== inference_params is True ==============================================if inference_params is not None:# 如果推理参数不为空# 从缓存中获取卷积和状态空间模型的状态conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)if inference_params.seqlen_offset %3E 0:# 调用 step 函数逐步更新状态,并返回输出out, _, _ = self.step(hidden_states, conv_state, ssm_state)return out
# ======================================================================================================================# 线性投影和转置:将输入序列 (B, L, D) 投影到内部维度,并进行转置 BLH -> HBLxz = rearrange(self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),"d (b l) -> b d l",l=seqlen,)if self.in_proj.bias is not None: # 如果存在偏置# 将偏置加到投影结果上xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") # 在反向传播中,我们将 dx 和 dz 写在一起,以避免 torch.cat 操作A = -torch.exp(self.A_log.float()) # 计算 A 矩阵 (d_inner, d_state)# ============================================== inference_params is None ==============================================# 如果使用快速路径,整体的思路和非快速一致,可以看else中的解释if self.use_fast_path and inference_params is None: # (不使用推理状态且启用了快速内核)不支持输出状态# 使用快速内核进行计算out = mamba_inner_fn(xz,self.conv1d.weight,self.conv1d.bias,self.x_proj.weight,self.dt_proj.weight,self.out_proj.weight,self.out_proj.bias,A,None, # input-dependent BNone, # input-dependent Cself.D.float(),delta_bias=self.dt_proj.bias.float(),delta_softplus=True,)else: # 如果不使用快速路径 (常规路径)# 将 xz 分割为 x 和 z 两部分 (xz 的形状为 (B, 2 * d_inner, L))x, z = xz.chunk(2, dim=1) # x 和 z 分别为输入投影的两部分# 处理短卷积 (如果推理时有状态缓存,则更新卷积状态)if conv_state is not None:# 使用 F.pad 填充卷积状态,使得卷积操作在序列长度不足时不会出错conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # 更新卷积状态 (B D W)# ===================== Conv部分+激活操作 (如果存在加速的卷积函数,则使用加速路径) ↓=======================if causal_conv1d_fn is None: # 如果没有使用加速卷积函数x = self.act(self.conv1d(x)[..., :seqlen]) # 常规卷积操作并激活else: # 使用加速卷积函数x = causal_conv1d_fn(x=x, # 输入weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), # 调整卷积权重bias=self.conv1d.bias, # 卷积偏置activation=self.activation, # 激活函数 (silu 或 swish))# ====================================== Conv部分+激活操作 ↑========================================# ======================================== SSM 部分 ↓ ========================================# We're careful here about the layout, to avoid extra transposes.# We want dt to have d as the slowest moving dimension# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.# ==对卷积结果进行投影以生成 dt, B 和 C==x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # 线性投影 (bl d)dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) # 分割为 dt, B, Cdt = self.dt_proj.weight @ dt.t() # 对时间常数 dt 进行线性变换dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) # 重塑为 (b, d, l)B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() # 重塑 BC = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() # 重塑 C#====================================assert self.activation in ["silu", "swish"]# 使用 selective_scan_fn 进行状态空间模型计算y = selective_scan_fn(x,dt,A,B,C,self.D.float(),z=z,delta_bias=self.dt_proj.bias.float(),delta_softplus=True,return_last_state=ssm_state is not None,)# 如果存在状态空间模型的状态缓存if ssm_state is not None:y, last_state = y # 更新最后的状态ssm_state.copy_(last_state) # 将更新后的状态缓存起来# ==================================== SSM 部分 ↑ =============================================# 将输出 y 重新调整维度 (B, D, L) -> (B, L, D)y = rearrange(y, "b d l -> b l d")# 通过输出层的线性投影恢复原始特征维度out = self.out_proj(y)return out>)
(3)step函数
Mamba类/ step 函数, 是一次完整的mamba块处理过程(inference_params不为空时执行)。负责在解码过程中处理单个时间步长的数据,并更新卷积状态和状态空间模型(SSM)的状态。主要逻辑分为两部分:
- 卷积步骤 (Conv Step):对输入进行卷积操作,用来捕捉局部的时序特征。
- 状态空间模型步骤 (SSM Step):利用状态空间模型进行时间序列的动态建模,捕捉序列中的长期依赖关系。
[def step(self, hidden_states, conv_state, ssm_state):"""执行过程中单步处理hidden_states: 当前时间步的输入张量 (B, 1, D), B 是批次大小,1 表示单个 token 的输入,D 是特征维度conv_state: 卷积的状态缓存 (B, D, W)ssm_state: 状态空间模型的状态缓存 (B, D, S)Returns: 返回当前时间步的输出 (B, 1, D),以及更新后的 conv_state 和 ssm_state"""# 确保输入的第二维是1,即仅处理一个 tokendtype = hidden_states.dtypeassert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"# 将输入进行线性投影,并将其 squeeze 以去除长度为 1 的维度 (B, 1, D) -> (B, 2D)xz = self.in_proj(hidden_states.squeeze(1)) # (B, 2D)# 将投影后的 xz 分为两部分:x 和 z,每部分维度为 (B, D)x, z = xz.chunk(2, dim=-1) # (B, D)# 卷积操作if causal_conv1d_update is None: # 如果没有使用加速卷积更新函数# 使用 torch.roll 更新卷积状态,向左滚动,丢弃最早的值 (B, D, W)conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))conv_state[:, :, -1] = x # 将当前输入 x 放到状态缓存的最后一个位置# 执行一维卷积,卷积结果为 (B, D)x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)if self.conv1d.bias is not None: # 如果存在卷积偏置x = x + self.conv1d.bias # 加上偏置# 激活函数处理并将数据类型转回为原始输入的 dtypex = self.act(x).to(dtype=dtype)else: # 如果使用了加速的卷积更新函数x = causal_conv1d_update(x, # 输入conv_state, # 卷积状态rearrange(self.conv1d.weight, "d 1 w -> d w"), # 卷积权重self.conv1d.bias, # 卷积偏置self.activation, # 激活函数 (如 silu 或 swish))# 对卷积后的 x 进行线性投影,得到 dt, B 和 Cx_db = self.x_proj(x) # (B, dt_rank + 2 * d_state)dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) # 分割为 dt, B, C# 时间常数 dt 通过线性投影更新,不加偏置dt = F.linear(dt, self.dt_proj.weight) # (B, d_inner)# 计算状态空间模型 (SSM) 中的 A 矩阵,使用 softplus 确保 A 矩阵为负数A = -torch.exp(self.A_log.float()) # (d_inner, d_state)# 状态空间模型步骤 (SSM Step)if selective_state_update is None: # 如果没有使用选择性状态更新函数# 离散化 A 和 B 矩阵dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) # 对 dt 进行 softplus 激活并加上偏置dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) # 计算离散化后的 A 矩阵dB = torch.einsum("bd,bn->bdn", dt, B) # 计算离散化后的 B 矩阵# 更新状态空间模型的状态ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)# 计算输出 yy = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) # 状态乘以 C 得到输出y = y + self.D.to(dtype) * x # 加上 D * x 的跳跃连接y = y * self.act(z) # 与 z 经过激活函数的结果相乘 (B, D)else: # 如果使用了选择性状态更新函数y = selective_state_update( # 调用加速的选择性状态更新函数ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True)# 将输出通过线性投影层out = self.out_proj(y) # (B, D)return out.unsqueeze(1), conv_state, ssm_state # 返回输出 (B, 1, D) 以及更新后的状态](<def step(self, hidden_states, conv_state, ssm_state):"""执行过程中单步处理hidden_states: 当前时间步的输入张量 (B, 1, D), B 是批次大小,1 表示单个 token 的输入,D 是特征维度conv_state: 卷积的状态缓存 (B, D, W)ssm_state: 状态空间模型的状态缓存 (B, D, S)Returns: 返回当前时间步的输出 (B, 1, D),以及更新后的 conv_state 和 ssm_state"""# ===================== 输入预处理部分 ↓ =====================# 确保输入的第二维是1,即仅处理一个 tokendtype = hidden_states.dtypeassert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"# 将输入进行线性投影,并将其 squeeze 以去除长度为 1 的维度 (B, 1, D) -%3E (B, 2D)xz = self.in_proj(hidden_states.squeeze(1)) # (B, 2D)# 将投影后的 xz 分为两部分:x 和 z,每部分维度为 (B, D)x, z = xz.chunk(2, dim=-1) # (B, D)# ===================== 输入预处理部分 ↑ =====================# ================== 卷积部分 + 激活操作 ↓ =====================if causal_conv1d_update is None: # 如果没有使用加速卷积更新函数# 使用 torch.roll 更新卷积状态,向左滚动,丢弃最早的值 (B, D, W)conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))conv_state[:, :, -1] = x # 将当前输入 x 放到状态缓存的最后一个位置# 执行一维卷积,卷积结果为 (B, D)x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)if self.conv1d.bias is not None: # 如果存在卷积偏置x = x + self.conv1d.bias # 加上偏置# 激活函数处理并将数据类型转回为原始输入的 dtypex = self.act(x).to(dtype=dtype)else: # 如果使用了加速的卷积更新函数x = causal_conv1d_update(x, # 输入conv_state, # 卷积状态rearrange(self.conv1d.weight, "d 1 w -> d w"), # 卷积权重self.conv1d.bias, # 卷积偏置self.activation, # 激活函数 (如 silu 或 swish))# ================== 卷积部分 + 激活操作 ↑ =====================# ==================== SSM 部分 ↓ ==========================# 对卷积后的 x 进行线性投影,得到 dt, B 和 Cx_db = self.x_proj(x) # (B, dt_rank + 2 * d_state)dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) # 分割为 dt, B, C# 时间常数 dt 通过线性投影更新,不加偏置dt = F.linear(dt, self.dt_proj.weight) # (B, d_inner)# 计算状态空间模型 (SSM) 中的 A 矩阵,使用 softplus 确保 A 矩阵为负数A = -torch.exp(self.A_log.float()) # (d_inner, d_state)# ====状态空间模型具体步骤 (SSM Step)====if selective_state_update is None: # 如果没有使用选择性状态更新函数# 离散化 A 和 B 矩阵dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) # 对 dt 进行 softplus 激活并加上偏置dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) # 计算离散化后的 A 矩阵dB = torch.einsum("bd,bn->bdn", dt, B) # 计算离散化后的 B 矩阵# 更新状态空间模型的状态ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)# 计算输出 yy = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) # 状态乘以 C 得到输出y = y + self.D.to(dtype) * x # 加上 D * x 的跳跃连接# ======= mamba块右边的激活叠加 ========y = y * self.act(z) # 与 z 经过激活函数的结果相乘 (B, D)else: # 如果使用了选择性状态更新函数y = selective_state_update( # 调用加速的选择性状态更新函数ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True)# ===================== SSM 部分 ↑ =========================# 将输出通过线性投影层out = self.out_proj(y) # (B, D)return out.unsqueeze(1), conv_state, ssm_state # 返回输出 (B, 1, D) 以及更新后的状态>)
此外,辅助函数 allocate_inference_cache
和 _get_states_from_cache
用于在推理过程中管理卷积状态和状态空间模型的状态缓存。
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):"""为推理过程分配卷积和状态空间模型的状态缓存。batch_size: 批次大小max_seqlen: 最大序列长度dtype: 数据类型 (可选)"""device = self.out_proj.weight.device # 获取设备 (CPU 或 GPU)conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype # 如果未指定 dtype,则使用卷积权重的 dtype# 初始化卷积状态,形状为 (B, D, W),其中 W 是卷积窗口大小conv_state = torch.zeros(batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype)ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype # 如果未指定 dtype,则使用 dt_proj 权重的 dtype# 初始化状态空间模型的状态,形状为 (B, D, S),其中 S 是状态维度ssm_state = torch.zeros(batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype)return conv_state, ssm_state # 返回分配的状态def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):"""从缓存中获取卷积和状态空间模型的状态,若缓存不存在则初始化状态。inference_params: 推理参数batch_size: 批次大小initialize_states: 是否初始化状态"""assert self.layer_idx is not None # 确保层索引存在if self.layer_idx not in inference_params.key_value_memory_dict: # 如果缓存中没有当前层的状态batch_shape = (batch_size,)# 初始化卷积状态conv_state = torch.zeros(batch_size,self.d_model * self.expand,self.d_conv,device=self.conv1d.weight.device,dtype=self.conv1d.weight.dtype,)# 初始化状态空间模型的状态ssm_state = torch.zeros(batch_size,self.d_model * self.expand,self.d_state,device=self.dt_proj.weight.device,dtype=self.dt_proj.weight.dtype,# dtype=torch.float32,)# 将初始化的状态存入缓存字典inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)else: # 如果缓存中已有状态conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]# TODO: What if batch size changes between generation, and we reuse the same states?if initialize_states: # 如果需要初始化状态conv_state.zero_() # 将卷积状态清零ssm_state.zero_() # 将状态空间模型的状态清零return conv_state, ssm_state # 返回卷积和状态空间模型的状态
2.1.2 Block类
用于构建包含 Mamba
模块的更大网络。封装了 LayerNorm/RMSNorm 和残差连接的简单模块,以及Mamba的类。这个 Block 与传统的 Transformer 预归一化块稍有不同。
标准的 Transformer block 顺序为: LN -> MHA/MLP -> Add。
而这里的顺序是: Add -> LN -> Mixer ,并且返回的是 hidden_states 和 residual。
这种设计主要是出于性能考虑,因为可以将残差和 LayerNorm 操作融合处理。
需要提供残差,除非是第一个块(第一个块不需要残差输入)。
......class Block(nn.Module): def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False):"""这是一个封装了 LayerNorm/RMSNorm 和残差连接的简单模块,包含了一个 mixer(Mamba) 类。这个 Block 与传统的 Transformer 预归一化块稍有不同。标准的 Transformer block 顺序为: LN -> MHA/MLP -> Add。而这里的顺序是: Add -> LN -> Mixer,并且返回的是 hidden_states 和 residual。这种设计主要是出于性能考虑,因为可以将残差和 LayerNorm 操作融合处理。需要提供残差,除非是第一个块(第一个块不需要残差输入)。"""super().__init__() # 调用父类的初始化方法,确保 nn.Module 的属性被正确初始化self.residual_in_fp32 = residual_in_fp32 # 是否在 fp32 精度下处理残差self.fused_add_norm = fused_add_norm # 是否启用残差加法和归一化融合self.mixer = mixer_cls(dim) # 初始化 mixer 模块,该模块用于处理输入数据的混合self.norm = norm_cls(dim) # 初始化归一化层,默认为 nn.LayerNorm# 如果启用了融合归一化,需要检查 RMSNorm 是否正确导入,并验证 self.norm 是 LayerNorm 或 RMSNormif self.fused_add_norm:assert RMSNorm is not None, "RMSNorm 导入失败" # 确保 RMSNorm 导入成功assert isinstance( # 确保 norm 是 LayerNorm 或 RMSNorm 类型self.norm, (nn.LayerNorm, RMSNorm)), "fused_add_norm 模式下仅支持 LayerNorm 和 RMSNorm"def forward( # 定义前向传播方法self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None):"""将输入传递给编码层的前向传播方法。参数:hidden_states: 输入序列 (必须提供)。residual: 输入的残差,如果 residual 为 None,hidden_states 直接作为残差使用。"""# 如果不启用 fused_add_norm,则进行标准的 LayerNorm 操作if not self.fused_add_norm:# 计算残差: 如果传入了 residual,将其加到 hidden_states 上,否则将 hidden_states 作为 residualresidual = (hidden_states + residual) if residual is not None else hidden_states# 将 residual 经过 norm 归一化hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))# 如果 residual_in_fp32 为真,将 residual 转换为 float32 处理if self.residual_in_fp32:residual = residual.to(torch.float32)else: # 启用了 fused_add_norm 时,使用融合的归一化和加法操作fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn# 使用 fused_add_norm_fn 处理 hidden_states 和 residualhidden_states, residual = fused_add_norm_fn(hidden_states,self.norm.weight,self.norm.bias,residual=residual,prenorm=True, # 使用预归一化residual_in_fp32=self.residual_in_fp32, # 在 fp32 中计算残差eps=self.norm.eps, # 归一化时使用的 epsilon 值,防止除零错误)# 将归一化后的 hidden_states 传递给 mixer 模块进行进一步处理hidden_states = self.mixer(hidden_states, inference_params=inference_params)# 返回处理后的 hidden_states 和 residualreturn hidden_states, residualdef allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):# 为推理分配缓存空间,用于存储中间状态,调用 mixer 的同名方法return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
2.2 官方完整实例
mixer_seq_simple.py 是官方的一个完整实例。
一个比较关键的函数是create_block函数,相当于使用mamba的一个接口。
create_block函数中包含了具体应该如何使用mamba_simple.py中的Block类
...
def create_block(d_model,ssm_cfg=None,norm_epsilon=1e-5,rms_norm=False,residual_in_fp32=False,fused_add_norm=False,layer_idx=None,device=None,dtype=None,
):if ssm_cfg is None:ssm_cfg = {}factory_kwargs = {"device": device, "dtype": dtype}# 创建Mamba混合器类的偏函数,允许传递额外的配置和工厂关键字参数mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)# 根据是否使用RMSNorm,创建归一化层类的偏函数,并设置epsilon和其他工厂关键字参数norm_cls = partial(nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs)# 创建Block实例,使用前面定义的混合器和归一化类block = Block(d_model,mixer_cls,norm_cls=norm_cls,fused_add_norm=fused_add_norm,residual_in_fp32=residual_in_fp32,)# 设置层索引,便于在模型中跟踪和配置此块block.layer_idx = layer_idxreturn block
...
(未完待续。。。)