LLM各层参数详细分析(以LLaMA为例)

网上大多分析LLM参数的文章都比较粗粒度,对于LLM的精确部署不太友好,在这里记录一下分析LLM参数的过程。


首先看QKV。先上transformer原文
在这里插入图片描述
也就是说,当h(heads) = 1时,在默认情况下, W i Q W_i^Q WiQ W i K W_i^K WiK W i V W_i^V WiV都是2维方阵,方阵维度是 d m o d e l × d m o d e l d_{model} \times d_{model} dmodel×dmodel.

结合llama源码 (https://github.com/facebookresearch/llama/blob/main/llama/model.py)

class ModelArgs:dim: int = 4096n_layers: int = 32n_heads: int = 32n_kv_heads: Optional[int] = Nonevocab_size: int = -1  # defined later by tokenizermultiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2ffn_dim_multiplier: Optional[float] = Nonenorm_eps: float = 1e-5max_batch_size: int = 32max_seq_len: int = 2048
# ...class Attention(nn.Module):"""Multi-head attention module."""def __init__(self, args: ModelArgs):"""Initialize the Attention module.Args:args (ModelArgs): Model configuration parameters.Attributes:n_kv_heads (int): Number of key and value heads.n_local_heads (int): Number of local query heads.n_local_kv_heads (int): Number of local key and value heads.n_rep (int): Number of repetitions for local heads.head_dim (int): Dimension size of each attention head.wq (ColumnParallelLinear): Linear transformation for queries.wk (ColumnParallelLinear): Linear transformation for keys.wv (ColumnParallelLinear): Linear transformation for values.wo (RowParallelLinear): Linear transformation for output.cache_k (torch.Tensor): Cached keys for attention.cache_v (torch.Tensor): Cached values for attention."""super().__init__()self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_headsmodel_parallel_size = fs_init.get_model_parallel_world_size()self.n_local_heads = args.n_heads // model_parallel_sizeself.n_local_kv_heads = self.n_kv_heads // model_parallel_sizeself.n_rep = self.n_local_heads // self.n_local_kv_headsself.head_dim = args.dim // args.n_heads

计算出
self.n_kv_heads = h = 32
self.head_dim = 4096/32=128
所以 W i Q W_i^Q WiQ W i K W_i^K WiK W i V W_i^V WiV 大小都为(4096, 128).(在未拆分前 W Q W^Q WQ W K W^K WK W V W^V WV都是 ( d i m , d i m ) = ( 4096 , 4096 ) (dim, dim) = (4096,4096) (dim,dim)=(4096,4096)大小)

Q , K , V Q,K,V Q,K,V的大小都是 ( n c t x , d i m ) = ( 2048 , 4096 ) (n_{ctx}, dim) = (2048,4096) (nctx,dim)=(2048,4096)在多头公式里。在self-attention里,其实他们都是同一个值:输入X),所以 Q × W i Q Q×W_i^Q Q×WiQ K × W i K K×W_i^K K×WiK Q × W i Q Q×W_i^Q Q×WiQ 都是 ( n c t x , d k ) = ( 2048 , 128 ) (n_{ctx}, d_k)=(2048,128) (nctx,dk)=(2048,128)。带入原文attention公式后,大小为(2048, 128)不变。Attention不改变大小(在默认 d k = d v d_k=d_v dk=dv情况下)。
在这里插入图片描述

经过Cancat,分开的头又合并,大小变为(2048, 4096)矩阵,经过 W O W^O WO大小是(4096,4096))全连接,还是(2048, 4096)矩阵。


然后看Feed forward.根据源码,

