Transformer推理结构简析(Decoder + MHA)

一、Transformer 基本结构

Transformer由encoder和decoder组成,其中:

  • encoder主要负责理解(understanding) The encoder’s role is to generate a rich representation (embedding) of the input sequence, which the decoder can use if needed

  • decoder主要负责生成(generation) The decoder outputs tokens one by one, where the current output depends on the previous tokens. This process is called auto-regressive generation

基本结构如下:

encoder结构和decoder结构基本一致(除了mask),所以主要看decoder即可:

每个核心的Block包含:

  • Layer Norm
  • Multi headed attention
  • A skip connection
  • Second layer Norm
  • Feed Forward network
  • Another skip connection

看下llama decoder部分代码,摘自transformers/models/llama/modeling_llama.py,整个forward过程和上图过程一模一样, 只是layer_norm换成了LlamaRMSNorm:

# 省略了一些不重要的code
class LlamaDecoderLayer(nn.Module):def __init__(self, config: LlamaConfig, layer_idx: int):...def forward(self,hidden_states: torch.Tensor,attention_mask: Optional[torch.Tensor] = None,position_ids: Optional[torch.LongTensor] = None,past_key_value: Optional[Tuple[torch.Tensor]] = None,output_attentions: Optional[bool] = False,use_cache: Optional[bool] = False,cache_position: Optional[torch.LongTensor] = None,**kwargs,) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:# hidden_states [bsz, q_len, hidden_size]residual = hidden_stateshidden_states = self.input_layernorm(hidden_states)# Self Attention 即MHAhidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states,attention_mask=attention_mask,position_ids=position_ids,past_key_value=past_key_value,output_attentions=output_attentions,use_cache=use_cache,cache_position=cache_position,**kwargs,)hidden_states = residual + hidden_states# Fully Connecteresidual = hidden_stateshidden_states = self.post_attention_layernorm(hidden_states)hidden_states = self.mlp(hidden_states)hidden_states = residual + hidden_statesoutputs = (hidden_states,)return outputs

上述代码展示的是标准的decoder过程,几个关键输入:

  • hidden_states [batch_size, seq_len, embed_dim] seq_len表示输入长度

  • attention mask的size为(batch_size, 1, query_sequence_length, key_sequence_length) 注意力掩码,实际使用的时候,PyTorch 会自动广播这个掩码到注意力权重矩阵的形状 [bsz, num_heads, q_len, kv_seq_len]。

  • position_ids or position_embeddings,位置id或者已经提前计算好的位置embedding

上述最核心的结构是其调用的self.self_attn,即是Multi-headed attention


二、Multihead Attention

Multihead Attention,多头注意力,上述decoder过程中最核心的地方,同时也是算子优化发力的地方。要理解多头先从单头开始。

单个attention

即Scaled Dot-Product Attention,公式如下:

其中QKV的维度一致,比如这里都是(2,3):

那么QKV怎么得到的呢?通过输入embedding和WqWkWv相乘得到qkv,这里WqWkWv是可学习参数:

拆开合起来计算都是等价的,上述的X1和X2是拆开计算,但是组合起来为(2,4)维度,同样可以和WqWkWv进行矩阵乘

实际在decoder的计算中,会带入causal (or look-ahead) mask

Causal mask 是为了确保模型在解码时不会关注未来的 token,这对于生成任务是必不可少的。通过这个掩码,模型只会依赖已经生成的 token,确保解码过程中是自回归的。

多个attention

实际中一般都是多个attention,和单个attention区别不大,下图右侧是多个attention的结构:

自注意力在多个头部之间并行应用,最后将结果连接在一起,我们输入(2,4)维度的X,分别和不同头的WqWkWv进行矩阵乘法得到每个头对应的QKV,然后QKV算出Z,再将所有Z合并和Wo相乘得到维度和X一致的Z:

实际中需要学习的权重为每个头的WqWkWv,同时也需要一个Wo,看一下llama中的实际计算过程:

