【官方Mamba库】原理简述和代码解析

目录

  • 1 代码原理简述
    • 1.1 原始结构——SSM
    • 1.2 结构改进——S4(Structured State Space for Sequences)
      • 1.2.1 离散化
      • 1.2.2HiPPO
    • 1.3 最终版本——Mamba(又称S6或selective SSMs)
  • 2 代码库目录结构
    • 2.1 mamba_simple.py主体结构
      • 2.1.1 Mamba类
        • (1)__init__函数
        • (2)forward函数
        • (3)step函数
      • 2.1.2 Block类
    • 2.2 官方完整实例

Mamba原理:参考链接1,参考链接2,参考链接3
Mamba代码:参考链接,mamba-mini实现参考链接

官方mamba库链接: https://github.com/state-spaces/mamba
(注:以下以1.1.4版本mamba-ssm库的代码为例)

🥳🥳🥳欢迎各位大佬一起来交流讨论🥳🥳🥳

1 代码原理简述

熟悉原理的话可以直接看第二章代码

原理介绍摘录参考了链接,(这位大佬的博客应该是全网最详细的了,想摸清原理一定要读一下)

Mamba算法的本质是借鉴了现代控制理论中的状态空间模型SSM

1.1 原始结构——SSM

状态空间模型SSM的公式如下:
请添加图片描述

通过求解这些方程,可以根据观察到的数据:输入序列和先前状态,去预测系统的未来状态。(A、B、C、D这4个矩阵是可学习的参数,是可以学习到的,学习好之后,在SSM中,矩阵A、B、C、D便固定不变了,但到后续mamba中则这4个矩阵可以随着输入不同而可变)
这一过程可以由如下图来表示:
请添加图片描述

1.2 结构改进——S4(Structured State Space for Sequences)

S4模型对SSM的改进有以下三点:

  1. 采用零阶保持,来进行连续化:由于SSM模型是针对连续函数的,但是在文本、图像等领域,数据都是离散的,所以我们需要将离散的点连续化,才能输入进SSM模型,最后再从连续的输出中采样离散的点来得到真正的输出
  2. 使用循环+卷积结构表示,从而能够并行训练,加快训练速度
  3. 使用HIPPO矩阵,以处理远程依赖

1.2.1 离散化

由于以上公式是针对连续变量的,单实际计算机相关的应用中往往离散的,所以经过采样零阶保持的方法,便有了以下离散化公式(这个公式很重要,在代码中会体现):
请添加图片描述

然后公式变成了这样:
请添加图片描述
在每个时间步,都会涉及到隐藏状态的更新(比如 h k h_{k} hk取决于 B ‾ x k \overline{\mathrm{B}}\mathrm{x}_{k} Bxk A ‾ h k − 1 \overline{\mathrm{A}}\mathrm{h}_{k-1} Ahk1的共同作用结果,然后通过 C h k Ch_k Chk预测输出 y k y_k yk),这样反复套娃,就得到了类似RNN的循环结构:
请添加图片描述
换个图表示更明显:
在这里插入图片描述

用公式来表示就是这样的:
y 2 = C h 2 = C ( A ˉ h 1 + B ˉ x 2 ) = C ( A ˉ ( A ˉ h 0 + B ˉ x 1 ) + B ˉ x 2 ) = C ( A ˉ ( A ˉ ⋅ B ˉ x 0 + B ˉ x 1 ) + B ˉ x 2 ) = C ( A ˉ ⋅ A ˉ ⋅ B ˉ x 0 + A ˉ ⋅ B ˉ x 1 + B ˉ x 2 ) = C ⋅ A ˉ 2 ⋅ B ˉ x 0 + C ⋅ A ˉ ⋅ B ˉ ⋅ x 1 + C ⋅ B ˉ x 2 \begin{aligned} y_2& =Ch_{2} \\ &=C\left(\bar{A}h_{1}+\bar{B}x_{2}\right) \\ &=C\begin{pmatrix}\bar{A}\begin{pmatrix}\bar{A}h_{0}+\bar{B}x_{1}\end{pmatrix}+\bar{B}x_{2}\end{pmatrix} \\ &=C\begin{pmatrix}\bar{A}\begin{pmatrix}\bar{A}\cdot\bar{B}x_{0}+\bar{B}x_{1}\end{pmatrix}+\bar{B}x_{2}\end{pmatrix} \\ &=C\begin{pmatrix}\bar{A}\cdot\bar{A}\cdot\bar{B}x_{0}+\bar{A}\cdot\bar{B}x_{1}+\bar{B}x_{2}\end{pmatrix} \\ &=C\cdot\bar{A}^{2}\cdot\bar{B}x_{0}+C\cdot\bar{A}\cdot\bar{B}\cdot x_{1}+C\cdot\bar{B}x_{2} \end{aligned} y2=Ch2=C(Aˉh1+Bˉx2)=C(Aˉ(Aˉh0+Bˉx1)+Bˉx2)=C(Aˉ(AˉBˉx0+Bˉx1)+Bˉx2)=C(AˉAˉBˉx0+AˉBˉx1+Bˉx2)=CAˉ2Bˉx0+CAˉBˉx1+CBˉx2
由此类推:
y 3 = C A A A B ‾ x 0 + C A A B ‾ x 1 + C A B ‾ x 2 + C B ‾ x 3 y_{3}=\mathbf{C\overline{AAAB}}x_{0}+\mathbf{C\overline{AAB}}x_{1}+\mathbf{C\overline{AB}}x_{2}+\mathbf{C\overline{B}}x_{3} y3=CAAABx0+CAABx1+CABx2+CBx3
y k = C A ˉ k B ˉ x 0 + C A ˉ k − 1 B ˉ x 1 + ⋯ + C A ˉ B ˉ x k − 1 + C B ˉ x k y_{k}=C\bar{A}^{k}\bar{B}x_{0}+C\bar{A}^{k-1}\bar{B}x_{1}+\cdots+C\bar{A}\bar{B}x_{k-1}+C\bar{B}x_{k} yk=CAˉkBˉx0+CAˉk1Bˉx1++CAˉBˉxk1+CBˉxk
这样看来循环结构就更明显了。

