【Transformer】transformer模型结构学习笔记

文章目录

      • 1. transformer架构
      • 2. transformer子层解析
      • 3. transformer注意力机制
      • 4. transformer部分释疑

 

图1 transformer模型架构
图2 transformer主要模块简介
图3 encoder-decoder示意图N=6
图4 encoder-decoder子层示意图

1. transformer架构

  • encoder-decoder框架是一种处理NLP或其他seq2seq转换任务中的常见框架, 机器翻译就是典型的seq2seq模型, 两个seq序列长度可以不相等

  • transformer也是encoder-decoder的总体架构, 如上图。transformer主要由4个部分组成:

    • 输入部分(输入输出嵌入与位置编码)
    • 多层编码器
    • 多层解码器
    • 以及输出部分(输出线性层与softmax)
  • 模块介绍

    • Input Embedding: 输入嵌入。将源文本中的词汇数字表示转换为向量表示,捕捉词汇间的关系
    • Positional Encoding: 位置编码。为输入序列的每个位置生成位置向量,以便模型能够理解序列中的位置信息
    • Output Embedding: 输出嵌入。将目标文本中的词汇数字表示转换为向量表示
    • Linear: 线性层。将decoder输出的向量转换为最终的输出维度
    • Softmax: softmax层。将线性层的输出转换为概率分布,以便进行最终的预测
    • encoder架构: encoder由6个相同的encoder层组成,每个层包括两个子层:一个多头自注意力层(multi-head self-attention)和一个逐位置的前馈神经网络(point-wise feed-forward network);每个子层后都会使用残差连接(residual connection)和层归一化(layer normalization)连接,即Add&Norm。如下图
    • decoder架构:decoder包含6个相同的decoder层,每层包含3个子层掩码自注意力层masked self-attention),encoder-decoder交叉注意力层逐位置的前馈神经网络。每个子层后都有残差连接层归一化操作,即Add&Norm。如下图

2. transformer子层解析

  • encoder和decoder的本质区别:self-attention的masked掩码机制
  • muitl-head进行masked的目的:在生成文本时,确保模型只依赖已知的信息,而不是未来的内容,对未来信息进行掩码处理,这样才能学会预测
  • multi-head的目的:让模型关注输入的不同部分或者不同信息,比如一个名词的修饰词,一个名词的分类,一个名词对象的情感、诗意等,从直观的到抽象的,捕获复杂的依赖关系
  • Add:残差连接。缓解梯度消失问题;网络输入x与网络输出F(x)相加,求导时相当于添加常数项1,缓解梯度消失问题
  • Norm:层归一化。在每个层上独立进行,使激活值具有相同的均值和方差,通常是0和1;在transformer中,Norm操作通常紧跟在Add之后,对残差连接结果进行归一化,以加速训练并稳定模型性能
  • 前馈网络:对输入进行非线性变换,提取更高级别的特征/信息
  • 逐位前馈神经网络:是一个简单的全连接神经网络,在模型中起到增加非线性和学习更复杂表示的作用。逐位的意思是逐个元素element或点进行独立且相同的操作,不是跨位置或跨元素来进行的。逐位前馈神经网络通常包括两个全连接层一个ReLU激活层,两个全连接层对应两个线性变换,第一个全连接层之后接ReLU激活函数引入非线性,使模型能够学习更复杂的表示。第一个全连接层通常对输入进行增维表示,第二个全连接层降维到模型输出所需的维度

3. transformer注意力机制

  • transformer的3种注意力层:在transformer架构中有3种不同的注意力层
    • self-attention layer自注意力层:编码器输入序列通过multi-head self-attention计算自注意力权重
    • casual attention layer因果自注意力层:解码器的单个序列通过masked multi-head self-attention计算自注意力权重
    • cross attention layer交叉注意力层:编码器-解码器两个序列通过multi-head cross attention进行注意力转移
  • 注意力机制的过程说明
  • 缩放点积注意力

上图是缩放点积注意力示意图,计算公式

其中,softmax内部是注意力分数,softmax整个是注意力权重,乘以缩放因子 1 d k \frac{1}{\sqrt{d_{k} } } dk 1是为了缓解可能的梯度消失问题(softmax值过大时), d k d_{k} dk是Q或者K的维度大小

  • 多头注意力机制

上图是多头注意力机制示意图,多个注意力头并行运行,每个头都会独立地计算注意力权重和输出,这里采用的是缩放点积注意力来计算;