class FeedForward(nn.Module):def __init__(self,dim: int,hidden_dim: int,multiple_of: int,ffn_dim_multiplier: Optional[float],):"""Initialize the FeedForward module.Args:dim (int): Input dimension.hidden_dim (int): Hidden dimension of the feedforward layer.multiple_of (int): Value to ensure hidden dimension is a multiple of this value.ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.Attributes:w1 (ColumnParallelLinear): Linear transformation for the first layer.w2 (RowParallelLinear): Linear transformation for the second layer.w3 (ColumnParallelLinear): Linear transformation for the third layer."""super().__init__()hidden_dim = int(2 * hidden_dim / 3)# custom dim factor multiplierif ffn_dim_multiplier is not None:hidden_dim = int(ffn_dim_multiplier * hidden_dim)hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)def forward(self, x):return self.w2(F.silu(self.w1(x)) * self.w3(x))

multiattention layer过后,经过加法和normlayer(RMS norm),进入feed_forward前馈网络。注意这里的前馈网络其中一个维度会有8/3≈2.7的放缩,然后multiple_of又保证必须是256的倍数,所以这里算出来hidden_dim是256的倍数中与8/3*4096最接近的,是11008。以这里的w1,w3大小为(4096,11008),w2大小为(11008,4096). 输出结果大小

整个decode layer计算如图所示,
在这里插入图片描述

来源:https://github.com/microsoft/Llama-2-Onnx/blob/main/Images/DecoderLayer.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/142805.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

电商项目高级篇-02 elasticsearch-下

