Tansformer代码实现

目录

1.Tansformer架构图

2.代码实现

2.1创建类:实现基于位置的前馈网络 

 2.2创建 残差&LN层标准归一化的类

2.3编码器block

2.4创建编码器 

2.5创建解码器 

2.6transformer解码器部分 

3.知识点个人理解


 

1.Tansformer架构图

2.代码实现

2.1创建类:实现基于位置的前馈网络 

#创建类:实现基于位置的前馈网络
#前馈网络(position wise feed_forward_net)
class PositionWiseFFN(nn.Module):def __init__(self, ffn_num_inputs, ffn_num_hiddens, ffn_num_outputs, **kwargs):super().__init__(**kwargs)self.dense1 = nn.Linear(ffn_num_inputs, ffn_num_hiddens)  #输入层self.relu = nn.ReLU()self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs) #输出层 没有激活函数def forward(self, X):return self.dense2(self.relu(self.dense1(X)))
#测试上面创建的类
#创建前馈网络对象
ffn = PositionWiseFFN(ffn_num_inputs=6, ffn_num_hiddens=4, ffn_num_outputs=8)
ffn.eval()
#假设数据
X = torch.ones((2, 3, 6))  #  X的shape=(batch_size, num_steps, num_hiddens)
ffn(X).shape

 torch.Size([2, 3, 8])

 2.2创建 残差&LN层标准归一化的类

class AddNorm(nn.Module):#残差连接后进行层规范化def __init__(self, normalized_shape, dropout, **kwargs):super().__init__(**kwargs)self.dropout = nn.Dropout(dropout)self.ln = nn.LayerNorm(normalized_shape)def forward(self, X, Y):return self.ln(self.dropout(Y) + X)  #残差结构       

 

#使用残差结构,输入的数据的维度必须相同!!1
add_norm = AddNorm([6, 8], dropout=0.2)
add_norm.eval()
X = torch.ones((2, 6, 8), dtype=torch.float32)
Y = torch.ones((2, 6, 8), dtype=torch.float32)
add_norm(X, Y)
tensor([[[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.]],[[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0.]]],grad_fn=<NativeLayerNormBackward0>)

2.3编码器block

class EncoderBlock(nn.Module):def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_inputs, ffn_num_hiddens, ffn_num_outputs, num_heads, dropout, use_bias=False, **kwargs):super().__init__(**kwargs)self.attention = dltools.MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)self.addnorm1 = dltools.AddNorm(norm_shape, dropout)self.ffn = dltools.PositionWiseFFN(ffn_num_inputs, ffn_num_hiddens, ffn_num_outputs)self.addnorm2 = dltools.AddNorm(norm_shape, dropout)def forward(self, X, valid_lens):Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))return self.addnorm2(Y, self.ffn(Y))
#假设X数据
X = torch.ones((2, 100, 24))
valid_lens = torch.tensor([3, 2])
encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 24,  8, 0.5)
encoder_blk.eval()
encoder_blk(X, valid_lens).shape
torch.Size([2, 100, 24])

2.4创建编码器 

# 编码器
class TransformerEncoder(dltools.Encoder):def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, num_heads, num_layers, dropout, use_bias=False, **kwargs):super().__init__(**kwargs)self.num_hiddens = num_hiddensself.embedding = nn.Embedding(vocab_size, num_hiddens)self.pos_encoding = dltools.PositionalEncoding(num_hiddens, dropout)self.blks = nn.Sequential()for i in range(num_layers):self.blks.add_module('block' + str(i), EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,num_heads, dropout, use_bias))def forward(self, X, valid_lens, *args):# 对embedding之后的数据进行缩放, 有助于收敛X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))self.attention_weights = [None] * len(self.blks)for i, blk in enumerate(self.blks):X = blk(X, valid_lens)self.attention_weights[i] = blk.attention.attention.attention_weightsreturn X
encoder = TransformerEncoder(200, 24, 24, 24, 24, [100, 24], 24, 48, 24, 8, 2, 0.3)
encoder.eval()
X = torch.ones((2, 100), dtype=torch.long)
encoder(X, valid_lens).shape
torch.Size([2, 100, 24])

2.5创建解码器 

 

 

