AI大模型系列之七:Transformer架构讲解

目录

Transformer是什么?

         输入模块结构:

编码器模块结构:

解码器模块:

输出模块结构:

Transformer核心思想是什么?

Transformer的代码架构

自注意力机制是什么?

多头注意力有什么用?

前馈神经网络

编码器(Encoder)

解码器(Decoder):

基于卷积神经网络(CNN)的编码器-解码器结构


Transformer是什么?

是一种基于注意力机制(attention mechanism)的神经网络架构,最初由Vaswani等人在论文《Attention Is All You Need》中提出。它在自然语言处理(NLP)领域取得了巨大成功,特别是在机器翻译任务中。

传统的循环神经网络(RNNs)和长短时记忆网络(LSTM)在处理长距离依赖关系时存在一些问题,而Transformer引入了自注意力机制来解决这些问题。自注意力机制允许模型在处理序列数据时关注输入序列中的不同位置,而不仅仅是当前位置。这使得Transformer能够并行处理输入序列,加速训练。

Transformer模型设计之初,用于解决机器翻译问题,是完全基于注意力机制构建的编码器-解码器架构,编码器和解码器均由若干个具有相同结构的层叠加而成,每一层的参数不同。编码器主要负责将输入序列转化为一个定长的向量表示,解码器则将这个向量解码为输出序列。Transformer总体架构可分为四个部分:输入部分、编码器、解码器、输出部分。

         输入模块结构:
  1. 源文本嵌入层及其位置编码器
  2. 目标文本嵌入层及其位置编码器
编码器模块结构:
  1. N个编码器层堆叠而成
  2. 每个编码器层由两个子层连接结构组成
  3. 第一个子层连接结构包括一个多头自注意力子层、规范化层和一个残差连接
  4. 第二个子层连接结构包括一个前馈全连接子层、规范化层和一个残差连接

编码器encoder,包含两层,一个self-attention层和一个前馈神经网络,self-attention能帮助当前节点不仅仅只关注当前的词,从而能获取到上下文的语义。

解码器模块:
  1. N个解码器层堆叠而成
  2. 每个解码器层由三个子层连接结构组成
  3. 第一个子层连接结构包括一个多头自注意力子层、规范化层和一个残差连接
  4. 第二个子层连接结构包括一个多头注意力子层、规范化层和一个残差连接
  5. 第三个子层连接结构包括一个前馈全连接子层、规范化层和一个残差连接
  6. 解码器decoder也包含encoder提到的两层网络,但是在这两层中间还有一层attention层,帮助当前节点获取到当前需要关注的重点内容。
输出模块结构:
  1. 线性层
  2. softmax

Transformer核心思想是什么?

自注意力机制(Self-Attention): 模型能够同时考虑输入序列中的所有位置,而不是像传统的固定窗口大小的卷积或循环神经网络一样逐步处理。 传统的神经网络在处理序列数据时,对每个位置的信息处理是固定的,而自注意力机制允许模型在处理每个位置时关注输入序列的其他部分,从而更好地捕捉全局信息。

位置编码(Positional Encoding): 由于Transformer没有显式的顺序信息,为了保留输入序列中元素的位置信息,需要添加位置编码。

多头注意力(Multi-Head Attention): 将自注意力机制应用多次,通过多个注意力头来捕捉不同的关系。

前馈神经网络(Feedforward Neural Network): 每个注意力子层后接一个前馈神经网络,用于学习非线性关系。

Transformer的成功不仅限于NLP领域,还在计算机视觉等领域取得了重要进展。由于其并行计算的优势,Transformer已成为深度学习中的经典模型之一,被广泛用于各种任务。

Transformer的代码架构

因为它涉及到自注意力机制、位置编码、多头注意力等多个关键概念

