机器翻译之Bahdanau注意力机制在Seq2Seq中的应用

目录

1.创建 添加了Bahdanau的decoder 

2. 训练

 3.定义评估函数BLEU

 4.预测

 5.知识点个人理解


1.创建 添加了Bahdanau的decoder 

import torch
from torch import nn
import dltools#定义注意力解码器基类
class AttentionDecoder(dltools.Decoder):  #继承dltools.Decoder写注意力编码器的基类def __init__(self, **kwargs):super().__init__(**kwargs)@property    #装饰器, 定义的函数方法可以像类的属性一样被调用def attention_weights(self):#raise用于引发(或抛出)异常raise NotImplementedError  #通常用于抽象基类中,作为占位符,提醒子类必须实现这个方法。 #创建 添加了Bahdanau的decoder
#继承AttentionDecoder这个基类创建Seq2SeqAttentionDecoder子类, 子类必须实现父类中NotImplementedError占位的方法
class Seq2SeqAttentionDecoder(AttentionDecoder):  #初始化属性和方法def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):"""vocab_size:此表大小,  相当于输入数据的特征数features,  也是输出数据的特征数embed_size:嵌入层的大小:将输入数据处理成小批量的数据num_hiddens:隐藏层神经元的数量num_layers:循环网络的层数dropout=0:不释放模型的参数(比如:神经元)"""super().__init__(**kwargs)#初始化注意力机制的评分函数方法self.attention = dltools.AdditiveAttention(key_size=num_hiddens,query_size=num_hiddens, num_hiddens=num_hiddens,dropout=dropout)#初始化嵌入层:将输入的数据处理成小批量的tensor数据   (文本--->数值的映射转化)self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size)#初始化循环网络self.rnn = nn.GRU(embed_size+num_hiddens, num_hiddens, num_layers, dropout=dropout)#初始化线性层  (输出层)self.dense = nn.Linear(num_hiddens, vocab_size)#初始化隐藏层的状态state   (计算state,需要编码器的输出结果、序列的有效长度)def init_state(self, enc_outputs, enc_valid_lens, *args):#enc_outputs是一个元组(输出结果,隐藏状态)#outputs的shape=(batch_size, num_steps, num_hiddens)#hidden_state的shape=(num_layers, batch_size, num_hiddens)outputs, hidden_state = enc_outputs#返回一个元组(,),可以用一个变量接收#outputs.permute(1, 0, 2)转换数据的维度是因为rnn循环神经网络的输入要求是先num_steps,再batch_size,return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)#定义前向传播   (输入数据X,state)def forward(self, X, state):#变量赋值:接收编码器encoder的输出结果、隐藏状态、序列有效长度#enc_outputs的shape=(batch_size, num_steps, num_hiddens)#hidden_state的shape=(num_layers, batch_size, num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state#X的shape=(batch_size, num_steps, vocab_size)X = self.embedding(X)   #将X输入embedding嵌入层后, X的shape=(batch_size, num_steps, embed_size)#调换X的0维度和1维度数据X = X.permute(1, 0, 2)   #X的shape=(num_steps, batch_size, embed_size)outputs, self._attention_weights = [], []  #创建空列表,用于存储数据for x in X:  #遍历每一批数据#获取query#hidden_state[-1]表示最后一层循环网络的隐藏层状态  (有两层循环网络)#hidden_state[-1]的shape=(batch_size, num_hiddens)    #dim=1表示在原索引1的维度增加一个维度query = torch.unsqueeze(hidden_state[-1], dim=1)  
#             print('query的shape:', query.shape)   #query的shape=(batch_size, 1, num_hiddens)#通过注意力机制获取上下文序列context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
#             print('context的shape:', context.shape)  #context的shape=(batch_size, 1, num_hiddens)#用最后一个维度 拼接context, x 数据x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
#             print('x的shape:', x.shape)   #x的shape=(batch_size, 1, num_hiddens+embed_size)#将x和hidden_state输入循环神经网络中,获取输出结果和新的hidden_stateout, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
#             print('out的shape:', out.shape)   #out的shape=(1, batch_size, num_hiddens)
#             print('hidden_state的shape:', hidden_state.shape) #两层循环层:hidden_state的shape=(2, batch_size, num_hiddens)#将输出结果添加到列表中outputs.append(out)self._attention_weights.append(self.attention_weights)outputs = self.dense(torch.cat(outputs, dim=0))
#         print('outputs的shape:', outputs.shape)  #outputs的shape=(num_steps, batch_size, vocab_size)return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights#测试代码
#创建编码器对象
encoder = dltools.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
#需要预测, 要加encoder.eval()
encoder.eval()
#创建解码器对象
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()#假设数据
batch_size, num_steps = 4, 7
X = torch.zeros((4, 7), dtype = torch.long)
#初始化状态state
state = decoder.init_state(encoder(X), None)
outputs, state = decoder(X, state)
#state包含三个东西(enc_outputs, hidden_state, enc_valid_lens)
#state[0]是 enc_outputs
#state[1]是 hidden_state, 两层循环层,就会有两个hidden_state, state[1][0]是第一层的hidden_state
outputs.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
outputs的shape: torch.Size([7, 4, 10])