class LlamaAttention(nn.Module):"""Multi-headed attention from 'Attention Is All You Need' paper"""def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):super().__init__()self.config = configself.layer_idx = layer_idxself.attention_dropout = config.attention_dropoutself.hidden_size = config.hidden_sizeself.num_heads = config.num_attention_headsself.head_dim = self.hidden_size // self.num_headsself.num_key_value_heads = config.num_key_value_headsself.num_key_value_groups = self.num_heads // self.num_key_value_headsself.max_position_embeddings = config.max_position_embeddingsself.rope_theta = config.rope_thetaself.is_causal = True# 这行代码是一个检查条件,确保hidden_size能够被num_heads整除。# 在多头注意力(Multi-Head Attention, MHA)机制中,输入的hidden_size被分割成多个头,每个头处理输入的一个子集。# head_dim是每个头处理的维度大小,它由hidden_size除以num_heads得到。if (self.head_dim * self.num_heads) != self.hidden_size:raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"f" and `num_heads`: {self.num_heads}).")# 需要学习更新的四个权重 WqWkWvWoself.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)self._init_rope()def forward(self,hidden_states: torch.Tensor,attention_mask: Optional[torch.Tensor] = None,position_ids: Optional[torch.LongTensor] = None,past_key_value: Optional[Cache] = None,output_attentions: bool = False,use_cache: bool = False,**kwargs,) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:bsz, q_len, _ = hidden_states.size()query_states = self.q_proj(hidden_states)key_states = self.k_proj(hidden_states)value_states = self.v_proj(hidden_states)query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)kv_seq_len = key_states.shape[-2]if past_key_value is not None:if self.layer_idx is None:raise ValueError(f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} ""for auto-regressive decoding with k/v caching, please make sure to initialize the attention class ""with a layer index.")kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)if past_key_value is not None:cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE modelskey_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)key_states = repeat_kv(key_states, self.num_key_value_groups)value_states = repeat_kv(value_states, self.num_key_value_groups)attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)if attention_mask is not None:if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):raise ValueError(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")attn_weights = attn_weights + attention_mask# upcast attention to fp32attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)# (bsz, self.num_heads, q_len, self.head_dim)attn_output = torch.matmul(attn_weights, value_states)# (bsz, q_len, self.num_heads, self.head_dim)attn_output = attn_output.transpose(1, 2).contiguous()attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)attn_output = self.o_proj(attn_output)return attn_output, attn_weights, past_key_value

上述的计算过程可以总结为以下几个步骤,假设输入张量hidden_states的维度为[batch_size, seq_length, hidden_size]

在这里插入图片描述
在这里插入图片描述

这个过程实现了将输入通过多个注意力"头"并行处理的能力,每个"头"关注输入的不同部分,最终的输出是所有"头"输出的拼接,再经过一个线性变换。这种机制增强了模型的表达能力,使其能够从多个子空间同时捕获信息。

因为不用GQA,q_len 就是 seq_length 就是 kv_seq_len


三、MHA计算

实际的多头计算代码如下,这里是通过torch.matmul实现的:

 attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)if attention_mask is not None:  # no matter the length, we just slice itcausal_mask = attention_mask[:, :, :, : key_states.shape[-2]]attn_weights = attn_weights + causal_maskposition_ids: Optional[torch.LongTensor] = None,# upcast attention to fp32attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)attn_output = torch.matmul(attn_weights, value_states)

上述代码中query_states和key_states的形状分别为[bsz, num_heads, q_len, head_dim]和[bsz, num_heads, kv_seq_len, head_dim]。matmul操作会自动在最后两个维度上进行矩阵乘法,并在前两个维度上进行广播。

应用注意力掩码attn_weights = attn_weights + causal_mask,causal_mask的形状可能是[1, 1, q_len, kv_seq_len]。PyTorch会自动将其广播到attn_weights的形状[bsz, num_heads, q_len, kv_seq_len]。