import torch
import torch.nn as nnclass PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=512):super(PositionalEncoding, self).__init__()self.encoding = torch.zeros(max_len, d_model)position = torch.arange(0, max_len).unsqueeze(1).float()div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))self.encoding[:, 0::2] = torch.sin(position * div_term)self.encoding[:, 1::2] = torch.cos(position * div_term)self.encoding = self.encoding.unsqueeze(0)def forward(self, x):return x + self.encoding[:, :x.size(1)].detach()class TransformerModel(nn.Module):def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers):super(TransformerModel, self).__init__()self.embedding = nn.Embedding(vocab_size, d_model)self.positional_encoding = PositionalEncoding(d_model)self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers)self.fc = nn.Linear(d_model, vocab_size)def forward(self, src, tgt):src = self.embedding(src)src = self.positional_encoding(src)tgt = self.embedding(tgt)tgt = self.positional_encoding(tgt)output = self.transformer(src, tgt)output = self.fc(output)return output# 定义模型
vocab_size = 10000  # 词汇表大小
d_model = 512  # 模型维度
nhead = 8  # 多头注意力的头数
num_encoder_layers = 6  # 编码器层数
num_decoder_layers = 6  # 解码器层数model = TransformerModel(vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers)# 定义输入
src = torch.randint(0, vocab_size, (10, 32))  # 10个序列,每个序列长度为32
tgt = torch.randint(0, vocab_size, (20, 32))  # 20个序列,每个序列长度为32# 前向传播
output = model(src, tgt)
自注意力机制是什么?

Self Attention
先看例子,下列句子是我们想要翻译的输入句子:
The animal didn’t cross the street because it was too tired
这个“it”在这个句子是指什么呢?它指的是street还是这个animal呢?这对于人类来说是一个简单的问题,但是对于算法则不是。
当模型处理这个单词“it”的时候,自注意力机制会允许“it”与“animal”建立联系。
随着模型处理输入序列的每个单词,自注意力会关注整个输入序列的所有单词,帮助模型对本单词更好地进行编码。
如果你熟悉RNN(循环神经网络),回忆一下它是如何维持隐藏层的。RNN会将它已经处理过的前面的所有单词/向量的表示与它正在处理的当前单词/向量结合起来。而自注意力机制会将所有相关单词的理解融入到我们正在处理的单词中
当我们在编码器#5(栈中最上层编码器)中编码“it”这个单词的时,注意力机制的部分会去关注“The Animal”,将它的表示的一部分编入“it”的编码中。


自注意力机制(Self-Attention Mechanism)是Transformer模型的核心组成部分之一,用于处理输入序列中各个位置之间的依赖关系。以下是对自注意力机制的详细解读:

注意力的概念:
注意力机制模拟了人类视觉系统的工作原理,即在处理某个任务时,我们不是对所有信息一视同仁,而是在某个时刻关注一部分信息,而另一时刻可能关注其他信息。在自注意力机制中,模型通过学习到的权重分配给输入序列中的不同位置,以便在生成输出时更加关注相关的部分。

自注意力机制的基本原理:

Query(查询): 通过将输入序列与权重矩阵相乘,得到每个位置的查询向量。查询向量用于衡量每个位置与其他位置的相关性。
Key(键): 通过将输入序列与权重矩阵相乘,得到每个位置的键向量。键向量用于被查询向量衡量,从而计算注意力分布。
Value(值): 通过将输入序列与权重矩阵相乘,得到每个位置的值向量。值向量将根据注意力分布加权求和,形成最终的输出。
注意力分布计算:

计算相似度: 通过查询向量和键向量的点积,计算每个位置的相似度得分。
缩放: 为了避免相似度过大导致的梯度消失或梯度爆炸问题,一般会对相似度进行缩放,常用的缩放因子是输入维度的平方根。
Softmax: 对缩放后的相似度应用Softmax函数,得到注意力权重分布。Softmax确保所有权重的总和为1,使其成为有效的概率分布。
注意力加权求和: 将值向量按照得到的注意力权重进行加权求和,得到最终的自注意力输出。
多头注意力(Multi-Head Attention):
为了增强模型的表达能力,自注意力机制通常会使用多个独立的注意力头。每个头学习不同的查询、键、值权重矩阵,最后将多个头的输出拼接在一起并通过线性映射进行融合。

位置编码(Positional Encoding):
自注意力机制没有直接考虑序列中元素的顺序,为了捕捉序列的位置信息,常常会在输入序列中添加位置编码。位置编码是一个与位置有关的可学习参数,使得模型能够更好地处理序列的顺序信息。

多头注意力有什么用?


多头注意力机制的引入具有以下几个优势:

多头并行计算: 不同注意力头可以并行计算,提高了计算效率。
学习不同表示: 不同头关注输入序列的不同部分,有助于模型学习更丰富、更复杂的特征表示。
提高模型泛化能力: 多头注意力可以使模型在处理不同类型的信息时更加灵活,提高了模型的泛化能力。
通过这种方式,多头注意力机制在Transformer模型中起到了至关重要的作用,使得模型能够更好地捕捉输入序列中的关系,提高了模型的表达能力。

前馈神经网络


前馈神经网络(Feedforward Neural Network)是一种最基本的神经网络结构,也被称为多层感知机(Multilayer Perceptron,MLP)。在深度学习中,前馈神经网络被广泛应用于各种任务,包括图像分类、语音识别、自然语言处理等。下面是对前馈神经网络的详细解读:

1. 基本结构
前馈神经网络由输入层、隐藏层和输出层组成。每一层都包含多个神经元(或称为节点),每个神经元与上一层的所有神经元都有连接,连接上带有权重。每个连接上都有一个权重,表示连接的强度。

输入层(Input Layer): 接受输入特征的层,每个输入特征对应一个输入层神经元。

隐藏层(Hidden Layer): 在输入层和输出层之间的一层或多层神经元,负责学习输入数据中的复杂模式。

输出层(Output Layer): 提供网络的输出,输出的维度通常与任务的要求相匹配,例如,对于二分类任务,可以有一个输出神经元表示两个类别的概率。

2. 激活函数
每个神经元在接收到输入后,会通过激活函数进行非线性变换。常用的激活函数包括:

Sigmoid 函数: 将输入映射到范围 ((0, 1)),适用于二分类问题。

Hyperbolic Tangent(tanh)函数: 将输入映射到范围 ((-1, 1)),具有零中心性,有助于减少梯度消失问题。

Rectified Linear Unit(ReLU)函数: 对于正数输入,输出等于输入;对于负数输入,输出为零。ReLU 是目前最常用的激活函数之一。

Softmax 函数: 用于多分类问题的输出层,将输出转化为概率分布。

3. 前向传播
前馈神经网络的训练过程中,信息从输入层传播到输出层的过程称为前向传播。具体步骤如下:

输入层接收输入特征。

每个神经元接收来自上一层神经元的输入,计算加权和。

加权和经过激活函数进行非线性变换,得到每个神经元的输出。

输出传递到下一层作为输入,重复以上步骤。

最终,网络的输出被用于任务的预测。

编码器+解码器
编码器-解码器结构是深度学习中常用的一种网络架构,特别在图像分割和生成任务中得到广泛应用。以下是对编码器-解码器结构的详细解读:

编码器(Encoder)


特征提取: 编码器的主要作用是从输入数据中提取关键特征。对于图像任务,输入通常是图像,编码器通过一系列卷积层(Convolutional Layers)进行特征提取。这些卷积层可以捕捉图像中的低级别和高级别特征,例如边缘、纹理和对象形状。

降维: 随着网络深度的增加,编码器通常会进行降维操作,通过池化层(Pooling Layers)或步幅较大的卷积层减小特征图的尺寸。这有助于减少计算复杂性和内存需求,并提高网络对输入的抽象表示能力。

语义信息提取: 在编码器的高层级特征表示中,网络通常能够捕捉到更抽象的语义信息,例如图像中的物体类别、结构等。这些特征通常被称为“语义特征”。

解码器(Decoder):


上采样: 解码器负责将编码器提取的特征映射还原为输入数据的尺寸。这通常涉及到上采样操作,其中通过插值或反卷积操作将特征图的尺寸放大。

特征融合: 解码器通常需要与编码器的相应层进行特征融合,以保留从输入到编码器的层次结构中学到的语义信息。这可以通过连接编码器和解码器的相应层来实现,形成所谓的“跳跃连接”(Skip Connections)。

重建输出: 解码器的最终目标是生成与输入数据相匹配的输出。对于图像分割任务,输出通常是一个与输入图像尺寸相同的特征图,其中每个像素或区域对应一个类别的概率或标签。

