当前位置: 首页 > news >正文

14.解码器的Mask

从入门AI到手写Transformer-14.解码器的Mask

  • 14.解码器的Mask
  • 代码

整理自视频 老袁不说话 。

14.解码器的Mask

解码器与编码器不同的地方如下图:
在这里插入图片描述
在之前的操作中,需要对填充字符 [pad] 进行掩码mask,消失填充字符的影响。对 Q K ⊤ d \frac{QK^\top}{\sqrt{d}} d QK n − m n-m nm行或 n − m n-m nm列进行mask,修改为 − 1 0 6 -10^{6} 106
而在编码器端,语义信息都是已知的, Q K ⊤ {QK^\top} QK计算时, [ x 1 x 2 ⋮ x n ] [ x 1 x 2 ⋯ x n ] = [ x 11 x 12 ⋯ x 1 n x 21 x 22 … x 2 n ⋮ x n 1 x n 2 … x n n ] \begin{bmatrix}x_{1}\\x_{2}\\\vdots\\{x_{n}}\end{bmatrix}\begin{bmatrix}x_{1}&x_{2}&\cdots&x_{n}\end{bmatrix}=\begin{bmatrix}x_{11}&x_{12}&\cdots&x_{1n}\\x_{21}&x_{22}&\dots&x_{2n} \\\vdots\\{x_{n1}}&x_{n2}&\dots&x_{nn}\end{bmatrix} x1x2xn[x1x2xn]=x11x21xn1x12x22xn2x1nx2nxnn
在解码器端,已知信息不足,只能计算到当前词语, [ x 1 x 2 ⋮ x n ] [ x 1 x 2 ⋯ x n ] = [ x 11 x 21 x 22 ⋮ x n 1 x n 2 … x n n ] \begin{bmatrix}x_{1}\\x_{2}\\\vdots\\{x_{n}}\end{bmatrix}\begin{bmatrix}x_{1}&x_{2}&\cdots&x_{n}\end{bmatrix}=\begin{bmatrix}x_{11}\\x_{21}&x_{22}\\\vdots\\{x_{n1}}&x_{n2}&\dots&x_{nn}\end{bmatrix} x1x2xn[x1x2xn]=x11x21xn1x22xn2xnn,是一个下三角矩阵,需要对原始矩阵的上三角矩阵掩码。

编码实现中,替换填充字符的掩码时,假设 b = 3 , n = 5 b=3,n=5 b=3,n=5,批量为3,句子长度为5,每个句子有效长度分别是 [ 4 , 2 , 3 ] [4,2,3] [4,2,3],则将比较矩阵设置为 [ 4 , 4 , 4 , 4 , 4 , 2 , 2 , 2 , 2 , 2 , 3 , 3 , 3 , 3 , 3 ] [4,4,4,4,4,2,2,2,2,2,3,3,3,3,3] [4,4,4,4,4,2,2,2,2,2,3,3,3,3,3],会将右侧字符判别为True,替换掉,详见这个链接
在这里插入图片描述
而替换上三角需要将矩阵设置为 [ 1 , 2 , 3 , 4 , 5 , 1 , 2 , 3 , 4 , 5 , 1 , 2 , 3 , 4 , 5 ] [1,2,3,4,5,1,2,3,4,5,1,2,3,4,5] [1,2,3,4,5,1,2,3,4,5,1,2,3,4,5],比较后得到上三角矩阵。
在这里插入图片描述

代码

