基于Mamba2的文本生成实战
深入探索Mamba模型架构与应用 - 商品搜索 - 京东
DeepSeek大模型高性能核心技术与多模态融合开发 - 商品搜索 - 京东
本节将演示一个使用Mamba2模型完成文本生成任务的示例。在这个过程中,我们将充分利用已有的数据集来训练和优化Mamba2模型,以实现高质量的文本生成效果。
12.2.1 文本生成Mamba2模型的完整实现
类似于前面使用Mamba完成文本生成任务,我们这里将使用Mamba2作为主干网格设计文本生成模型。在具体应用上,使用Mamba2模型直接替代原有的Mamba模型即可。完整代码如下:
class Mamba2LMHeadModel(nn.Module):def _ _init_ _(self, args: Mamba2Config, device: Device = None):super()._ _init_ _()self.args = argsself.device = deviceself.backbone = nn.ModuleDict(dict(embedding=nn.Embedding(args.vocab_size, args.d_model, device=device),layers=nn.ModuleList([nn.ModuleDict(dict(mixer=Mamba2(args, device=device),norm=RMSNorm(args.d_model, device=device),))for _ in range(args.n_layer)]),norm_f=RMSNorm(args.d_model, device=device),))self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False, device=device)self.lm_head.weight = self.backbone.embedding.weightdef forward(self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None) -> tuple[LongTensor, list[InferenceCache]]:seqlen = input_ids.shape[1]if h is None:h = [None for _ in range(self.args.n_layer)]x = self.backbone.embedding(input_ids)for i, layer in enumerate(self.backbone.layers):y, h[i] = layer.mixer(layer.norm(x), h[i])x = y + xx = self.backbone.norm_f(x)logits = self.lm_head(x)return logits[:, :seqlen], cast(list[InferenceCache], h)
可以看到,这里的核心在于使用我们之前完成的Mamba2模型作为特征提取的主干网络,以抽取和计算特征。至于其他部分,如输出分类层和返回值,读者可以参考Mamba模型的实现。
12.2.2 基于Mamba2的文本生成
最后,我们将完成基于Mamba2的文本生成实战任务。对于这部分内容,读者可以参考我们在实现Mamba文本生成时所准备的训练框架。完整代码如下:
from model import Mamba2,Mamba2Config
import math
from tqdm import tqdm
import torch
from torch.utils.data import DataLoaderdevice = "cuda"
mamba_cfg = Mamba2Config(d_model=384)
mamba_cfg.chunk_size = 4
model = mamba_model = Mamba2(mamba_cfg,device=device)
model.to(device)
save_path = "./saver/mamba_generator.pth"
model.load_state_dict(torch.load(save_path),strict=False)BATCH_SIZE = 192
seq_len = 64
import get_data_emotion
#import get_data_emotion_2 as get_data_emotion
train_dataset = get_data_emotion.TextSamplerDataset(get_ data_emotion.token_list,seq_len=seq_len)
train_loader = (DataLoader(train_dataset, batch_size=BATCH_SIZE,shuffle=True))optimizer = torch.optim.AdamW(model.parameters(), lr = 2e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max = 1200,eta_min=2e-7,last_epoch=-1)
criterion = torch.nn.CrossEntropyLoss()for epoch in range(48):pbar = tqdm(train_loader,total=len(train_loader))for token_inp,token_tgt in pbar:token_inp = token_inp.to(device)token_tgt = token_tgt.to(device)logits = model(token_inp)loss = criterion(logits.view(-1, logits.size(-1)), token_tgt.view(-1))optimizer.zero_grad()loss.backward()optimizer.step()lr_scheduler.step() # 执行优化器pbar.set_description(f"epoch:{epoch +1}, train_loss:{loss.item():.5f}, lr:{lr_scheduler.get_last_lr()[0]*1000:.5f}")torch.save(model.state_dict(), save_path)
简单给一下代码,具体生成结果,读者可以自行验证。
本书节选自《深入探索Mamba模型架构与应用》,获出版社和作者授权发布。