电商项目高级篇-02 elasticsearch-下 4.2、QueryDSL返回指定字段 4.2、QueryDSL 返回指定字段 返回单个字段 GET bank/_search {"query": {"match_all": {}}, "sort": [{"balance": {"order": "desc"}}], &quo…

详解yolov1理论 代码

目标检测要解决的3大问题: 1、有没有? 图片中是否有要检测的物体?(检测物体,判定前景背景) 2、是什么? 这些物体分别是什么?(检测到的物体是什么) 3、在…

【RV1103】RTL8723bs (SD卡形状模块)驱动开发

文章目录 前言硬件分析Luckfox Pico的SD卡接口硬件原理图LicheePi zero WiFiBT模块总结 正文Kernel WiFi驱动支持Kernel 设备树支持修改一:修改二: SDK全局配置支持 wifi全局编译脚本支持编译逻辑拷贝rtl8723bs的固件到文件系统的固定目录里面去 上电后手…

Learn Prompt- Midjourney案例:Logo设计

Logo设计是一个充满挑战的任务,因为Logo是品牌重要价值的浓缩。 快速开始​ 直接使用logo design for...来获取灵感。 备注 图像中生成文字在Midjourney中的效果还不是很好,但你可以用Canva编辑图片并替换自己的文字。 在提示中使用那些擅长你所寻找的…

Spring整合RabbitMQ——生产者(利用配置类)

1.生产者配置步骤 2.引入依赖 3.编写配置 配置RabbitMQ的基本信息,用来创建连接工厂的 编写启动类 编写配置类 4. 编写测试类

idea没有maven工具栏解决方法

背景:接手的一些旧项目,有pom文件,但是用idea打开的时候,没有认为是maven文件,所以没有maven工具栏,不能进行重新加载pom文件中的依赖。 解决方法:选中pom.xml文件,右键 选择添加为…

Stm32_标准库_呼吸灯_按键控制

Stm32按键和输出差不多 PA1为LED供给正电,PB5放置按键,按键一端接PB5,另一端接负极 void Key_Init(void){RCC_APB2PeriphClockCmd(RCC_APB2Periph_GPIOB, ENABLE); //APB2总线连接着GPIOBGPIO_InitStructur.GPIO_Mode GPIO_Mode_IPU;GPIO_InitStructur.…

24. 图论 - 图的表示种类

Hi,你好。我是茶桁。 之前的一节课中,我们了解了图的来由和构成,简单的理解了一下图的一些相关概念。那么这节课,我们要了解一下图的表示,种类。相应的,我们中间需要穿插一些新的知识点用于更好的去理解图,比如说邻接矩阵。 图的表示 我们一般用什么样的形式来表示图…

http基础教程(超详细)

HTTP HTTP 一 、基础概念 请求和响应报文URL 二、HTTP 方法 GETHEADPOSTPUTPATCHDELETEOPTIONSCONNECTTRACE 三、HTTP 状态码 1XX 信息2XX 成功3XX 重定向4XX 客户端错误5XX 服务器错误 四、HTTP 首部 通用首部字段请求首部字段响应首部字段实体首部字段 五、具体应用 连接管理…

使用 Python 函数callable和isinstance的意义

一、说明 在这篇博客中,我们将探讨两个python函数:1 callable 中的函数及其有趣的应用程序。该callable函数用于检查对象是否可调用,这意味着它可以作为函数调用。2 isinstance这个内置函数允许我们比较两种不同的数据类型并确定它们是否相…

Flink--6、输出算子(连接到外部系统、文件、kafka、MySQL、自定义Sink)

星光下的赶路人star的个人主页 世间真正温煦的春色,都熨帖着大地,潜伏在深谷 文章目录 1、输出算子(Sink)1.1 连接到外部系统1.2 输出到文件1.3 输出到Kafka1.4 输出到MySQL(JDBC)1.4 自定义Sink输出 1、输…

[密码学入门]仿射密码(Affine)

加密算法y(axb)mod N 解密算法x*(y-b)mod N(此处的为a关于N的乘法逆元,不是幂的概念) 如何求,涉及的知识挺多,还没想好怎么写,丢番图方程,贝祖定理(又译裴蜀定理),扩展欧…

IO流————

一、字符流 前面我们学习了字节流,使用字节流可以读取文件中的字节数据。但是如果文件中有中文,使用字节流来读取,就有可能读到半个汉字的情况,这样会导致乱码。虽然使用读取全部字节的方法不会出现乱码,但是如果文件过大又不太合适。 所以Java专门为我们提供了另外一种…

opencv: 解决保存视频失败的问题

摘要:opencv能读取视频,但保存视频时报错。 一、首先要确保已经下载了openh264.dll文件,否则保存的视频无法打开,详细可以浏览这个:opencv:保存视频。 二、保存视频时出现一下问题: OpenCV:…

vue_Delete `␍`eslint(prettier/prettier)

Delete ␍eslint(prettier/prettier) 错误的解决方案 问题背景 在Windows笔记本上新拉完代码,在执行pre-commit时,出现如下错误: Delete ␍eslint(prettier/prettier)问题根源 罪魁祸首是git的一个配置属性:core.autocrlf 由于…

【空间-光谱联合注意网络:多时相遥感图像】

A Spatial–Spectral Joint Attention Network for Change Detection in Multispectral Imagery (一种用于多光谱图像变化检测的空间-光谱联合注意网络) 变化检测是通过比较双时相图像来确定和评估变化,这是遥感领域的一项具有挑战性的任务…

Oracle VM VirtualBox安装并下载安装CentOS7

Oracle VM VirtualBox安装并下载安装CentOS7 Oracle VM VirtualBox下载CentOS创建虚拟机 Oracle VM VirtualBox VM下载链接 https://www.oracle.com/cn/virtualization/virtualbox/ 点击链接直接下载就行,下载完默认安装或者更改一下安装目录。 下载CentOS http://…

快速排序与冒泡排序以及代码

快速排序 快速排序(Quicksort)是一种常用的排序算法,它基于分治的思想。 时间复杂度:O(nlogn) 空间复杂度:O(logn) 快速排序的基本思想如下: 选择一个元素…

vue3 + mark.js | 实现文字标注功能

页面效果 具体实现 新增 1、监听鼠标抬起事件,通过window.getSelection()方法获取鼠标用户选择的文本范围或光标的当前位置。2、通过 选中的文字长度是否大于0或window.getSelection().isCollapsed (返回一个布尔值用于描述选区的起始点和终止点是否位于一个位置&…

设计模式 - 代理模式

目录 一. 前言 二. 实现 三. 静态代理和动态代理 一. 前言 代理模式(Proxy Pattern),为某个对象提供一种代理以控制对对象的访问。即客户端可通过代理对象间接访问目标对象,同时可限制、增强、修改目标对象的一些特性。访问者不…