然后将所有头的输出拼接concat起来得到最终的输出;

多头其实是为了提取不同维度的信息,捕获复杂的依赖关系,增强模型的表示能力;最后多个头结果进行拼接,避免单个计算的误差,即避免只关注单方面维度信息的误差

计算公式:

在transformer原文中,head_num = 8,d_k=d_v=64

  • 交叉注意力机制

    • 自注意力机制,QKV都来自同一序列,如下
    • 交叉注意力机制,输入来自两个不同的序列,一个序列用作查询Q(来自decoder states的queries),另一个序列提供键K和值V(来自encoder states的keys和values),实现跨序列的交互和注意力转移,如下
  • 因果注意力机制

    • 为了确保模型在生成序列时只依赖于之前的输入信息,而不会受到未来信息的影响。casual self-attention通过掩码未来位置来实现这一点;使模型在预测某个位置的输出时,只看到该位置及之前的输入。如下图所示
    • 其中掩码未来位置的原因通过下图说明:
    • 掩码机制通过下图说明,加一个很大的负数,softmax之后就是0,如下

4. transformer部分释疑

  • 问题1:transformer相对RNN能处理长序列数据, 同时能进行并行计算, LSTM相对RNN进行改进的, 解决长时依赖问题, 那么transformer相对于LSTM有什么优势
    • (1)LSTM在解决长时依赖仍有局限。LSTM依赖cell state来传递长时信息,限制了其全局信息捕获能力;而transformer的自注意力机制可以考虑任意两个位置之间的依赖关系,能更好的捕捉全局的、长距离的依赖信息
    • (2)transformer的可解释性更强:transformer计算每个位置与所有位置的依赖关系,使得模型的预测结果更易于解释,LSTM的解释性相对较弱
    • (3)并行计算能力:transformer不用像LSTM等待上一时间步的输出作为下一时间步的输入,可以实现完全并行的计算,更容易进行分布式计算和加速
    • (4)扩展性和灵活性:transformer结构相对灵活,可以轻松扩展到更大的数据集和更复杂的任务中

 

  • 问题2:同问题1, transformer通过怎样的设计能够实现并行计算的?

    • 参考这个图,可以并行计算一个位置和其他所有位置的依赖关系
  • 问题3:层归一化Norm和batch normalization的区别

    • 都是归一化,但层归一化不是批量归一化;
    • LN是对每个样本的每个层进行的归一化,即对每个样本的所有特征做归一化;
    • 而BN是对每个batch数据进行归一化,即对batch_size内的每个特征做归一化;
    • LN保留了不同特征之间的大小关系,抹平了不同样本之间的大小关系,所以LN更适合NLP领域任务;
    • 而BN保留了不同样本之间的大小关系,抹平了不同特征之间的大小关系,所以BN更适合于依赖不同样本之间关系的任务,如CV领域
    • LN可以缓解梯度消失问题、改善系统对缩放摆幅变化的鲁棒性、更适用于小样本数据情况
    • 而BN旨在提高模型的训练速度和稳定性,使模型学习效率更高,降低测试错误率和泛化误差

 

  • 问题4:encoder和decoder的本质区别self-attention是否masked,如何理解
    • encoder中每个元素都能管住整个序列中的所有其他元素,生成新的输出表示。处理整个输入序列,不需要掩码未来的信息
    • decoder在生成序列时,只能依赖已经生成的部分,而不能依赖未来的信息。masked处理的是输出序列,将未来位置的注意力权重设置为0,从而限制模型的关注点在已生成的序列上,实现了类似条件语言模型的功能
    • decoder和encoder交叉注意力层,decoder允许关注encoder的输出,从而融合encoder中的信息到生成过程

 

  • 问题5:transformer训练的过程参数有哪些,除了W_Q/K/V这几个参数矩阵以外
    • (1)嵌入维度:输入和输出嵌入的维度,词嵌入和位置编码的维度。比如词嵌入矩阵大小为词汇表大小如50000 * d_词嵌入向量的维度
    • (2)multi-head attention的num_heads:注意力头数,决定模型并行关注输入序列不同部分的能力,每个头都会产生一个独立的注意力权重矩阵。论文中num_heads = 8
    • (3)隐藏层层数:每个encoder层和decoder层都保持一致
    • (4)前馈神经网络隐藏层大小:神经元个数,通常比层数大很多,以便能学习复杂的特征表示
    • (5)encoder和decoder的层数:定义了模型中encoder和decoder各自包含的层数,论文中n_layers = 6,即6个encoder层和6个decoder层
    • (6)位置编码的维度:输入输出序列进入encoder/decoder层时都要进行位置编码,通常与嵌入维度相同,以便和嵌入向量直接相加
    • (7)训练参数:像学习率,选用的优化器,batch_size,epoches等
    • (8)正则化参数:如dropout rate随机失活的神经元比例防止过拟合,L2正则化等
    • (9)权重初始化方法:如随机初始化,Xavier初始化,He初始化等,合理的初始化能加快训练的过程尽快找到最优解

 

  • 问题6:QKV计算的过程,W矩阵都是可以训练的

  • 问题7:self-attention和(cross)attention的区别

    • self-attention设置source=target,即query=key=value,然后计算内部依赖关系

 

  • 问题8:预训练模型BERT和transformer是什么关系
    • BERT(Bidirectional Encoder Representations from Transformers)使用transformer的encoder结构来构建的,输入与transformer类似,包括token/segment/position embedding等,这些embedding将输入文本序列转换为模型可以理解的向量表示;
    • 在BERT中可以选择encoder层的数量,轻量级模型通常使用12层,重量级模型通常使用24层;transformer的自注意力机制使BERT能够关注双向上下文的信息

 

  • 问题9:transformer模型训练的时候采用了什么损失函数
    • transformer训练过程主要采用了交叉熵损失函数(负对数似然损失函数)来衡量模型预测的概率分布真实分布之间的差异,也可以采用KL散度;
    • 并且可以计算向量空间距离MSE,即两组概率向量的空间距离

 


 