Out[11]:

(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))

2. 训练

#声明变量
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()#加载数据
train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)#创建编辑器对象
encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
#创建编辑器对象
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)#创建网络模型
net = dltools.EncoderDecoder(encoder, decoder)#模型训练
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

 

 3.定义评估函数BLEU

def bleu(pred_seq, label_seq, k):print('pred_seq:', pred_seq)print('label_seq:', label_seq)#将pred_seq, label_seq分别进行空格分隔pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')#获取pred_seq, label_seq的长度len_pred, len_label = len(pred_seq), len(label_seq)score = math.exp(min(0, 1 - (len_label / len_pred)))for n in range(1, k+1): #n的取值范围,  range()左闭右开num_matches, label_subs = 0, collections.defaultdict(int)for i in range(len_label - n + 1):label_subs[' '.join(label_tokens[i: i+n])] += 1for i in range(len_pred - n + 1):if label_subs[' '.join(pred_tokens[i: i+n])] > 0:num_matches += 1label_subs[' '.join(pred_tokens[i: i+n])] -=1score *= math.pow(num_matches / (len_pred -n + 1), math.pow(0.5, n))return score

 4.预测

import math
import collectionsengs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')

go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('il est bon .', []), bleu 0.658
i'm home . => ('je suis chez moi .', []), bleu 1.000

 5.知识点个人理解

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/146389.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

动手学深度学习(五)循环神经网络RNN

一、序列模型 1、统计工具 ①联合概率分布 假设有一个序列 x[x1,x2,…,xT],我们可以把序列的联合概率分解为多个条件概率的乘积。 ②建模 f(x1​,…,xt−1​) 是一个函数,用于提取前 t−1个序列元素的信息。这意味着我们不需要存储每一个之前的序列元…

关于群里脱敏系统的讨论2024-09-20

群里大家讨论脱敏系统,傅同学:秦老师,银行数据脱敏怎么做的,怎么存储的? 采购了脱敏系统,一般是硬件(厂商直接卖的一体机)。这个系统很复杂,大概卖50-100万一台。 最核…

Springboot常见问题(bean找不到)

如图错误显示userMapper bean没有找到。 解决方案: mapper包位置有问题:因为SpringBoot默认的包扫描机制会扫描启动类所在的包同级文件和子包下的文件。注解问题: 比如没有加mapper注解 然而无论是UserMapper所在的包位置还是Mapper注解都是…

HelpLook VS GitBook,在线文档管理工具对比

在线文档管理工具在当今时代非常重要。随着数字化时代的到来,人们越来越依赖于电子文档来存储、共享和管理信息。无论是与团队合作还是与客户分享,人们都可以轻松地共享文档链接或通过设置权限来控制访问。在线文档管理工具的出现大大提高了工作效率和协…

echarts 散点图tooltip显示一个点对应多个y值

tooltip&#xff1a;显示 tooltip: {trigger: "axis",extraCssText: max-width:50px; white-space:pre-wrap,formatter: function (params) {let arr []params.forEach(v > {arr.push(v.data[1])});return params[0].data[0]":<br>["arr.toStr…

leetcode刷题3

文章目录 前言回文数1️⃣ 转成字符串2️⃣ 求出倒序数再比对 正则表达式匹配[hard]1️⃣ 动态规划 盛最多水的容器1️⃣ 遍历分类2️⃣ 双指针贪心 最长公共前缀1️⃣ 遍历&#xff08;zip解包&#xff09; 三数之和1️⃣ 双指针递归 最接近的三数之和1️⃣ 迭代一次双指针 电…

PCL addLine可视化K近邻

目录 一、概述 1.1原理 1.2实现步骤 1.3应用场景 二、代码实现 2.1关键函数 2.2完整代码 三、实现效果 PCL点云算法汇总及实战案例汇总的目录地址链接&#xff1a; PCL点云算法与项目实战案例汇总&#xff08;长期更新&#xff09; 一、概述 本文将介绍如何使用PCL库中…

Unreal Engine 5 C++: 编辑器工具编写入门(中文解释)

目录 准备工作 1.创建插件 2.修改插件设置 快速资产操作&#xff08;quick asset action) 自定义编辑器功能 0.创建编辑器button&#xff0c;测试debug message功能 大致流程 详细步骤 1.ctrlF5 launch editor 2.创建新的cpp class&#xff0c;derived from AssetAction…

基于PHP的CRM管理系统源码/客户关系管理CRM系统源码/php源码/附安装教程