应用softmax和dropout,然后最后计算attn_output = torch.matmul(attn_weights, value_states),其中attn_weights的形状为[bsz, num_heads, q_len, kv_seq_len],value_states的形状为[bsz, num_heads, kv_seq_len, head_dim]。matmul操作会在最后两个维度上进行矩阵乘法,并在前两个维度上进行广播。这里attn_output的维度为bsz, num_heads, q_len, self.head_dim


四、torch.matmul

多维矩阵乘法,支持多维和broadcast,比较复杂:

  • 如果两个输入张量都是一维张量,执行的是点积操作,返回一个标量
  • 如果两个输入张量都是二维张量,执行的是矩阵乘法,返回一个新的二维矩阵,这个操作就是常见的
  • 如果第一个张量是一维张量,第二个张量是二维张量,则会在第一张量的维度前面添加一个1(扩展为2维),然后进行矩阵乘法,计算完后会移除添加的维度
  • 如果第一个张量是二维张量,第二个张量是一维张量,则执行的是矩阵-向量乘法,返回一个一维张量
  • 当两个输入张量中有一个是多维的(N > 2),会执行批量矩阵乘法。在这种情况下,非矩阵的维度(批量维度)会被广播(broadcasted)。如果一个张量是一维,会对其进行维度扩展和移除

我们这里的多维数据 matmul() 乘法,可以认为该乘法使用使用两个参数的后两个维度来计算,其他的维度都可以认为是batch维度。

比如,输入张量的形状为 (j×1×n×n) 和 (k×n×n) 时,会输出形状为 (j×k×n×n) 的张量。

具体点,假设两个输入的维度分别是input (1000×500×99×11), other (500×11×99)那么我们可以认为torch.matmul(input, other, out=None) 乘法首先是进行后两位矩阵乘法得到(99×11)×(11×99)⇒(99×99) ,然后分析两个参数的batch size分别是 (1000×500)和 500, 可以广播成为 (1000×500), 因此最终输出的维度是(1000×500×99×99)。

计算QK点积的时候:

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

query_states的形状为[bsz, num_heads, q_len, head_dim],key_states.transpose(2, 3)的形状为[bsz, num_heads, head_dim, kv_seq_len]。matmul操作会在最后两个维度上进行矩阵乘法,得到形状为[bsz, num_heads, q_len, kv_seq_len]的注意力权重。

计算注意力输出:

attn_output = torch.matmul(attn_weights, value_states)   

这里使用torch.matmul将注意力权重与值(value)相乘。attn_weights的形状为[bsz, num_heads, q_len, kv_seq_len],value_states的形状为[bsz, num_heads, kv_seq_len, head_dim]。matmul操作会在最后两个维度上进行矩阵乘法,得到形状为[bsz, num_heads, q_len, head_dim]的注意力输出。

因为我们不用GQA,q_len 就是 kv_seq_len


如何学习大模型

现在社会上大模型越来越普及了,已经有很多人都想往这里面扎,但是却找不到适合的方法去学习。

作为一名资深码农,初入大模型时也吃了很多亏,踩了无数坑。现在我想把我的经验和知识分享给你们,帮助你们学习AI大模型,能够解决你们学习中的困难。

我已将重要的AI大模型资料包括市面上AI大模型各大白皮书、AGI大模型系统学习路线、AI大模型视频教程、实战学习,等录播视频免费分享出来,需要的小伙伴可以扫取。

一、AGI大模型系统学习路线

很多人学习大模型的时候没有方向,东学一点西学一点,像只无头苍蝇乱撞,我下面分享的这个学习路线希望能够帮助到你们学习AI大模型。

在这里插入图片描述

二、AI大模型视频教程

在这里插入图片描述

三、AI大模型各大学习书籍

在这里插入图片描述

四、AI大模型各大场景实战案例

在这里插入图片描述

五、结束语