#decoder block
class DecoderBlock(nn.Module):def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, num_heads, dropout, i, **kwargs):super().__init__(**kwargs)self.i = i#第一层多头注意力机制self.attention1 = dltools.MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)#残差与LN层归一化self.addnorm1 = dltools.AddNorm(norm_shape, dropout)#第二层多头注意力机制self.attention2 = dltools.MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)#残差与LN层归一化self.addnorm2 = dltools.AddNorm(norm_shape, dropout)#ffn层self.ffn = dltools.PositionWiseFFN(ffn_num_input, ffn_num_hiddens, ffn_num_outputs)self.addnorm3 = dltools.AddNorm(norm_shape, dropout)def forward(self, X, state):#state包含三个东西:enc_outputs, enc_valid_lens,dec_outputs enc_outputs, enc_valid_lens = state[0], state[1]#state[2][self.i]表示当前的dec_outputs#若当前的dec_outputs为空, 表示处于训练模式:即X是需要传入的强制学习的目标标记(文本序列)if state[2][self.i] is None:key_values = Xelse: #若不为空,就是处于预测模式  (训练好了,state[2][self.i]就会存当前block的输出储数据)#预测需要把前面时刻预测得到的信息和当前的block的输出得到的信息拼到一起#此时的X表示前面时刻预测得到的信息key_values = torch.cat((state[2][self.i], X), axis=1)state[2][self.i] = key_values  #存储当前block的输出得到的信息,赋值给变量#在训练时, 需要对真实值进行遮蔽if self.training:  #若处于训练#X是三维数据, shape= (batch_size, num_steps, num_hiddens)batch_size, num_steps, _ = X.shape#dec_valid_lens的shape=(batch_size, num_steps),每一行数据都是[1, 2, ....., num_steps]dec_valid_lens = torch.arange(1, num_steps + 1, device=X.device).repeat(batch_size, 1)else:dec_valid_lens = None  #为空,表示不训练,处于预测中,预测不需要遮蔽序列元素#多头自注意力层X2 = self.attention1(X, key_values, key_values, dec_valid_lens)#残差+LN层归一化Y = self.addnorm1(X, X2)#多头自注意力层Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)Z = self.addnorm2(Y, Y2)return self.addnorm3(Z, self.ffn(Z)), state      
valid_lens
tensor([3, 2])
decoder_blk = DecoderBlock(24, 24, 24, 24, [100, 24], 24, 24, 24, 8, 0.5, 0)
decoder_blk.eval()
X = torch.ones((2, 100, 24))
state = [encoder_blk(X, valid_lens), valid_lens, [None]]
result1, result2 = decoder_blk(X, state)
result1.shape,    result2[2][0].shape

都是

torch.Size([2, 100, 24])

2.6transformer解码器部分 

# transformer解码器部分
class TransformerDecoder(dltools.AttentionDecoder):def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, num_heads, num_layers, dropout, **kwargs):super().__init__(**kwargs)self.num_hiddens = num_hiddensself.num_layers = num_layersself.embedding = nn.Embedding(vocab_size, num_hiddens)self.pos_embedding = dltools.PositionalEncoding(num_hiddens, dropout)self.blks = nn.Sequential()for i in range(num_layers):self.blks.add_module('block' + str(i), DecoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, num_heads, dropout, i))self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):return [enc_outputs, enc_valid_lens, [None]*self.num_layers]def forward(self, X, state):X = self.pos_embedding(self.embedding(X) * math.sqrt(self.num_hiddens))self._attention_weights = [[None] * len(self.blks) for _ in range(2)]for i, blk in enumerate(self.blks):X, state = blk(X, state)self._attention_weights[0][i] = blk.attention1.attention.attention_weightsself._attention_weights[1][i] = blk.attention2.attention.attention_weightsreturn self.dense(X), state@propertydef attention_weights(self):return self._attention_weights
num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()
ffn_num_input, ffn_num_hiddens, ffn_num_outputs, num_heads = 32, 64,32, 4
key_size, query_size, value_size = 32, 32, 32
norm_shape = [32]train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)encoder = TransformerEncoder(len(src_vocab), key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, num_heads, num_layers, dropout)
decoder = TransformerDecoder(len(tgt_vocab), key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, num_heads, num_layers, dropout)
net = dltools.EncoderDecoder(encoder, decoder)
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