import torch# 定义批次大小b和序列长度n
b = 1  # 批次大小(batch size)
n = 5  # 序列长度(sequence length)# 方法1:生成有效长度(每个样本使用相同长度)
# 随机生成b个1~n的数字,每个数字重复n次
valid_len1 = torch.randint(1, n+1, (b,)).repeat_interleave(n)
print(valid_len1)  # 示例输出(当b=3): tensor([2,2,2,2,2, 3,3,3,3,3, 4,4,4,4,4])# 方法2:生成有效长度(每个位置使用递增长度)
# 生成1~n的序列,并重复b次
valid_len2 = torch.arange(1, n+1).repeat(b)
print(valid_len2)  # 示例输出(当b=3): tensor([1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5])# 创建基础位置序列 [0,1,2,...,n-1]
a = torch.arange(n)
print(a)  # 输出: tensor([0, 1, 2, 3, 4])# 生成掩码(mask)
# mask1: 每个样本统一长度限制(广播比较)
mask1 = a[None, :] < valid_len1[:, None]  # 形状: (b*n, n)
# mask2: 每个位置独立长度限制(如填充位置)
mask2 = a[None, :] < valid_len2[:, None]  # 形状: (b*n, n)
print(mask1)
print(mask2)# 创建随机输入张量(形状: b*n × n)
X = torch.randn(b, n, n).reshape(b*n, n)# 应用mask2:将无效位置(mask=False)的值设为极小值
# 这通常在注意力机制中用于屏蔽无效位置(如填充位置)
X[~mask2] = -1e6  # ~表示逻辑取反
print(X)

输出结果

tensor([5, 5, 5, 5, 5])
tensor([1, 2, 3, 4, 5])
tensor([0, 1, 2, 3, 4])
tensor([[True, True, True, True, True],[True, True, True, True, True],[True, True, True, True, True],[True, True, True, True, True],[True, True, True, True, True]])
tensor([[ True, False, False, False, False],[ True,  True, False, False, False],[ True,  True,  True, False, False],[ True,  True,  True,  True, False],[ True,  True,  True,  True,  True]])
tensor([[-1.3461e+00, -1.0000e+06, -1.0000e+06, -1.0000e+06, -1.0000e+06],[-1.9000e+00, -2.1507e+00, -1.0000e+06, -1.0000e+06, -1.0000e+06],[ 1.3339e+00, -7.1440e-01,  4.6307e-01, -1.0000e+06, -1.0000e+06],[-1.1387e+00, -1.7289e+00, -1.1892e+00, -1.1285e+00, -1.0000e+06],[ 2.0739e-01, -1.3382e+00,  1.9721e-01,  2.8529e-01, -8.7918e-01]])
http://www.xdnf.cn/news/30277.html

相关文章:

  • End-to-End从混沌到秩序:基于LLM的Pipeline将非结构化数据转化为知识图谱
  • 从GET到POST:HTTP请求的攻防实战与CTF挑战解析
  • EF Core中动态加载关联的导航属性
  • 直线轴承的正确安装方式是什么?
  • LeetCode 热题 100_乘积最大子数组(88_152_中等_C++)(动态规划)
  • 虚拟现实(VR)技术在教育领域的创新应用
  • QML中的3D功能--入门开发
  • Chat2DB创始人姬朋飞:AI在 text2sql应用领域的实践
  • Java从入门到“放弃”(精通)之旅——数组的定义与使用⑥
  • 进程程序替换
  • 【橘子大模型】初探rag知识库的构建
  • Linux基础IO(八)之硬链接
  • 完整游戏排行榜系统实现
  • Redux Promise 中间件
  • C++ 数组 array ™实现动画效果全解析⚡YQW · Studio ⚡
  • Http基础
  • QML中的3D功能--自定义着色器开发
  • 硬件操作指南——ATK-MD0430 V20
  • 什么是超类实体和派生属性
  • JavaScript 变量语法扩展
  • C 语言联合与枚举:自定义类型的核心解析
  • Scade 语言词法介绍
  • 游戏引擎学习第235天:在 Windows 上初始化 OpenGL
  • 4N60-ASEMI开关电源与适配器专用4N60
  • 6.7 ChatGPT自动生成定时任务脚本:Python与Cron双方案实战指南
  • android测试依赖
  • Python番外——常用的包功能讲解和分类组合
  • GD32H7单片机使用segger_rtt,rtt-viewer看不到输出的问题,怎样解决?
  • 使用docker在manjaro linux系统上运行windows和ubuntu
  • 在统信UOS1060上新增备份到U盘