Mamba环境配置教程【自用】

1. 新建一个Conda虚拟环境

conda create -n mamba python=3.10

在这里插入图片描述

2. 进入该环境

conda activate mamba

在这里插入图片描述

3. 安装torch(建议2.3.1版本)以及相应的 torchvison、torchaudio
直接进入pytorch离线包下载网址,在里面寻找对应的pytorch以及torchvison、torchaudio
CSDN资源
在这里插入图片描述

下载完成后,进入这些文件的目录下,直接使用下面三个指令进行安装即可

pip install torch-2.3.1+cu118-cp310-cp310-linux_x86_64.whl 
pip install torchvision-0.18.1+cu118-cp310-cp310-linux_x86_64.whl 
pip install torchaudio-2.3.1+cu118-cp310-cp310-linux_x86_64.whl

4. 安装triton和transformers库

pip install triton==2.3.1
pip install transformers==4.43.3

5. 安装完这些我们最基本Pytorch环境以及配置完成,接下来就是Mamba所需的一些依赖了,由于Mamba需要底层的C++进行编译,所以还需要手动安装一下cuda-nvcc这个库,直接使用conda命令即可

conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc

6. 最后就是下载最重要的 causal-conv1d 和mamba-ssm库。在这里我们同样选择离线安装的方式,来避免大量奇葩的编译bug。首先进入下面各自的github网址种进行下载对应版本
causal-conv1d —— 1.4.0
在这里插入图片描述
mamba-ssm —— 2.2.2
在这里插入图片描述
和安装pytorch一样,进入下载的.whl文件所在文件夹,直接使用以下指令进行安装

pip install causal_conv1d-1.4.0+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install mamba_ssm-2.2.2+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

7. 安装好环境后,验证一下Mamba块能否成功运行,直接复制下面代码保存问mamba2_test.py,并运行