基于卷积神经网络(CNN)的编码器-解码器结构
import torch
import torch.nn as nnclass EncoderDecoder(nn.Module):def __init__(self):super(EncoderDecoder, self).__init__()# 编码器部分self.encoder = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2)# 添加更多卷积层和池化层...)# 解码器部分self.decoder = nn.Sequential(# 添加上采样层和特征融合...nn.Conv2d(128, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),nn.Conv2d(64, 32, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),nn.Conv2d(32, 1, kernel_size=3, padding=1),nn.Sigmoid()  # 用于二分类任务时添加Sigmoid激活函数)def forward(self, x):# 编码器前向传播x = self.encoder(x)# 解码器前向传播x = self.decoder(x)return x# 创建模型实例
model = EncoderDecoder()# 打印模型结构
print(model)

编码器和解码器的结构可能会更加复杂,具体的设计取决于任务的要求和数据集的特点。上述代码中使用的是PyTorch的简单卷积层、池化层和上采样层,实际场景中可能需要更深的网络结构和更复杂的组件。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1421804.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

ohmyzsh的安装过程中失败拒绝连接问题的解决

1.打开官网Oh My Zsh - a delightful & open source framework for Zsh 在官网能看到下面的界面 有这两种自动安装的方式 个人本次选择的是: wget https://raw.github.com/ohmyzsh/ohmyzsh/master/tools/install.sh -O - 1.打开终端输入安装的指令 sh -c "$(wget…

etcd集群恢复、单节点恢复操作手册

一、集群备份 备份方式:Jenkins触发每小时的定时任务,通过调取ansible的playbook进行etcd集群的数据备份和上传,默认只备份集群中的非leader成员,避免leader成员压力过大。将备份数据上传到对应的公有云对象存储,分别…

软件测试总体报告(实际项目原件Word参考)

软件全套精华资料包清单部分文件列表: 工作安排任务书,可行性分析报告,立项申请审批表,产品需求规格说明书,需求调研计划,用户需求调查单,用户需求说明书,概要设计说明书&#xff0c…

bat xcopy 解析

echo off set source_folder"C:\path\to\source" set destination_folder"C:\path\to\destination" set exclude_file"C:\path\to\excluded_folders.txt"REM 创建目标文件夹(如果不存在) mkdir %destination_folder% 2>…

01-02-1

1、day10作业 使用的代码 #include<stdio.h> void change(int* i) {*i(*i) / 2; } int main() {int i 0;scanf("%d", &i);change(&i);printf("%d", i);return 0; } ​ ​ 2、day11作业 使用的代码 #include<stdio.h> #include<…

如何在windows server下安装mysql5.7数据库,并使用Navicat Premium 15可视化工具新建数据库并读取数据库信息。

如何在windows server下安装mysql5.7数据库&#xff1f; MySQL :: Download MySQL Community Server (Archived Versions)https://downloads.mysql.com/archives/community/点击↑&#xff0c;然后选择对应版本和平台↓下载 将下载后的安装包放入固定目录&#xff08;这里以D:…

Linux0.11 中全局描述符表(GDT)

在Linux内核中&#xff0c;全局描述符表&#xff08;Global Descriptor Table&#xff0c;简称GDT&#xff09;是一个关键的数据结构&#xff0c;主要用于管理处理器的内存段和相关的权限与属性。它属于x86架构中的保护模式特性&#xff0c;允许操作系统对内存访问进行更精细的…

代码大师的工具箱:现代软件开发利器

✨✨ 欢迎大家来访Srlua的博文&#xff08;づ&#xffe3;3&#xffe3;&#xff09;づ╭❤&#xff5e;✨✨ &#x1f31f;&#x1f31f; 欢迎各位亲爱的读者&#xff0c;感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua小谢&#xff0c;在这里我会分享我的知识和经验。&am…

RedisTemplate操作Redis详解之连接Redis及自定义序列化

连接到Redis 使用Redis和Spring时的首要任务之一是通过IoC容器连接到Redis。为此&#xff0c;需要java连接器&#xff08;或绑定&#xff09;。无论选择哪种库&#xff0c;你都只需要使用一组Spring Data Redis API&#xff08;在所有连接器中行为一致&#xff09;&#xff1a;…

【原理代码详解】DeepSORT算法:多目标跟踪的深度学习解决方案

一、引言 在视频监控和智能交通系统中&#xff0c;多目标跟踪是一项关键技术&#xff0c;它涉及检测视频中的多个目标&#xff0c;并在视频帧之间维持每个目标的身份。DeepSORT算法作为SORT算法的扩展&#xff0c;通过结合深度学习和传统的跟踪技术&#xff0c;提高了目标跟踪…