创作不易,如有帮助,请 点赞 收藏 支持
 


 

[参考文章]
[1]. transformer注意力机制解析
[2]. Seq2Seq的注意力机制
[3]. attention机制图示
[4]. LN与BN的区别
[5]. Seq2Seq的注意力机制
[6]. transformer的decoder结构
[7]. decoder-only和编解码器区别
[8]. Attention is All You Need翻译
[9]. transformer结构详解,推荐

created by shuaixio, 2024.06.23

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1474800.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

ZYNQ-LINUX环境C语言利用Curl库实现HTTP通讯

前言 在Zynq-Linux环境中,需要使用C语言来编写APP时,访问HTTP一般可以使用Curl库来实现,但是在Zynq的SDK中,并没有集成该库,在寻找了很多资料后找到了一种使用很方便的额办法。这篇文章主要记录一下移植Curl的过程。 …

一招解决找不到d3dcompiler43.dll,无法继续执行代码问题

当您的电脑遇到d3dcompiler43.dll缺失问题时,首先需要了解d3dcompiler43.dll文件及其可能导致问题的原因,之后便可以选择合适的解决方案。在此,我们将会为您提供寻找d3dcompiler43.dll文件的多种处理方法。 一、d3dcompiler43.dll文件分析 d…

C++入门7——string类详解

目录 1.什么是string类? 2.string类对象的常见构造 2.1 string(); 2.2 string (const char* s); 2.3 string (const string& str); 2.4 string (const string& str, size_t pos, size_t len npos); 2.5 string (const char* s, size_t n); 2.7 验证…

绝区肆--2024 年AI安全状况

前言 随着人工智能系统变得越来越强大和普及,与之相关的安全问题也越来越多。让我们来看看 2024 年人工智能安全的现状——评估威胁、分析漏洞、审查有前景的防御策略,并推测这一关键领域的未来可能如何。 主要的人工智能安全威胁 人工智能系统和应用程…

Java里的Arrary详解

DK 中提供了一个专门用于操作数组的工具类,即Arrays 类,位于java.util 包中。该类提供了一些列方法来操作数组,如排序、复制、比较、填充等,用户直接调用这些方法即可不需要自己编码实现,降低了开发难度。 java.util.…

Python爬虫系列-让爬虫自己写爬虫(半自动化,代替人工写爬虫)

现在的PC、手机客户端等终端设备大量使用了网页前后端技术,另外主流的网站也会经常会更新,导致以前一个月更新一次爬虫代码,变成了天天需要更新代码,所以自动化爬虫技术在当前就显得特别重要,最近我也是在多次更新某个…

赋值运算符重载和const成员函数和 const函数

文章目录 1.运算符重载(1)(2)运算符重载的语法:(3)运算符重载的注意事项:(4)前置和后置重载区别 2.const成员函数3.取地址及const取地址操作符重载4.总结 1.运算符重载 (1) 我们知道内置类型(整形,字符型,浮点型…)可以进行一系…