1.2.2HiPPO

HiPPO是一种解决长距离依赖问题的方法,用于构建并调整状态矩阵A使其逼近最优解,这种方法比把A初始化为随机矩阵要好得多
请添加图片描述

1.3 最终版本——Mamba(又称S6或selective SSMs)

Mamba模型对于S4模型的改进有以下三点:

  1. 参数化矩阵:对输入信息进行有选择性的处理,从而得到类似Attention的效果,即不同的输入拥有不同的状态,token信息
  2. 硬件感知算法,并行化–选择扫描算法,加快训练推理速度
  3. 更简化的SSM模型架构

在最终的Mamaba中,作者让B矩阵、C矩阵、 Δ \Delta Δ 成为输入的函数,让模型能够根据输入内容自适应地调整其行为
请添加图片描述

最终的整体流程:
请添加图片描述
结构细节如下:
请添加图片描述

其中的“选择性SSM(即Selective SSM)”具有以下属性:

  1. Recurrent SSM通过离散化创建循环SSM
  2. HiPPO对矩阵A进行初始化A以捕获长程依赖性
  3. 选择性扫描算法(Selective scan algorithm)选择性压缩信息
  4. 硬件感知算法(Hardware-aware algorithm)加速计算

Mamba和Trasformer,RNN相对的优势:
请添加图片描述

2 代码库目录结构

mamba
├── benchmarks
│ 	└── benchmark_generation_mamba_simple.py  // 示例模型的推理脚本
├── csrc
│ 	└── selective_scan  // 选择性扫描的c++实现
├── evals
│ 	└── lm_harness_eval.py
├── mamba_ssm
│ 	├── models
│   │   ├── config_mamba.py
│   │   └── mixer_seq_simple.py  // 使用mamba构建的一个完整的语言模型示例
│ 	├── modules
│   │   └── mamba_simple.py   // mamba block的实现
│ 	├── ops
│   │   ├── triton
│   │   │   ├── layernorm.py
│   │   │   ├── selective_state_update.py
│   │   └── selective_scan_interface.py   // 选择性SSM层的实现
│ 	├── utils
│   │   ├── generation.py
│   │   └── hf.py
└── test└── ops├── triton│		├── test_selective_state_update.py└──test_selective_scan.py

2.1 mamba_simple.py主体结构

2.1.1 Mamba类

class Mamba(nn.Module):def __init__(...def forward(self, hidden_states, inference_params=None):...def step(self, hidden_states, conv_state, ssm_state):...def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):...def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):...
(1)__init__函数

Mamba类/__init__函数负责各个模块的初始化

  • 输入投影初始化:输入特征通过线性投影被扩展为两倍的内部特征维度,准备进行卷积和状态空间模型的计算。
  • 卷积层:使用 1D 卷积操作,捕捉局部的时间特征,使用 d_conv 作为卷积核大小。
  • 状态空间模型:通过 A_log 矩阵初始化状态矩阵,并通过跳跃连接 D 来捕捉长期依赖。
  • 时间常数初始化:时间常数 dt 通过 dt_proj 层生成,并在初始化时通过 softplus 函数的逆函数来确保其在合理范围内。
  • 输出层:通过线性投影,将扩展后的特征维度恢复到输入特征维度,完成一步处理后的输出。