源码简介&#xff1a; 这是一款基于PHP开发的CRM管理系统源码&#xff0c;全称客户关系管理CRM系统源码&#xff0c;它是由php源码开发的&#xff0c;还附带了一整套详细的安装教程哦&#xff01; 功能亮点&#xff1a; 1、公海管理神器&#xff1a;不仅能搞定公海类型&…

【自然语言处理】补充:布尔模型

【自然语言处理】补充:布尔模型 布尔检索是指针对查询的检索,布尔查询是指利用AND,OR或者NOT操作符将词项连接起来的查询,例如:信息AND检索、信息OR检索、信息AND检索AND NOT教材 Google的高级搜索/布尔查询 Google的AND—百度 “ 手机 报价 ”Google的NOT—百度 “ 手机…

关于MATLAB计算3维图的向量夹角总是不正确的问题记录

文章目录 问题描述解决方法完整代码 问题描述 因为最近在做无人机的一个项目&#xff0c;所以需要画出无人机的轨迹&#xff0c;然后再提取特征值&#xff0c;我这里在计算夹角的时候发现为什么在视觉上明明看的是钝角但是实际计算出来却是锐角的角度。 如下图所示&#xff0c…

Spring面试题合集

Spring 1.谈谈你对Spring的理解 首先Spring是一个轻量级的开源框架&#xff0c;为Java程序的开发提供了基础架构支持&#xff0c;简化了应用开发&#xff0c;让开发者专注于开发逻辑&#xff1b; 同时Spring是一个容器&#xff0c;它通过管理Bean的生命周期和依赖注入&#…

无处不在的人工智能

文章目录 引言科幻电影中的AI《她》&#xff1a;人工智能的爱情《我&#xff0c;机器人》&#xff1a;AI的觉醒 人工智能的发展现状专用人工智能的突破通用人工智能的起步 结语 引言 在21世纪的今天&#xff0c;人工智能&#xff08;AI&#xff09;已经成为推动社会发展的关键…

英集芯IP5902:集成电压可调异步升压转换充电管理功能的8位MCU芯片

英集芯IP5902是一款集成了9V异步升压转换、锂电池充电管理及负端NMOS管的8-bit MCU芯片&#xff0c;外壳采用了SOP16封装形式&#xff0c;高集成度和丰富的功能使其在应用时只需很少的外围器件&#xff0c;就能有效减小整体方案的尺寸&#xff0c;降低BOM成本&#xff0c;为小型…

dockercompose指定配置文件

dockercompose指定配置文件 文件名字必须是以下的集中形式&#xff1a; docker-compose.yaml docker-compose.yml compose.yaml compose.yml 其他名字就失败的。 一般白眉大叔都是用 compose.yaml 这个格式&#xff0c; 用习惯了。 但是我们必须知道它有几种格式都是可以…

聚焦于 Web 性能指标 TTI

在优化网站性能的过程中&#xff0c;我们经常遇到一个“为指标而优化”的困境。指标并不能真正反映用户体验&#xff0c;而应该最真实地反映用户行为。 在本节中&#xff0c;我们将研究 TTI&#xff08;Time to Interactive&#xff09;。在深入探讨这个话题之前&#xff0c;我…

信奥初赛解析:1-3-计算机软件系统

知识要点 软件系统是计算机的灵魂。没有安装软件的计算机称为“裸机”&#xff0c;无法完成任何工作硬件为软件提供运行平台。软件和硬件相互关联,两者之间可以相互转化&#xff0c;互为补充 计算机软件系统按其功能可分为系统软件和应用软件两大类 一、系统软件 系统软件是指…

HTTP中的event-stream,eventsource,SSE,chatgpt,stream request,golang

我们都知道chatgpt是生成式的&#xff0c;因此它返回给客户端的消息也是一段一段的&#xff0c;所以普通的HTTP协议无法满足&#xff0c;当然websocket是能满足的&#xff0c;但是这个是双向的通信&#xff0c;其实 SSE&#xff08;Server-Sent Events&#xff09; 正好满足这个…

【操作教程】视频监控系统EasyCVR视频汇聚管理平台如何添加用户和角色?

视频监控平台/视频监控系统EasyCVR视频汇聚管理平台以其强大的拓展性、灵活的部署方式、高性能的视频能力和智能化的分析能力&#xff0c;为各行各业的视频监控需求提供了优秀的解决方案。通过简单的配置和操作&#xff0c;用户可以轻松地进行远程视频监控、存储和查看&#xf…

永磁同步电机谐波抑制算法(8)——基于神经网络的傻瓜式(无需知道谐波频率)谐波抑制

1.简介 前面的内容已经介绍了很多谐波抑制的方法&#xff1a;多同步、PIR、陷波器等等。也介绍了比较多的谐波来源&#xff1a;死区&#xff08;5、7、11、13等次相电流谐波&#xff09;、绕组不对称&#xff08;基波不等幅值、3次相电流谐波&#xff09;等等。 上述的方法都…