# 开始预测
engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')

 

go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('il est calme .', []), bleu 1.000
i'm home . => ('je suis chez moi .', []), bleu 1.000

3.知识点个人理解

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/146907.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

连续数组问题

目录 一题目&#xff1a; 二思路&#xff1a; 三代码&#xff1a; 一题目&#xff1a; leetcode链接&#xff1a;. - 力扣&#xff08;LeetCode&#xff09; 二思路&#xff1a; 思路&#xff1a;前缀和&#xff08;第二种&#xff09;化0为-1hash&#xff1a; 这样可以把…

【大模型实战篇】一种关于大模型高质量数据的处理方法-无标注数据类别快速识别及重复数据检测(加权向量-卷积神经网络-聚类算法结合)

1. 背景介绍 大模型的能力很大程度上依赖于高质量的数据&#xff0c;在之前的一篇文章《高质量数据过滤及一种BoostedBaggingFilter处理方法的介绍》中&#xff0c;我们介绍了大模型的数据处理链路&#xff0c;本文继续关注在高质量数据的模块。 本文所要介绍的处理方法&…

vscode 配置django

创建运行环境 使用pip安装Django&#xff1a;pip install django。 创建一个新的Django项目&#xff1a;django-admin startproject myproject。 打开VSCode&#xff0c;并在项目文件夹中打开终端。 在VSCode中安装Python扩展&#xff08;如果尚未安装&#xff09;。 在项…

滑动窗口经典题目

目录 滑动窗口 什么是滑动窗口&#xff1f; 什么时候用滑动窗口&#xff1f; 怎么用滑动窗口&#xff1f; 209. 长度最小的子数组&#xff08;滑动窗口的引入&#xff09; 3. 无重复字符的最长子串 1004. 最大连续1的个数 III 1658. 将 x 减到 0 的最小操作数 904. 水…

Fyne ( go跨平台GUI )中文文档-容器和布局 (四)

本文档注意参考官网(developer.fyne.io/) 编写, 只保留基本用法 go代码展示为Go 1.16 及更高版本, ide为goland2021.2 这是一个系列文章&#xff1a; Fyne ( go跨平台GUI )中文文档-入门(一)-CSDN博客 Fyne ( go跨平台GUI )中文文档-Fyne总览(二)-CSDN博客 Fyne ( go跨平台GUI…

【重学 MySQL】三十七、聚合函数

【重学 MySQL】三十七、聚合函数 基本概念5大常用的聚合函数COUNT()SUM()AVG()MAX()MIN() 使用场景注意事项示例查询 聚合函数&#xff08;Aggregate Functions&#xff09;在数据库查询中扮演着至关重要的角色&#xff0c;特别是在处理大量数据时。它们能够对一组值执行计算&a…

37. Vector3与模型位置、缩放属性

本文章给通过组对象Group (opens new window)给大家讲解一下threejs层级模型或树结构的概念。 Group层级模型(树结构)案例 下面代码创建了两个网格模型mesh1、mesh2&#xff0c;通过THREE.Group类创建一个组对象group,然后通过add方法把网格模型mesh1、mesh2作为设置为组对象g…

Vuex的使用看这一篇就够了

Vuex概述 Vuex 是一个专为 Vue.js 应用程序开发的状态管理库。它采用集中式存储管理应用的所有组件的状态&#xff0c;并以一种可预测的方式来保证状态以一种可预测的方式发生变化。 state状态 把公用的数据放到store里的state就行了&#xff0c;上面是vue2的代码&#xff0c;下…

[大语言模型] LINFUSION:1个GPU,1分钟,16K图像

1. 文章 2409.02097 (arxiv.org)https://arxiv.org/pdf/2409.02097 LINFUSION: 1 GPU, 1 MINUTE, 16K IMAGE 摘要 本文介绍了一种新型的扩散模型LINFUSION&#xff0c;它能够在保持高分辨率图像生成性能的同时显著降低时间和内存复杂度。该模型采用了基于Transformer的UNet进…

【前端】ES6:Class语法和Class继承