TB作品】51单片机 Proteus仿真 51单片机SPI显示OLED字符驱动

// GND 电源地 // VCC 接5V或3.3v电源 // D0 P1^4(SCL) // D1 P1^3(SDA) // RES 接P12 // DC 接P11 // CS 接P10 OLED显示接口与控制实验报告 背景 OLED(有机发光二极管)显示器由于其高对比度、低功耗和…

最新版Python安装教程

一、安装Python 1.下载Python 访问Python官网: https:/www.oython.orgl 点击downloads按钮,在下拉框中选择系统类型(windows/Mac OS./Linux等) 选择下载最新稳定版本的Python 以下内容以演示安装Windows操作系统64位的python 左边是稳定发布版本Stabl…

Linux权限概述

一、权限概述 1.权限的基本概念 2.为什么要设置权限 3.linux用户的身份类别 4.user文件的拥有者 5.group文件所属组内用户 6.other其他用户 7.特殊用户root 二、普通权限管理 1.ls -l查看文件权限 2.文件类型以及权限解析 3.文件或文件夹的权限设置 4.通过数字给文件…

CSRF verification failed. Request aborted.

最近在学习django,遇到这个问题。CSRF verification failed. Request aborted. 解决方案: 1、在Html template中加入csrf_token 2、在view.py中对应的view函数上加上装饰器 再启动运行,报错就解决了。

网页生成二维码、在线演示

https://andi.cn/page/621504.html

Zabbix监控软件

目录 一、什么是Zabbix 二、zabbix监控原理 三、zabbix 安装步骤 一、什么是Zabbix ●zabbix 是一个基于 Web 界面的提供分布式系统监视以及网络监视功能的企业级的开源解决方案。 ●zabbix 能监视各种网络参数,保证服务器系统的安全运营;并提供灵活的…

通信协议_Modbus协议简介

概念介绍 Modbus协议:一种串行通信协议,是Modicon公司(现在的施耐德电气Schneider Electric)于1979年为使用可编程逻辑控制器(PLC)通信而发表。Modbus已经成为工业领域通信协议的业界标准(De f…

ARM架构和Intel x86架构

文章目录 1. 处理器架构 2. ARM架构 3. Intel x86架构 4. 架构对比 5. 编译过程对比 1. 处理器架构 处理器架构是指计算机处理器的设计和组织方式,它决定了处理器的性能、功耗和功能特性。处理器架构影响着从计算机系统的硬件设计到软件开发的各个方面。在现代…

@[TOC](六、数据可视化—Echars(爬虫及数据可视化))

六、数据可视化—Echars(爬虫及数据可视化) Echarts应用 Echarts Echarts官网,很多图表等都是我们可以 https://echarts.apache.org/zh/index.html 是百度自己做的图表,后来用的人越来越多,捐给了orange组织&#xf…

【ROS2】初级:客户端-创建自定义 msg 和 srv 文件

目标:定义自定义接口文件( .msg 和 .srv )并将它们与 Python 和 C节点一起使用。 教程级别:初学者 时间:20 分钟 目录 背景 先决条件 任务 1. 创建一个新包2. 创建自定义定义3 CMakeLists.txt4 package.xml5. 构建 tut…

Vue3中生成本地pdf并下载

1. 前言 前端中经常会遇到在系统中根据数据导出一个pdf文件出来,一般都是后端来实现的,既然后端可以实现,前端为什么就不行呢,正好有一次也写了这个需求,就写了个小demo 示例图: 2. 实现步骤 首先下载html2pdf.js这个库yarn add html2pdf.js // 或 npm i html2pdf.js在项…

下载,连接mysql数据库驱动(最详细)

前言 本篇博客,我讲讲如何连接数据库?我使用mysql数据库举例。 目录 下载对应的数据库jar 包 百度网盘 存有8.4.0版本压缩包:链接:https://pan.baidu.com/s/13uZtXRmuewHRbXaaCU0Xsw?pwduipy 提取码:uipy 复制这…

数据结构--二叉树和堆

目录 1.基本概念 2.树的遍历方法 3.满二叉树&&完全二叉树 4.逻辑结构&&物理结构 5.推理公式 6.二叉树应用--堆 7.简单实现堆 1.基本概念 (1)这个里面的概念还是比较多的,但是大部分我们只需要了解即可,因为…