李飞飞团队关于2024年人工智能发展报告总结 (Artificial Intelligence Index Report)

目录 1 10大核心信息2 AI研究和发展2.1 核心要点2.2 核心对比信息2.3 模型是否会用尽数据2.4 基础模型发展2.5 训练模型成本 3 技术性能3.1 核心要点3.2 重要模型发布情况3.3 AI表现情况3.4 多学科、高难度评估集 (MMMU & GPQA & ARC)3.5 Agents3.6 RLHF & RLAIF3.…

R语言数据分析案例-股票题目分析

Value at Risk&#xff08;VaR&#xff09;是一种统计技术&#xff0c;用于量化投资组合在正常市场条件下可能遭受的最大潜在损失。它是风险管理和金融领域中一个非常重要的概念。VaR通常以货币单位表示&#xff0c;用于估计在给定的置信水平和特定时间范围内&#xff0c;投资组…

基于网络的无人海洋船舶控制

书籍&#xff1a;Network-Based Control of Unmanned Marine Vehicles 作者&#xff1a;Yu-Long Wang&#xff0c;Qing-Long Han&#xff0c;Chen Peng&#xff0c;Lang Ma 出版&#xff1a;Springer 书籍下载-《基于网络的无人海洋船舶控制》控制系统中的通信网络可能引起延…

28.6k Star!Dify:完善生态、支持Ollama与本地知识库、企业级拖放式UI构建AI Agent、API集成进业务!

原文链接&#xff08;更好排版、视频播放、社群交流&#xff09; 28.6k Star&#xff01;Dify&#xff1a;完善生态、支持Ollama与本地知识库、企业级拖放式UI构建AI Agent、API集成进业务&#xff01; 原创 Aitrainee [ AI进修生 ](javascript:void(0)&#x1f609; AI进修…

触摸OpenNJet,云原生世界触手可及

&#x1f308;个人主页: Aileen_0v0 &#x1f525;热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 ​&#x1f4ab;个人格言:“没有罗马,那就自己创造罗马~” 文章目录 导言OpenNJet云原生引擎介绍云原生平台的介绍优化与创新 为什么选择OpenNJet云原生引擎如何在windo…

SAP 控制已转采购订单的PR不允许删除简介

SAP系统中采购申请当被转成采购订单后&#xff0c;在采购申请中会关联到对应已生生成的采购订单&#xff0c;如下图中可以看到采购申请对应的采购订单 当日常操作中用户在创建完采购申请后&#xff0c;当PR转成PO后仍然可以将采购申请的行项目进行删除&#xff0c;显然这个操作…

maven .lastUpdated文件作用

现象 有时候我在用maven管理项目时会发现有些依赖报错&#xff0c;这时你可以看一下本地仓库中是否有.lastUpdated文件&#xff0c;也许与它有关。 原因 有这个文件就表示依赖下载过程中发生了错误导致依赖没成功下载&#xff0c;可能是网络原因&#xff0c;也有可能是远程…

做国外问卷调查,一天能挣多少钱?

大家好​&#xff0c;我是汇舟问卷&#xff0c;专注于国外问卷调查项目已经五年的时间了&#xff0c;目前做的一直比较稳定。 这个项目说白了就是通过搭建国外的环境&#xff0c;登录问卷平台&#xff0c;通过参与国外企业发布的问卷调查来获取​美金奖励。 那么参与的问卷的…

2.数据类型与变量(java篇)

目录 数据类型与变量 数据类型 变量 整型变量 长整型变量 短整型变量 字节型变量 浮点型变量 双精度浮点型 单精度浮点型 字符型变量 布尔型变量&#xff08;boolean&#xff09; 类型转换 自动类型转换(隐式) 强制类型转换(显式) 类型提升 字符串类型 数据类…

中医揿针的注意事项

点击文末领取揿针的视频教程跟直播讲解 关于揿针的注意事项&#xff0c;我们可以从以下几个方面进行探讨&#xff1a; 01操作前准备 1. 确保针具的清洁和无菌状态&#xff0c;以避免感染。 2. 了解患者的身体状况&#xff0c;如是否有特殊疾病或过敏史&#xff0c;以便选择…