# Copyright (c) 2024, Tri Dao, Albert Gu.import math
import torch
import torch.nn as nn
import torch.nn.functional as Ffrom einops import rearrange, repeattry:from causal_conv1d import causal_conv1d_fn
except ImportError:causal_conv1d_fn = Nonetry:from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
except ImportError:RMSNormGated, LayerNorm = None, Nonefrom mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combinedclass Mamba2Simple(nn.Module):def __init__(self,d_model,d_state=128,d_conv=4,conv_init=None,expand=2,headdim=64,ngroups=1,A_init_range=(1, 16),dt_min=0.001,dt_max=0.1,dt_init_floor=1e-4,dt_limit=(0.0, float("inf")),learnable_init_states=False,activation="swish",bias=False,conv_bias=True,# Fused kernel and sharding optionschunk_size=256,use_mem_eff_path=True,layer_idx=None,  # Absorb kwarg for general moduledevice=None,dtype=None,):factory_kwargs = {"device": device, "dtype": dtype}super().__init__()self.d_model = d_modelself.d_state = d_stateself.d_conv = d_convself.conv_init = conv_initself.expand = expandself.d_inner = self.expand * self.d_modelself.headdim = headdimself.ngroups = ngroupsassert self.d_inner % self.headdim == 0self.nheads = self.d_inner // self.headdimself.dt_limit = dt_limitself.learnable_init_states = learnable_init_statesself.activation = activationself.chunk_size = chunk_sizeself.use_mem_eff_path = use_mem_eff_pathself.layer_idx = layer_idx# Order: [z, x, B, C, dt]d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheadsself.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)conv_dim = self.d_inner + 2 * self.ngroups * self.d_stateself.conv1d = nn.Conv1d(in_channels=conv_dim,out_channels=conv_dim,bias=conv_bias,kernel_size=d_conv,groups=conv_dim,padding=d_conv - 1,**factory_kwargs,)if self.conv_init is not None:nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)# self.conv1d.weight._no_weight_decay = Trueif self.learnable_init_states:self.init_states = nn.Parameter(torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs))self.init_states._no_weight_decay = Trueself.act = nn.SiLU()# Initialize log dt biasdt = torch.exp(torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))+ math.log(dt_min))dt = torch.clamp(dt, min=dt_init_floor)# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759inv_dt = dt + torch.log(-torch.expm1(-dt))self.dt_bias = nn.Parameter(inv_dt)# Just to be explicit. Without this we already don't put wd on dt_bias because of the check# name.endswith("bias") in param_grouping.pyself.dt_bias._no_weight_decay = True# A parameterassert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)A_log = torch.log(A).to(dtype=dtype)self.A_log = nn.Parameter(A_log)# self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)self.A_log._no_weight_decay = True# D "skip" parameterself.D = nn.Parameter(torch.ones(self.nheads, device=device))self.D._no_weight_decay = True# Extra normalization layer right before output projectionassert RMSNormGated is not Noneself.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs)self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)def forward(self, u, seq_idx=None):"""u: (B, L, D)Returns: same shape as u"""batch, seqlen, dim = u.shapezxbcdt = self.in_proj(u)  # (B, L, d_in_proj)A = -torch.exp(self.A_log)  # (nheads) or (d_inner, d_state)initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.learnable_init_states else Nonedt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)if self.use_mem_eff_path:# Fully fused pathout = mamba_split_conv1d_scan_combined(zxbcdt,rearrange(self.conv1d.weight, "d 1 w -> d w"),self.conv1d.bias,self.dt_bias,A,D=self.D,chunk_size=self.chunk_size,seq_idx=seq_idx,activation=self.activation,rmsnorm_weight=self.norm.weight,rmsnorm_eps=self.norm.eps,outproj_weight=self.out_proj.weight,outproj_bias=self.out_proj.bias,headdim=self.headdim,ngroups=self.ngroups,norm_before_gate=False,initial_states=initial_states,**dt_limit_kwargs,)else:z, xBC, dt = torch.split(zxbcdt, [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], dim=-1)dt = F.softplus(dt + self.dt_bias)  # (B, L, nheads)assert self.activation in ["silu", "swish"]# 1D Convolutionif causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:xBC = self.act(self.conv1d(xBC.transpose(1, 2)).transpose(1, 2))  # (B, L, self.d_inner + 2 * ngroups * d_state)xBC = xBC[:, :seqlen, :]else:xBC = causal_conv1d_fn(x=xBC.transpose(1, 2),weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),bias=self.conv1d.bias,activation=self.activation,).transpose(1, 2)# Split into 3 main branches: X, B, C# These correspond to V, K, Q respectively in the SSM/attention dualityx, B, C = torch.split(xBC, [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)y = mamba_chunk_scan_combined(rearrange(x, "b l (h p) -> b l h p", p=self.headdim),dt,A,rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),chunk_size=self.chunk_size,D=self.D,z=None,seq_idx=seq_idx,initial_states=initial_states,**dt_limit_kwargs,)y = rearrange(y, "b l h p -> b l (h p)")# Multiply "gate" branch and apply extra normalization layery = self.norm(y, z)out = self.out_proj(y)return outif __name__ == '__main__':model = Mamba2Simple(256).cuda()inputs = torch.randn(2, 128, 256).cuda()pred = model(inputs)print(pred.size())      

在这里插入图片描述

参考文献

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1537426.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

面试官问:你如何看待加班?

面试官问:你如何看待加班? 面试官问:你如何看待加班?这类问题是比较常见的,出现频率相当高。有些同学看到这样的问题,就会断定这家公司估计是经常加班的,绝对的不能去!!…

如何制作一张动态壁纸? ------居然也和编程有关,简单总结一下下,因为我也很好奇哈哈(含实战代码和效果演示)

相关语言: html ,css,JavaScript,c# c#的开发难度高,本文以前端三剑客为例。 理论基础 制作一款类似于Wallpaper Engine上的动态壁纸是一个有趣的项目,它涉及到一些基本的编程知识以及图形设计技能。下…

C++_类和对象(中、下篇)—— const成员函数、取地址运算符的重载、深入构造函数、类型转换、static成员、友元

目录 三、类和对象(中) 6、取地址运算符重载 1、const成员函数 2、取地址运算符的重载 四、类和对象(下) 1、深入构造函数 2、类型转换 3、static成员 4、友元 三、类和对象(中) 6、取地址运算…

削峰+限流:秒杀场景下的高并发写请求解决方案

我是小米,一个喜欢分享技术的29岁程序员。如果你喜欢我的文章,欢迎关注我的微信公众号“软件求生”,获取更多技术干货! 哈喽,大家好!我是小米,一个29岁、活泼积极、热衷分享技术的码农。今天和大家聊一聊应对高并发的写请求这个主题,尤其是在大促、秒杀这种场景下,系统…

南京信息工程大学《2020年+2021年817自动控制原理真题》 (完整版)

本文内容,全部选自自动化考研联盟的:《25届南京信息工程大学817自控考研资料》的真题篇。后续会持续更新更多学校,更多年份的真题,记得关注哦~ 目录 2020年真题 2021年真题 Part1:20202021年完整版真题 2020年真题…

Unity3D下如何播放RTSP流?

技术背景 在Unity3D中直接播放RTSP(Real Time Streaming Protocol)流并不直接支持,因为Unity的内置多媒体组件(如AudioSource和VideoPlayer)主要设计用于处理本地文件或HTTP流,而不直接支持RTSP。所以&…

并查集的应用

目录 1.并查集的代码 2.union操作 3.find操作 4.图 写代码:定义一个并查集(用长度为n的数组实现) 基于上述定义,实现并查集的基本操作—— 并 Union 基于上述定义,实现并查集的基本操作—— 查 Find 自己设计一…

欧美游戏市场的差异

欧洲和美国的游戏市场虽然高度发达且利润丰厚,但表现出由文化偏好、消费者行为、监管环境和平台受欢迎程度塑造的独特特征。这些差异对于寻求为每个地区量身定制策略的游戏开发商和发行商来说非常重要。 文化偏好和游戏类型 美国:美国游戏市场倾向于青…

Java基础尚硅谷84-面向对象-package与import关键字的使用

曾国藩说,基础不牢,很难走得远。 所以时时回顾一下Java基础,打好地基,让自己走得更稳,更远。 今天这节课,学到对自己有点价值的东西是: 1 import 导入A包.*,可以使用A包下&#xff…

操作系统、数据库

操作系统 管道:读写先进先出,不能想读哪里就读哪里,想写哪里就哪里 内存 块表的理论支撑:局部性原理:1.时间局部性和空间局部性原理。 内存映射:

目标检测——flask后端YOLOv8检测视频,前端实时显示检测结果

前端代码&#xff1a; index.html <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>YOLOv8 Video S…

网络安全笔试练习题,据说10分钟内答对的都是高手!

《网安面试指南》http://mp.weixin.qq.com/s?__bizMzkwNjY1Mzc0Nw&mid2247484339&idx1&sn356300f169de74e7a778b04bfbbbd0ab&chksmc0e47aeff793f3f9a5f7abcfa57695e8944e52bca2de2c7a3eb1aecb3c1e6b9cb6abe509d51f&scene21#wechat_redirect 《Java代码审…

CSP-J/S赛前知识点大全2:初赛纯靠记忆的知识点

-NOI的中文意思是&#xff08;全国青少年信息学奥林匹克竞赛&#xff09;。 -NOIP从&#xff08;2022&#xff09;年开始不支持Pascal语言。 -中国计算机学会&#xff08;CCF&#xff09;于&#xff08;1962&#xff09;年成立&#xff0c;于(1984)年创办全国青少年计算机程序设…

惬意享受阅读,优雅的微信公众号订阅方式,极空间部署『WeWe RSS』

惬意享受阅读&#xff0c;优雅的微信公众号订阅方式&#xff0c;极空间部署『WeWe RSS』 哈喽小伙伴们好&#xff0c;我是Stark-C~ 不知道大家平时是怎么阅读自己关注的公众号文章的&#xff0c;是不是基本就靠微信平自动提醒更新呢&#xff1f;如果是这样&#xff0c;那么我…

dubbo二

dubbo dubbo扩展加载流程 服务调用过程 线程派发模型 多版本控制 集群容错 策略对比 负载均衡及其实现

ICM20948 DMP代码详解(25)

接前一篇文章&#xff1a;ICM20948 DMP代码详解&#xff08;24&#xff09; 上一回讲到了inv_icm20948_load_firmware函数&#xff0c;对于大体功能进行了介绍&#xff0c;本回深入其具体实现代码细节。为了便于理解和回顾&#xff0c;再次贴出相关代码&#xff1a; //Setup Iv…

甲骨文创始人埃里森:人工智能终有一天会追踪你的一举一动

9月17日消息&#xff0c;据外电报道&#xff0c;甲骨文创始人拉里埃里森在甲骨文财务分析师会议上表示&#xff0c;他预计人工智能有一天将为大规模执法监控网络提供动力。“我们将进行监督。”他说。“每一位警察都将随时受到监督&#xff0c;如果有问题&#xff0c;人工智能会…

从0到一个漏洞几千块,走了这么久,走了这么远,当然还要继续走下去。

从0到一个漏洞几千块&#xff0c;走了这么久&#xff0c;走了这么远&#xff0c;当然还要继续走下去。

odb使用

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、创建学生类和班级类1.学生类2.班级类3.生成数据库支持代码 二、创建数据库对象&#xff0c;对数据库进行操作1.构建连接池工厂配置对象2.构造数据库操作对象…

概率分布深度解析:PMF、PDF和CDF的技术指南

本文将深入探讨概率分布&#xff0c;详细阐述概率质量函数&#xff08;PMF&#xff09;、概率密度函数&#xff08;PDF&#xff09;和累积分布函数&#xff08;CDF&#xff09;这些核心概念&#xff0c;并通过实际示例进行说明。 在深入探讨PMF、PDF和CDF之前&#xff0c;有必…