学习AI大模型是当前科技发展的趋势,它不仅能够为我们提供更多的机会和挑战,还能够让我们更好地理解和应用人工智能技术。通过学习AI大模型,我们可以深入了解深度学习、神经网络等核心概念,并将其应用于自然语言处理、计算机视觉、语音识别等领域。同时,掌握AI大模型还能够为我们的职业发展增添竞争力,成为未来技术领域的领导者。

再者,学习AI大模型也能为我们自己创造更多的价值,提供更多的岗位以及副业创收,让自己的生活更上一层楼。

因此,学习AI大模型是一项有前景且值得投入的时间和精力的重要选择。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/147703.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

国内短剧cps系统和短剧(播放)系统的区别,附各源码部署教程

国内短剧项目主要分为两大形式:一种是做短剧播放平台,让用户付费观看;另一种是做短剧的分销,就是将他人的平台短剧推广,可做平台可入驻,拿分成。 首先来说一下短剧播放平台(短剧系统&#xff0…

828华为云征文|华为云服务器Flexus X 搭建BTC虚拟币质押投资理财系统(仅测试学习)

一、华为云服务器Flexus X 选购和介绍 强大性能,引领云服务新潮流 柔性算力,满足多样化需求 Flexus X实例的部署与管理过程也非常便捷。用户只需在华为云官网注册账号,选择适合的Flexus X实例规格,完成购买后即可开始部署。华为…

telnet ftp ssh 如何在交换设备上创建

telnet 测试 说明telnet 成功 这测试ftp 成功

深入理解MySQL InnoDB中的B+索引机制

目录 一、InnoDB中的B 树索引介绍 二、聚簇索引 (一)使用记录主键值的大小进行排序 页内记录排序 页之间的排序 目录项页的排序 (二)叶子节点存储完整的用户记录 数据即索引 自动创建 (三)聚簇索引…

【每日刷题】Day129

【每日刷题】Day129 🥕个人主页:开敲🍉 🔥所属专栏:每日刷题🍍 🌼文章目录🌼 1. 105. 从前序与中序遍历序列构造二叉树 - 力扣(LeetCode) 2. LCR 154. 复杂…

足球预测模型理论:足球数据分析——XGBoost算法实战

简介:本文将探讨如何使用XGBoost算法进行足球数据分析,特别是足球运动员身价估计。我们将通过实例和生动的语言,解释XGBoost算法的原理和实际应用,帮助读者理解复杂的技术概念,并提供可操作的建议和解决问题的方法。 足…

Eclipse离线安装Tomcat插件

Eclipse离线安装Tomcat插件 最近的自己在对低版本的代码的进行维护补丁,不得不采用Eclipse 来进行跑项目,真的是折磨 其中遇到一个问题就是打开Eclipse的2021版,安装Tomcat的插件,发现好家伙,就是死活在线安装失败 (喵的,真的是让我抓耳挠腮!!哈哈哈) 无奈,只好采用离线安装,特…

实时语音识别技术实现

实时语音识别 1.环境2.完整代码3.效果4.可能的问题 实时语音识别 1.环境 python版本:3.11.9 2.完整代码 import sqlite3 import timefrom funasr import AutoModel import sounddevice as sd import numpy as np# 模型参数设置 chunk_size [0, 10, 5] encoder_c…

60.【C语言】内存函数(memset,memcmp函数)

3.memset函数(常用) *简单使用 memset:memory set cplusplus的介绍 点我跳转 翻译: 函数 memset void * memset ( void * ptr, int value, size_t num ); 填充内存块 将ptr指向的内存块的前num个字节设置为指定值(解释为无符号char)。 (指针ptr类型为…

短剧APP分销小视频联盟收益源码带版权激励视频无需自己上传短剧

功能介绍: 带2000多部短剧资源,有版权,无需自己更新短剧, 已对接广告联盟,解锁短剧观看激励视频,对接各大广告平台 带刷小视频功能,插入视频广告,获取广告收益, 带任…

力扣206.反转链表

力扣《反转链表》系列文章目录 刷题次序,由易到难,一次刷通!!! 题目题解206. 反转链表反转链表的全部 题解192. 反转链表 II反转链表的指定段 题解224. 两两交换链表中的节点两个一组反转链表 题解325. K 个一组翻转…

【C++掌中宝】缺省参数的全面解析

文章目录 前言1. 什么是缺省参数?2. 缺省参数的分类2.1 全缺省【备胎是如何使用的😅】2.1.1 疑难细究 2.2 半缺省2.2.1 错误用法示范2.2.2 正确用法示范2.2.3🔥实参缺省与形参缺省的混合辨析🔥 3. 缺省参数的规则和限制4. 规定必须…

Leetcode 1039. 多边形三角形剖分的最低得分 枚举型区间dp C++实现

问题:Leetcode 1039. 多边形三角形剖分的最低得分 你有一个凸的 n 边形,其每个顶点都有一个整数值。给定一个整数数组 values ,其中 values[i] 是第 i 个顶点的值(即 顺时针顺序 )。 假设将多边形 剖分 为 n - 2 个三…

邮件发送高级功能详解:HTML格式、附件添加与SSL/TLS加密连接

目录 一、邮件HTML格式设置 1.1 HTML邮件的优势 1.2 HTML邮件的编写 二、添加附件 2.1 附件的重要性 2.2 添加附件的代码示例 2.3 注意事项 三、使用SSL/TLS加密连接 3.1 SSL/TLS加密的重要性 3.2 SSL/TLS加密的工作原理 3.3 在邮件发送中启用SSL/TLS 3.3.1 邮件客…

力扣 LCR 020 回文子串 -Python

题目链接:LCR 020. 回文子串 - 力扣(LeetCode) 题目描述: 给定一个字符串 s ,请计算这个字符串中有多少个回文子字符串。 具有不同开始位置或结束位置的子串,即使是由相同的字符组成,也会被视…

OpenFeign 远程调用

目录 前言 OpenFeign 介绍 OpenFeign 的前⾝ Spring Cloud Feign 快速上⼿ 引⼊依赖 添加注解 编写 OpenFeign 的客户端 远程调⽤ OpenFeign 参数传递 传递单个参数 传递多个参数 传递对象 传递 JSON 最佳实践 Feign 继承⽅式 创建⼀个 Module 引⼊依赖 编写…

EasyExcel将数据库里面的数据生成excel文件

EasyExcel官方文档 1.在model模块导入依赖 <!-- 生成报表--> <dependency><groupId>com.alibaba</groupId><artifactId>easyexcel</artifactId><version>4.0.3</version> </dependency> 2.修饰实体类 package…

四叉树碰撞代码

使用raylib 代码来源 https://github.com/seyhajin/flux-samples/blob/master/raylib/quadtree/quadtree.c 原来是视锥碰撞四叉树&#xff0c;经过一周开发变成碰撞检测四叉树可视化 后经过改写 绿色检测 灰色检测 //https://github.com/seyhajin/flux-samples/blob/mast…

【C++篇】走进C++标准模板库:STL的奥秘与编程效率提升之道

文章目录 C STL 初探&#xff1a;打开标准模板库的大门前言第一章: 什么是STL&#xff1f;1.1 标准模板库简介1.2 STL的历史背景1.3 STL的组成 第二章: STL的版本与演进2.1 不同的STL版本2.2 STL的影响与重要性 第三章: 为什么学习 STL&#xff1f;3.1 从手动编写到标准化解决方…

three.js 让阴影更黑更暗

r166 可以通过设置intensity属性来配置每个光源的阴影强度 light.shadow.intensity 3;或者 修改shader THREE.ShaderChunk["shadowmap_pars_fragment"]THREE.ShaderChunk["shadowmap_pars_fragment"].replace( "occlusion clamp( max( hard_sha…