文章目录 1 Class语法1.1 类的写法1.2 getter与setter1.3 静态属性和静态方法 2 Class继承 1 Class语法 1.1 类的写法 class Person {constructor(name,age){this.name name;this.age age;}say(){console.log(this.name,this.age)} } let obj new Person("kerwin&quo…

python--基础语法(2)

1.顺序语句 默认情况下&#xff0c;Python的代码执行顺序是按照从上到下的顺序&#xff0c;依次执行的。 2.条件语句 条件语句能够表达“如果 ...否则 ...”这样的语义这构成了计算机中基础的逻辑判定条件语&#xff0c; 也叫做 分支语句。表示了接下来的逻辑可能有几种走向…

SysML图例-10cm最小航天器AC-10

DDD领域驱动设计批评文集>> 《软件方法》强化自测题集>> 《软件方法》各章合集>> SysML图中词汇 AC10 AeroCube-10&#xff0c;大小仅为10 10 15 cm的卫星&#xff0c;更多信息参见下文&#xff1a; AeroCube-10成为迄今为止完成在轨接近操作的最小航天…

yolov8模型在手部关键点检测识别中的应用【代码+数据集+python环境+GUI系统】

yolov8模型在手部关键点检测识别中的应用【代码数据集python环境GUI系统】 背景意义 在手势识别、虚拟现实&#xff08;VR&#xff09;、增强现实&#xff08;AR&#xff09;等领域&#xff0c;手部关键点检测为用户提供了更加自然、直观的交互方式。通过检测手部关键点&#…

移动登录页:让用户开启一段美好的旅程吧。

Hi,大家好&#xff0c;我是大千UI工场&#xff0c;移动登录页千千万&#xff0c;这里最好看&#xff0c;本期分享一批移动端的登录页面&#xff0c;供大家欣赏。 本次分享的是毛玻璃/3D风格的登录页。

Linux文件IO(七)-复制文件描述符

在 Linux 系统中&#xff0c;open 返回得到的文件描述符 fd 可以进行复制&#xff0c;复制成功之后可以得到一个新的文件描述符&#xff0c;使用新的文件描述符和旧的文件描述符都可以对文件进行 IO 操作&#xff0c;复制得到的文件描述符和旧的文件描述符拥有相同的权限&#…

【文化课学习笔记】【化学】选必三:合成高分子生物大分子

【化学】选必三&#xff1a;合成高分子&生物大分子 如果你是从 B 站一化儿笔记区来的&#xff0c;请先阅读我在第一篇有机化学笔记中的「读前须知」(点开头的黑色小三角展开)&#xff1a;链接 加聚反应 基本概念 聚合反应 由小分子化合物合成高分子化合物的反应叫聚合反应。…

学习 git 命令行的简单操作, 能够将代码上传到 Gitee 上

首先登录自己的gitee并创建好仓库 将仓库与Linux终端做链接 比如说我这里已经创建好了一个我的Linux学习仓库 点开克隆/下载&#xff1a; 在你的终端中粘贴上图中1中的指令 此时他会让你输入你的用户名和密码&#xff0c;用户名就是上图中3中Username for ....中后面你的一个…

秒变 Vim 高手:必学的编辑技巧与隐藏功能大揭秘

文章目录 前言一、vi与vim二、Vim的三种模式1. 普通模式2. 插入模式3. 命令模式 三、Vim中的查找与替换1. 查找2. 替换 四、给Vim设置行号1. 临时显示行号2. 永久显示行号 总结 前言 在Linux系统中&#xff0c;文本编辑器是开发者和系统管理员日常工作中的重要工具之一。其中&…

手机号归属地查询-运营商归属地查询-手机号归属地信息-运营商手机号归属地查询接口-手机号归属地

手机号归属地查询接口是一种网络服务接口&#xff0c;它允许开发者通过编程方式查询手机号码的注册地信息。这种接口通常由第三方服务提供商提供&#xff0c;并可通过HTTP请求进行调用。以下是一些关于手机号归属地查询接口的相关信息&#xff1a; 1. 接口功能 归属地查询&am…

HTB-GreenHorn 靶机笔记

GreenHorn 靶机笔记 概述 GreenHorn 是 HTB 上的一个 linux easy 难度的靶机&#xff0c;主要是通过信息搜集和代码审计找到对我们有用的信息。其中还包含了对pdf文件的修复技术 靶机地址&#xff1a;https://app.hackthebox.com/machines/GreenHorn 一丶 nmap 扫描 1&…