def __init__(self,d_model,d_state=16,d_conv=4,expand=2,dt_rank="auto",dt_min=0.001,dt_max=0.1,dt_init="random",dt_scale=1.0,dt_init_floor=1e-4,conv_bias=True,bias=False,use_fast_path=True,  # Fused kernel optionslayer_idx=None,device=None,dtype=None,):factory_kwargs = {"device": device, "dtype": dtype}super().__init__()self.d_model = d_model  # 输入特征维度self.d_state = d_state  # 状态空间模型的维度self.d_conv = d_conv    # 卷积核大小self.expand = expand    # 特征扩展倍数self.d_inner = int(self.expand * self.d_model)  # 扩展后的内部维度self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank  # 自动确定时间常数的秩self.use_fast_path = use_fast_path  # 是否使用优化的快速路径self.layer_idx = layer_idx  # 当前层的索引# 线性投影层,将输入特征扩展为内部维度的两倍self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)# 1D 卷积操作,卷积核大小为 `d_conv`,输入和输出通道数均为 `d_inner`self.conv1d = nn.Conv1d(in_channels=self.d_inner,out_channels=self.d_inner,bias=conv_bias,kernel_size=d_conv,groups=self.d_inner,padding=d_conv - 1,**factory_kwargs,)# 激活函数 SiLU (Sigmoid Linear Unit),一种常用的非线性激活self.activation = "silu"self.act = nn.SiLU()# 再次线性投影,用于生成时间常数、B 和 C 矩阵self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs)# 时间常数的线性投影层self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)# 初始化时间常数的投影层参数,确保初始化时方差保持稳定dt_init_std = self.dt_rank**-0.5 * dt_scaleif dt_init == "constant":nn.init.constant_(self.dt_proj.weight, dt_init_std)elif dt_init == "random":nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)else:raise NotImplementedError# 初始化 dt 投影的偏置,使得F.softplus(dt_bias)后的值在 `dt_min` 和 `dt_max` 之间dt = torch.exp(torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))+ math.log(dt_min)).clamp(min=dt_init_floor)# 使用 softplus 的逆函数来计算偏置inv_dt = dt + torch.log(-torch.expm1(-dt))with torch.no_grad():self.dt_proj.bias.copy_(inv_dt)# 我们的初始化会将所有的Linear.bias变成0,所以这里标记Linear.bias不需要重新初始化self.dt_proj.bias._no_reinit = True# S4D 初始化过程,状态矩阵 A 的对数A = repeat(torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),"n -> d n",d=self.d_inner,).contiguous()A_log = torch.log(A)  # 对 A 取对数,保持精度fp32self.A_log_temp = GwParameters(A_log)self.A_log = self.A_log_temp.paramself.A_log._no_weight_decay = True # 防止权重衰减# 跳跃连接参数 Dself.D_temp = GwParameters(torch.ones(self.d_inner, device=device))self.D = self.D_temp.paramself.D._no_weight_decay = True# 输出层的线性投影,将内部维度投影回原始特征维度self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
(2)forward函数

Mamba类/ forward 函数负责在前向传播过程中对输入的序列数据进行处理。它主要包括以下几步:

  1. 输入预处理:对输入的序列进行线性投影和卷积操作,以提取有用的特征。
  2. 状态空间模型 (SSM):通过状态空间模型对时间序列的短期和长期依赖进行建模,捕捉序列中的动态变化。
  3. 卷积层和激活函数:使用一维卷积层提取局部时序信息,激活函数用于引入非线性特征。
  4. 时间常数的处理:通过 dt_proj 线性投影将特定的时间常数应用到特征上,适应不同时序的动态。
  5. 输出层:处理后的特征经过线性投影恢复为输入的形状,最终得到输出。
[def forward(self, hidden_states, inference_params=None):"""主要思路:1. 输入预处理:对输入的序列进行线性投影和卷积操作,以提取有用的特征。2. 状态空间模型 (SSM) :通过状态空间模型对时间序列的短期和长期依赖进行建模,捕捉序列中的动态变化。3. 卷积层和激活函数:使用一维卷积层提取局部时序信息,激活函数用于引入非线性特征。4. 时间常数的处理:通过 dt_proj 线性投影将特定的时间常数应用到特征上,适应不同时序的动态。5. 输出层:处理后的特征经过线性投影恢复为输入的形状,最终得到输出。hidden_states: (B, L, D) 表示批次大小 B、序列长度 L、特征维度 DReturns: 返回与输入相同形状的输出"""batch, seqlen, dim = hidden_states.shape# 初始化卷积和状态空间模型的状态为 Noneconv_state, ssm_state = None, Noneif inference_params is not None:# 如果推理参数不为空# 从缓存中获取卷积和状态空间模型的状态conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)if inference_params.seqlen_offset > 0:# 调用 step 函数逐步更新状态,并返回输出out, _, _ = self.step(hidden_states, conv_state, ssm_state)return out# 线性投影和转置:将输入序列 (B, L, D) 投影到内部维度,并进行转置 BLH -> HBLxz = rearrange(self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),"d (b l) -> b d l",l=seqlen,)if self.in_proj.bias is not None:  # 如果存在偏置# 将偏置加到投影结果上xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")A = -torch.exp(self.A_log.float())  # 计算 A 矩阵 (d_inner, d_state)# 如果使用快速路径 (不使用推理状态且启用了快速内核)# 在反向传播中,我们将 dx 和 dz 写在一起,以避免 torch.cat 操作if self.use_fast_path and inference_params is None:  # 不支持输出状态# 使用快速内核进行计算out = mamba_inner_fn(xz,self.conv1d.weight,self.conv1d.bias,self.x_proj.weight,self.dt_proj.weight,self.out_proj.weight,self.out_proj.bias,A,None,  # input-dependent BNone,  # input-dependent Cself.D.float(),delta_bias=self.dt_proj.bias.float(),delta_softplus=True,)else: # 如果不使用快速路径 (常规路径)# 将 xz 分割为 x 和 z 两部分 (xz 的形状为 (B, 2 * d_inner, L))x, z = xz.chunk(2, dim=1)  # x 和 z 分别为输入投影的两部分# 处理短卷积 (如果推理时有状态缓存,则更新卷积状态)if conv_state is not None:# 使用 F.pad 填充卷积状态,使得卷积操作在序列长度不足时不会出错conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))  # 更新卷积状态 (B D W)# 执行一维卷积并进行激活操作 (如果存在加速的卷积函数,则使用加速路径)if causal_conv1d_fn is None:  # 如果没有使用加速卷积函数x = self.act(self.conv1d(x)[..., :seqlen])  # 常规卷积操作并激活else:  # 使用加速卷积函数x = causal_conv1d_fn(x=x,                                    # 输入weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),  # 调整卷积权重bias=self.conv1d.bias,                  # 卷积偏置activation=self.activation,             # 激活函数 (silu 或 swish))# We're careful here about the layout, to avoid extra transposes.# We want dt to have d as the slowest moving dimension# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.# 对卷积结果进行投影以生成 dt, B 和 Cx_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # 线性投影 (bl d)dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)  # 分割为 dt, B, Cdt = self.dt_proj.weight @ dt.t()  # 对时间常数 dt 进行线性变换dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)  # 重塑为 (b, d, l)B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()  # 重塑 BC = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()  # 重塑 Cassert self.activation in ["silu", "swish"]# 使用 selective_scan_fn 进行状态空间模型计算y = selective_scan_fn(x,dt,A,B,C,self.D.float(),z=z,delta_bias=self.dt_proj.bias.float(),delta_softplus=True,return_last_state=ssm_state is not None,)# 如果存在状态空间模型的状态缓存if ssm_state is not None:y, last_state = y  # 更新最后的状态ssm_state.copy_(last_state)  # 将更新后的状态缓存起来# 将输出 y 重新调整维度 (B, D, L) -> (B, L, D)y = rearrange(y, "b d l -> b l d")# 通过输出层的线性投影恢复原始特征维度out = self.out_proj(y)return out](<def forward(self, hidden_states, inference_params=None):"""主要思路:1. 输入预处理:对输入的序列进行线性投影和卷积操作,以提取有用的特征。2. 状态空间模型 (SSM) :通过状态空间模型对时间序列的短期和长期依赖进行建模,捕捉序列中的动态变化。3. 卷积层和激活函数:使用一维卷积层提取局部时序信息,激活函数用于引入非线性特征。4. 时间常数的处理:通过 dt_proj 线性投影将特定的时间常数应用到特征上,适应不同时序的动态。5. 输出层:处理后的特征经过线性投影恢复为输入的形状,最终得到输出。hidden_states: (B, L, D) 表示批次大小 B、序列长度 L、特征维度 DReturns: 返回与输入相同形状的输出"""batch, seqlen, dim = hidden_states.shape# 初始化卷积和状态空间模型的状态为 Noneconv_state, ssm_state = None, None# TODO: inference_params这个参数有什么用?什么时候会被调整?
# ============================================== inference_params is True ==============================================if inference_params is not None:# 如果推理参数不为空# 从缓存中获取卷积和状态空间模型的状态conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)if inference_params.seqlen_offset %3E 0:# 调用 step 函数逐步更新状态,并返回输出out, _, _ = self.step(hidden_states, conv_state, ssm_state)return out
# ======================================================================================================================# 线性投影和转置:将输入序列 (B, L, D) 投影到内部维度,并进行转置 BLH -> HBLxz = rearrange(self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),"d (b l) -> b d l",l=seqlen,)if self.in_proj.bias is not None:  # 如果存在偏置# 将偏置加到投影结果上xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") # 在反向传播中,我们将 dx 和 dz 写在一起,以避免 torch.cat 操作A = -torch.exp(self.A_log.float())  # 计算 A 矩阵 (d_inner, d_state)# ============================================== inference_params is None ==============================================# 如果使用快速路径,整体的思路和非快速一致,可以看else中的解释if self.use_fast_path and inference_params is None:  # (不使用推理状态且启用了快速内核)不支持输出状态# 使用快速内核进行计算out = mamba_inner_fn(xz,self.conv1d.weight,self.conv1d.bias,self.x_proj.weight,self.dt_proj.weight,self.out_proj.weight,self.out_proj.bias,A,None,  # input-dependent BNone,  # input-dependent Cself.D.float(),delta_bias=self.dt_proj.bias.float(),delta_softplus=True,)else: # 如果不使用快速路径 (常规路径)# 将 xz 分割为 x 和 z 两部分 (xz 的形状为 (B, 2 * d_inner, L))x, z = xz.chunk(2, dim=1)  # x 和 z 分别为输入投影的两部分# 处理短卷积 (如果推理时有状态缓存,则更新卷积状态)if conv_state is not None:# 使用 F.pad 填充卷积状态,使得卷积操作在序列长度不足时不会出错conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))  # 更新卷积状态 (B D W)# ===================== Conv部分+激活操作 (如果存在加速的卷积函数,则使用加速路径) ↓=======================if causal_conv1d_fn is None:  # 如果没有使用加速卷积函数x = self.act(self.conv1d(x)[..., :seqlen])  # 常规卷积操作并激活else:  # 使用加速卷积函数x = causal_conv1d_fn(x=x,                                    # 输入weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),  # 调整卷积权重bias=self.conv1d.bias,                  # 卷积偏置activation=self.activation,             # 激活函数 (silu 或 swish))# ====================================== Conv部分+激活操作 ↑========================================# ======================================== SSM 部分 ↓ ========================================# We're careful here about the layout, to avoid extra transposes.# We want dt to have d as the slowest moving dimension# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.# ==对卷积结果进行投影以生成 dt, B 和 C==x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # 线性投影 (bl d)dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)  # 分割为 dt, B, Cdt = self.dt_proj.weight @ dt.t()  # 对时间常数 dt 进行线性变换dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)  # 重塑为 (b, d, l)B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()  # 重塑 BC = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()  # 重塑 C#====================================assert self.activation in ["silu", "swish"]# 使用 selective_scan_fn 进行状态空间模型计算y = selective_scan_fn(x,dt,A,B,C,self.D.float(),z=z,delta_bias=self.dt_proj.bias.float(),delta_softplus=True,return_last_state=ssm_state is not None,)# 如果存在状态空间模型的状态缓存if ssm_state is not None:y, last_state = y  # 更新最后的状态ssm_state.copy_(last_state)  # 将更新后的状态缓存起来# ==================================== SSM 部分 ↑ =============================================# 将输出 y 重新调整维度 (B, D, L) -> (B, L, D)y = rearrange(y, "b d l -> b l d")# 通过输出层的线性投影恢复原始特征维度out = self.out_proj(y)return out>)
(3)step函数

Mamba类/ step 函数, 是一次完整的mamba块处理过程(inference_params不为空时执行)。负责在解码过程中处理单个时间步长的数据,并更新卷积状态和状态空间模型(SSM)的状态。主要逻辑分为两部分:

  1. 卷积步骤 (Conv Step):对输入进行卷积操作,用来捕捉局部的时序特征。
  2. 状态空间模型步骤 (SSM Step):利用状态空间模型进行时间序列的动态建模,捕捉序列中的长期依赖关系。
[def step(self, hidden_states, conv_state, ssm_state):"""执行过程中单步处理hidden_states: 当前时间步的输入张量 (B, 1, D), B 是批次大小,1 表示单个 token 的输入,D 是特征维度conv_state: 卷积的状态缓存 (B, D, W)ssm_state: 状态空间模型的状态缓存 (B, D, S)Returns: 返回当前时间步的输出 (B, 1, D),以及更新后的 conv_state 和 ssm_state"""# 确保输入的第二维是1,即仅处理一个 tokendtype = hidden_states.dtypeassert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"# 将输入进行线性投影,并将其 squeeze 以去除长度为 1 的维度 (B, 1, D) -> (B, 2D)xz = self.in_proj(hidden_states.squeeze(1))  # (B, 2D)# 将投影后的 xz 分为两部分:x 和 z,每部分维度为 (B, D)x, z = xz.chunk(2, dim=-1)  # (B, D)# 卷积操作if causal_conv1d_update is None:  # 如果没有使用加速卷积更新函数# 使用 torch.roll 更新卷积状态,向左滚动,丢弃最早的值 (B, D, W)conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))conv_state[:, :, -1] = x  # 将当前输入 x 放到状态缓存的最后一个位置# 执行一维卷积,卷积结果为 (B, D)x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)if self.conv1d.bias is not None:  # 如果存在卷积偏置x = x + self.conv1d.bias  # 加上偏置# 激活函数处理并将数据类型转回为原始输入的 dtypex = self.act(x).to(dtype=dtype)else:  # 如果使用了加速的卷积更新函数x = causal_conv1d_update(x,                     # 输入conv_state,            # 卷积状态rearrange(self.conv1d.weight, "d 1 w -> d w"),  # 卷积权重self.conv1d.bias,      # 卷积偏置self.activation,       # 激活函数 (如 silu 或 swish))# 对卷积后的 x 进行线性投影,得到 dt, B 和 Cx_db = self.x_proj(x)  # (B, dt_rank + 2 * d_state)dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)  # 分割为 dt, B, C# 时间常数 dt 通过线性投影更新,不加偏置dt = F.linear(dt, self.dt_proj.weight)  # (B, d_inner)# 计算状态空间模型 (SSM) 中的 A 矩阵,使用 softplus 确保 A 矩阵为负数A = -torch.exp(self.A_log.float())  # (d_inner, d_state)# 状态空间模型步骤 (SSM Step)if selective_state_update is None:  # 如果没有使用选择性状态更新函数# 离散化 A 和 B 矩阵dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))  # 对 dt 进行 softplus 激活并加上偏置dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))  # 计算离散化后的 A 矩阵dB = torch.einsum("bd,bn->bdn", dt, B)  # 计算离散化后的 B 矩阵# 更新状态空间模型的状态ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)# 计算输出 yy = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)  # 状态乘以 C 得到输出y = y + self.D.to(dtype) * x  # 加上 D * x 的跳跃连接y = y * self.act(z)  # 与 z 经过激活函数的结果相乘 (B, D)else:  # 如果使用了选择性状态更新函数y = selective_state_update(  # 调用加速的选择性状态更新函数ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True)# 将输出通过线性投影层out = self.out_proj(y)  # (B, D)return out.unsqueeze(1), conv_state, ssm_state  # 返回输出 (B, 1, D) 以及更新后的状态](<def step(self, hidden_states, conv_state, ssm_state):"""执行过程中单步处理hidden_states: 当前时间步的输入张量 (B, 1, D), B 是批次大小,1 表示单个 token 的输入,D 是特征维度conv_state: 卷积的状态缓存 (B, D, W)ssm_state: 状态空间模型的状态缓存 (B, D, S)Returns: 返回当前时间步的输出 (B, 1, D),以及更新后的 conv_state 和 ssm_state"""# ===================== 输入预处理部分 ↓ =====================# 确保输入的第二维是1,即仅处理一个 tokendtype = hidden_states.dtypeassert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"# 将输入进行线性投影,并将其 squeeze 以去除长度为 1 的维度 (B, 1, D) -%3E (B, 2D)xz = self.in_proj(hidden_states.squeeze(1))  # (B, 2D)# 将投影后的 xz 分为两部分:x 和 z,每部分维度为 (B, D)x, z = xz.chunk(2, dim=-1)  # (B, D)# ===================== 输入预处理部分 ↑ =====================# ================== 卷积部分 + 激活操作 ↓ =====================if causal_conv1d_update is None:  # 如果没有使用加速卷积更新函数# 使用 torch.roll 更新卷积状态,向左滚动,丢弃最早的值 (B, D, W)conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))conv_state[:, :, -1] = x  # 将当前输入 x 放到状态缓存的最后一个位置# 执行一维卷积,卷积结果为 (B, D)x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)if self.conv1d.bias is not None:  # 如果存在卷积偏置x = x + self.conv1d.bias  # 加上偏置# 激活函数处理并将数据类型转回为原始输入的 dtypex = self.act(x).to(dtype=dtype)else:  # 如果使用了加速的卷积更新函数x = causal_conv1d_update(x,                     # 输入conv_state,            # 卷积状态rearrange(self.conv1d.weight, "d 1 w -> d w"),  # 卷积权重self.conv1d.bias,      # 卷积偏置self.activation,       # 激活函数 (如 silu 或 swish))# ================== 卷积部分 + 激活操作 ↑ =====================# ====================   SSM 部分 ↓ ==========================# 对卷积后的 x 进行线性投影,得到 dt, B 和 Cx_db = self.x_proj(x)  # (B, dt_rank + 2 * d_state)dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)  # 分割为 dt, B, C# 时间常数 dt 通过线性投影更新,不加偏置dt = F.linear(dt, self.dt_proj.weight)  # (B, d_inner)# 计算状态空间模型 (SSM) 中的 A 矩阵,使用 softplus 确保 A 矩阵为负数A = -torch.exp(self.A_log.float())  # (d_inner, d_state)# ====状态空间模型具体步骤 (SSM Step)====if selective_state_update is None:  # 如果没有使用选择性状态更新函数# 离散化 A 和 B 矩阵dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))  # 对 dt 进行 softplus 激活并加上偏置dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))  # 计算离散化后的 A 矩阵dB = torch.einsum("bd,bn->bdn", dt, B)  # 计算离散化后的 B 矩阵# 更新状态空间模型的状态ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)# 计算输出 yy = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)  # 状态乘以 C 得到输出y = y + self.D.to(dtype) * x  # 加上 D * x 的跳跃连接# ======= mamba块右边的激活叠加 ========y = y * self.act(z)  # 与 z 经过激活函数的结果相乘 (B, D)else:  # 如果使用了选择性状态更新函数y = selective_state_update(  # 调用加速的选择性状态更新函数ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True)# =====================  SSM 部分 ↑ =========================# 将输出通过线性投影层out = self.out_proj(y)  # (B, D)return out.unsqueeze(1), conv_state, ssm_state  # 返回输出 (B, 1, D) 以及更新后的状态>)

此外,辅助函数 allocate_inference_cache_get_states_from_cache 用于在推理过程中管理卷积状态和状态空间模型的状态缓存。

def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):"""为推理过程分配卷积和状态空间模型的状态缓存。batch_size: 批次大小max_seqlen: 最大序列长度dtype: 数据类型 (可选)"""device = self.out_proj.weight.device  # 获取设备 (CPU 或 GPU)conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype  # 如果未指定 dtype,则使用卷积权重的 dtype# 初始化卷积状态,形状为 (B, D, W),其中 W 是卷积窗口大小conv_state = torch.zeros(batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype)ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype  # 如果未指定 dtype,则使用 dt_proj 权重的 dtype# 初始化状态空间模型的状态,形状为 (B, D, S),其中 S 是状态维度ssm_state = torch.zeros(batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype)return conv_state, ssm_state  # 返回分配的状态def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):"""从缓存中获取卷积和状态空间模型的状态,若缓存不存在则初始化状态。inference_params: 推理参数batch_size: 批次大小initialize_states: 是否初始化状态"""assert self.layer_idx is not None  # 确保层索引存在if self.layer_idx not in inference_params.key_value_memory_dict:  # 如果缓存中没有当前层的状态batch_shape = (batch_size,)# 初始化卷积状态conv_state = torch.zeros(batch_size,self.d_model * self.expand,self.d_conv,device=self.conv1d.weight.device,dtype=self.conv1d.weight.dtype,)# 初始化状态空间模型的状态ssm_state = torch.zeros(batch_size,self.d_model * self.expand,self.d_state,device=self.dt_proj.weight.device,dtype=self.dt_proj.weight.dtype,# dtype=torch.float32,)# 将初始化的状态存入缓存字典inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)else:  # 如果缓存中已有状态conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]# TODO: What if batch size changes between generation, and we reuse the same states?if initialize_states:  # 如果需要初始化状态conv_state.zero_()  # 将卷积状态清零ssm_state.zero_()  # 将状态空间模型的状态清零return conv_state, ssm_state  # 返回卷积和状态空间模型的状态

2.1.2 Block类

用于构建包含 Mamba 模块的更大网络。封装了 LayerNorm/RMSNorm 和残差连接的简单模块,以及Mamba的类。这个 Block 与传统的 Transformer 预归一化块稍有不同。
标准的 Transformer block 顺序为: LN -> MHA/MLP -> Add。
而这里的顺序是: Add -> LN -> Mixer ,并且返回的是 hidden_states 和 residual。
这种设计主要是出于性能考虑,因为可以将残差和 LayerNorm 操作融合处理。
需要提供残差,除非是第一个块(第一个块不需要残差输入)。

......class Block(nn.Module): def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False):"""这是一个封装了 LayerNorm/RMSNorm 和残差连接的简单模块,包含了一个 mixer(Mamba) 类。这个 Block 与传统的 Transformer 预归一化块稍有不同。标准的 Transformer block 顺序为: LN -> MHA/MLP -> Add。而这里的顺序是: Add -> LN -> Mixer,并且返回的是 hidden_states 和 residual。这种设计主要是出于性能考虑,因为可以将残差和 LayerNorm 操作融合处理。需要提供残差,除非是第一个块(第一个块不需要残差输入)。"""super().__init__()  # 调用父类的初始化方法,确保 nn.Module 的属性被正确初始化self.residual_in_fp32 = residual_in_fp32  # 是否在 fp32 精度下处理残差self.fused_add_norm = fused_add_norm  # 是否启用残差加法和归一化融合self.mixer = mixer_cls(dim)  # 初始化 mixer 模块,该模块用于处理输入数据的混合self.norm = norm_cls(dim)  # 初始化归一化层,默认为 nn.LayerNorm# 如果启用了融合归一化,需要检查 RMSNorm 是否正确导入,并验证 self.norm 是 LayerNorm 或 RMSNormif self.fused_add_norm:assert RMSNorm is not None, "RMSNorm 导入失败"  # 确保 RMSNorm 导入成功assert isinstance(  # 确保 norm 是 LayerNorm 或 RMSNorm 类型self.norm, (nn.LayerNorm, RMSNorm)), "fused_add_norm 模式下仅支持 LayerNorm 和 RMSNorm"def forward(  # 定义前向传播方法self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None):"""将输入传递给编码层的前向传播方法。参数:hidden_states: 输入序列 (必须提供)。residual: 输入的残差,如果 residual 为 None,hidden_states 直接作为残差使用。"""# 如果不启用 fused_add_norm,则进行标准的 LayerNorm 操作if not self.fused_add_norm:# 计算残差: 如果传入了 residual,将其加到 hidden_states 上,否则将 hidden_states 作为 residualresidual = (hidden_states + residual) if residual is not None else hidden_states# 将 residual 经过 norm 归一化hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))# 如果 residual_in_fp32 为真,将 residual 转换为 float32 处理if self.residual_in_fp32:residual = residual.to(torch.float32)else:  # 启用了 fused_add_norm 时,使用融合的归一化和加法操作fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn# 使用 fused_add_norm_fn 处理 hidden_states 和 residualhidden_states, residual = fused_add_norm_fn(hidden_states,self.norm.weight,self.norm.bias,residual=residual,prenorm=True,  # 使用预归一化residual_in_fp32=self.residual_in_fp32,  # 在 fp32 中计算残差eps=self.norm.eps,  # 归一化时使用的 epsilon 值,防止除零错误)# 将归一化后的 hidden_states 传递给 mixer 模块进行进一步处理hidden_states = self.mixer(hidden_states, inference_params=inference_params)# 返回处理后的 hidden_states 和 residualreturn hidden_states, residualdef allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):# 为推理分配缓存空间,用于存储中间状态,调用 mixer 的同名方法return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

2.2 官方完整实例

mixer_seq_simple.py 是官方的一个完整实例。
一个比较关键的函数是create_block函数,相当于使用mamba的一个接口。
create_block函数中包含了具体应该如何使用mamba_simple.py中的Block类

...
def create_block(d_model,ssm_cfg=None,norm_epsilon=1e-5,rms_norm=False,residual_in_fp32=False,fused_add_norm=False,layer_idx=None,device=None,dtype=None,
):if ssm_cfg is None:ssm_cfg = {}factory_kwargs = {"device": device, "dtype": dtype}# 创建Mamba混合器类的偏函数,允许传递额外的配置和工厂关键字参数mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)# 根据是否使用RMSNorm,创建归一化层类的偏函数,并设置epsilon和其他工厂关键字参数norm_cls = partial(nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs)# 创建Block实例,使用前面定义的混合器和归一化类block = Block(d_model,mixer_cls,norm_cls=norm_cls,fused_add_norm=fused_add_norm,residual_in_fp32=residual_in_fp32,)# 设置层索引,便于在模型中跟踪和配置此块block.layer_idx = layer_idxreturn block
...

(未完待续。。。)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1544644.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

OLED(2)驱动篇

文章目录 1 概述2 代码简述2.1 OLED 对象2.2 OLEDProtocol 对象2.3 OLEDFont 对象 3 成果展示 1 概述 1&#xff09;代码仓库&#xff1a;这里尝试了两种面向对象的方式&#xff0c;不足之处敬请指正。 OOP 方式&#xff1a;https://gitee.com/luyaocf/demo-jlc_stm32f407_oop.…

Unity 设计模式 之 行为型模式-【命令模式】【责任链模式】

Unity 设计模式 之 行为型模式-【命令模式】【责任链模式】 目录 Unity 设计模式 之 行为型模式-【命令模式】【责任链模式】 一、简单介绍 二、命令模式&#xff08;Command Pattern&#xff09; 1、什么时候使用命令模式 2、使用命令模式的好处 3、使用时的注意事项 三…

FME学习笔记

读取数据 方法一&#xff1a;add reader 通过读模块来进行数据的读取 方法二&#xff1a;FeatureReader Parameters 通过转换器来进行数据的读取 可以通过空间范围进行筛选 在FME中&#xff0c;所有数据处理都要用到的&#xff0c;绝对的重点&#xff1a;转换器&#xff…

【Python】PyCharm: 强大的 Python 开发环境

⭕️宇宙起点 &#x1f4e2; 引言&#x1f3ac; 什么是 PyCharm&#xff1f;&#x1f528; PyCharm 的核心特性1. 智能代码编辑2. 调试和测试3. 项目和代码结构导航4. 集成 AI 助手5. 远程开发6. 集成数据库7. 科学工具8. 版本控制集成9. Web 开发 &#x1f4e6; 安装 PyCharm&…

黑马智数Day4-1

新增月卡 配置路由完成跳转 {path: /cardAdd,component: () > import(/views/car/car-card/add-card) }<el-button type"primary" click"$router.push(/cardAdd)">添加月卡</el-button> 车辆信息表单验证 <el-form :model"carInf…

Bug:ThreadPoolTaskScheduler搭配CronTask完成定时任务,关闭scheduler后CronTask任务仍然执行?

【问题】执行下面代码后&#xff0c;关闭ThreadPoolTaskScheduler&#xff0c;CronTask仍然继续执行。 Configuration public class config {Beanpublic String getString() throws InterruptedException {Runnable runnable () -> {try {System.out.println("hello r…

《程序猿之设计模式实战 · 适配器模式》

&#x1f4e2; 大家好&#xff0c;我是 【战神刘玉栋】&#xff0c;有10多年的研发经验&#xff0c;致力于前后端技术栈的知识沉淀和传播。 &#x1f497; &#x1f33b; CSDN入驻不久&#xff0c;希望大家多多支持&#xff0c;后续会继续提升文章质量&#xff0c;绝不滥竽充数…

【后端开发】JavaEE初阶—线程安全问题与加锁原理(超详解)

前言&#xff1a; &#x1f308;上期博客&#xff1a;【后端开发】JavaEE初阶—Theard类及常见方法—线程的操作&#xff08;超详解&#xff09;-CSDN博客 &#x1f308;感兴趣的小伙伴看一看小编主页&#xff1a;GGBondlctrl-CSDN博客 &#x1f308;小编会在后端开发的学习中不…

关于javascript中防抖和节流的使用详解

防抖&#xff08;Debounce&#xff09;和节流&#xff08;Throttle&#xff09;是两种常见的优化技巧&#xff0c;通常用于控制函数在短时间内频繁触发的场景&#xff0c;尤其是在处理用户输入、滚动、窗口大小调整等事件时。它们的主要目的是减少不必要的函数调用&#xff0c;…

超详细超实用!!!AI编程之cursor编写设计模式开闭原则实例(四)

云风网 云风笔记 云风知识库 一、设计模式开闭原则定义 当应用的需求改变时&#xff0c;在不修改软件实体&#xff08;项目模块、类、接口方法&#xff09;的源代码或者二进制代码的前提下&#xff0c;可以扩展模块的功能&#xff0c;使其满足新的需求。即软件实体应当对扩展开…

【Linux】nginx连接前端项目

文章目录 一、项目编译1.编译文件2.dist文件 二、Linux nginx配置三、启动nginx 一、项目编译 1.编译文件 2.dist文件 二、Linux nginx配置 在Xshell软件中&#xff0c;点击CtrlAltF进入文件传输找到地址&#xff1a;/usr/local/nginx/html将dist文件传入 找到nginx.conf&…

git add成功后忘记commit的文件丢了?

本文目标&#xff1a;开发人员&#xff0c;在了解git fsck命令用法的条件下&#xff0c;进行git add成功但由于误操作导致丢失的文件找回&#xff0c;达到找回丢失文件的程度。 文章目录 1 痛点2 解决方案3 总结/练习 1 痛点 开发过程中&#xff0c;分支太多&#xff08;基线分…

CREO教程——2 绘制标准图纸

CREO教程——2 绘制标准图纸 说明&#xff1a;继承第一章设置好的配置文件&#xff0c;这一章进行学习分享如何定制自己的图纸图框&#xff0c;参考国家标准距&#xff0c;定制属于设计师或单位的通用图框。 1.设置工作目录 1.1设置工作目录 1.打开软件设置工作目录&#x…

u盘格式化怎么恢复数据?四款工具来救急!

工作中真的没少碰到过那些让人头疼的数据丢失问题&#xff0c;特别是U盘里的宝贝数据一不小心就“蒸发”了&#xff0c;简直让人欲哭无泪。不过别担心&#xff0c;我作为一个数据恢复的新手小白&#xff0c;最近可是亲测了几款超给力的数据恢复软件&#xff0c;今天就来跟大家分…

19c-TNS-12541: TNS:no listener

有套19c单机&#xff0c;没应用任何的补丁&#xff0c;使用lsnrctl status查看监听是异常的&#xff0c;但是lsnrctl start发现监听已运行&#xff0c;当前业务连接都正常&#xff0c; orcl:/home/oracledb> lsnrctl status LSNRCTL for Linux: Version 19.0.0.0.0 - Pro…

打造灵活DateTimePicker日期时间选择器组件:轻松实现时间的独立清除功能

element ui中日期和时间选择器&#xff08;DateTimePicker&#xff09;是一个常见且重要的组件。它允许用户轻松地选择日期和时间&#xff0c;极大地提升了用户体验。然而&#xff0c;在某些场景下&#xff0c;用户可能需要更细粒度的控制&#xff0c;例如单独清除已选择的时间…

下载与安装|Inventor 2025百度云资源分享附教程

如大家所了解的&#xff0c;Inventor是一款专业的三维可视化实体建模软件&#xff0c;主要用于各类二维机械制图、三维制图的设计和开发等操作&#xff0c;可以广泛地应用于零件设计、钣金设计、装配设计等领域。 不同领域的应用证明了Inventor具有强大的兼容性&#xff0c;基…

监控易监测对象及指标之:全面监控Oracle ODBC数据库

在数字化时代&#xff0c;数据库作为存储和管理企业核心数据的基石&#xff0c;其稳定性和性能直接关系到业务的连续性和效率。Oracle数据库以其强大的功能和稳定性&#xff0c;广泛应用于各行各业。为了确保Oracle数据库的稳定运行和高效性能&#xff0c;对其进行全面监控显得…

备战软考Day04-计算机网络

1、计算机网络的分类 2、七层网络体系结构 3、网络的设备与标准 4、TCP/IP协议族 TCP/IP作为Internet的核心协议&#xff0c;被广泛应用于局域网和广域网中&#xff0c;目前已成为事实上的国际标准 1、TCP/IP分层模型 TCP/IP协议是Internet的基础和核心&#xff0c;和OSI参考…

git命令将已经commit的代码push到其他分支

文章目录 一&#xff1a;对于多分支的代码库&#xff0c;将提交记录从一个分支转移到另一个分支是常见需求方法1&#xff1a;撤销commit操作方法2&#xff1a;实用命令git cherry-pick 来移动commit 二、不小心revert导致代码消失的问题 一&#xff1a;对于多分支的代码库&…