基于MindSpore实现Transformer机器翻译(下)

因本文内容较长,故分为上下两部分。上部分可点击以下链接查看
基于MindSpore实现Transformer机器翻译(上)

编码器(Encoder)

Transformer的Encoder负责处理输入的源序列,并将输入信息整合为一系列的上下文向量(context vector)输出。

每个encoder层中存在两个子层:多头自注意力(multi-head self-attention)和基于位置的前馈神经网络(position-wise feed-forward network)。

子层之间使用了残差连接(residual connection),并使用了层规范化(layer normalization)。二者统称为“Add & Norm”

encoder

基于位置的前馈神经网络 (Position-Wise Feed-Forward Network)

基于位置的前馈神经网络被用来对输入中的每个位置进行非线性变换。它由两个线性层组成,层与层之间需要经过ReLU激活函数。

F F N ( x ) = R e L U ( x W 1 + b 1 ) W 2 + b 2 \mathrm{FFN}(x) = \mathrm{ReLU}(xW_1 + b_1)W_2 + b_2 FFN(x)=ReLU(xW1+b1)W2+b2

相比固定的ReLU函数,基于位置的前馈神经网络可以处理更加复杂的关系,并且由于前馈网络是基于位置的,可以捕获到不同位置的信息,并为每个位置提供不同的转换。

Add & Norm

Add & Norm层本质上是残差连接后紧接了一个LayerNorm层。

Add&Norm ( x ) = LayerNorm ( x + Sublayer ( x ) ) \text{Add\&Norm}(x) = \text{LayerNorm}(x + \text{Sublayer}(x)) Add&Norm(x)=LayerNorm(x+Sublayer(x))

  • Add:残差连接,帮助缓解网络退化问题,注意需要满足 x x x SubLayer ( x ) 的形状一致 \text{SubLayer}(x)的形状一致 SubLayer(x)的形状一致
  • Norm:Layer Norm,层归一化,帮助模型更快地进行收敛;

解码器 (Decoder)

decoder

解码器将编码器输出的上下文序列转换为目标序列的预测结果 Y ^ \hat{Y} Y^,该输出将在模型训练中与真实目标输出 Y Y Y进行比较,计算损失。

不同于编码器,每个Decoder层中包含两层多头注意力机制,并在最后多出一个线性层,输出对目标序列的预测结果。

  • 第一层:计算目标序列的注意力分数的掩码多头自注意力
  • 第二层:用于计算上下文序列与目标序列对应关系,其中Decoder掩码多头注意力的输出作为query,Encoder的输出(上下文序列)作为key和value;

带掩码的多头注意力

在处理目标序列的输入时,t时刻的模型只能“观察”直到t-1时刻的所有词元,后续的词语不应该一并输入Decoder中。

为了保证在t时刻,只有t-1个词元作为输入参与多头注意力分数的计算,我们需要在第一个多头注意力中额外增加一个时间掩码,使目标序列中的词随时间发展逐个被暴露出来。

该注意力掩码可通过三角矩阵实现,对角线以上的词元表示为不参与注意力计算的词元,标记为1。

0 1 1 1 1 0 0 1 1 1 0 0 0 1 1 0 0 0 0 1 0 0 0 0 0 \begin{matrix} 0 & 1 & 1 & 1 & 1\\ 0 & 0 & 1 & 1 & 1\\ 0 & 0 & 0 & 1 & 1\\ 0 & 0 & 0 & 0 & 1\\ 0 & 0 & 0 & 0 & 0\\ \end{matrix} 0000010000110001110011110

该掩码一般被称作subsequent mask。

最后,将subsequent mask和padding mask合并为一个整体的掩码,确保模型既不会注意到t时刻以后的词元,也不会关注为 <pad> 的词元。

dec-self-attn-mask

通过Transformer实现文本机器翻译

全流程

  • 数据预处理: 将图像、文本等数据处理为可以计算的Tensor
  • 模型构建: 使用框架API, 搭建模型
  • 模型训练: 定义模型训练逻辑, 遍历训练集进行训练
  • 模型评估: 使用训练好的模型, 在测试集评估效果
  • 模型推理: 将训练好的模型部署, 输入新数据获得预测结果

数据准备

我们本次使用的数据集为Multi30K数据集,它是一个大规模的图像-文本数据集,包含30K+图片,每张图片对应两类不同的文本描述:

  • 英语描述,及对应的德语翻译;
  • 五个独立的、非翻译而来的英语和德语描述,描述中包含的细节并不相同;

因其收集的不同语言对于图片的描述相互独立,所以训练出的模型可以更好地适用于有噪声的多模态内容。

Multi30K

图片来源:Elliott, D., Frank, S., Sima’an, K., & Specia, L. (2016). Multi30K: Multilingual English-German Image Descriptions. CoRR, 1605.00459.

在本次文本翻译任务中,德语是源语言(source languag),英语是目标语言(target language)。

数据预处理

在使用数据进行模型训练等操作时,我们需要对数据进行预处理,流程如下:

  1. 加载数据集;
  2. 构建词典;
  3. 创建数据迭代器;
数据加载器

加载数据集,并进行分词,即将句子拆解为单独的词元(token,可以为字符或者单词)。一般在机器翻译类任务中,我们习惯进行单词级词元化,即每个词元要么为一个单词,要么为一个标点符号。同一个单词,不论首字母是否大写,都应该对应同一个词元,故在分词前,我们需统一将单词转换为小写。

“Hello world!” --> [“hello”, “world”, “!”]

接下来,我们创建数据加载器Multi30K。后期调用该类进行遍历时,每次返回当前源语言(德语)与目标语言(英语)文本描述的词元列表。

词典

将每个词元映射到从0开始的数字索引中(为节约存储空间,可过滤掉词频低的词元),词元和数字索引所构成的集合叫做词典(vocabulary)。

以上述“Hello world!”为例,该序列组成的词典为:

{“<unk>”: 0, “<pad>”: 1, “<bos>”: 2, “<eos>”: 3, “hello”: 4, “world”: 5, “!”: 6}

在构建词典中,我们使用了4个特殊词元。

  • <unk>:未知词元(unknown),将出现次数少于一定频率的单词统一判定为未知词元;
  • <bos>:起始词元(begin of sentence),用来标注一个句子的开始;
  • <eos>:结束词元(end of sentence),用来标注一个句子的结束;
  • <pad>:填充词元(padding),当句子长度不够时将句子填充至统一长度;

通过Vocab创建词典后,我们可以实现词元与数字索引之间的互相转换。我们可以通过调用enocde函数,返回输入词元或者词元序列对应的数字索引或数字索引序列,反之亦然,我们同样可以通过调用decode函数,返回输入数字索引或数字索引序列对应的词元或词元序列。

使用collections中的CounterOrderedDict统计英/德语每个单词在整体文本中出现的频率。构建词频字典,然后再将词频字典转为词典。其中,收录所有源语言(德语)词元的词典为de_vocab,收录所有目标语言(英语)词元的词典为en_vocab

在分配数字索引时有一个小技巧:常用的词元对应数值较小的索引,这样可以节约空间。

数据迭代器

数据预处理的最后一步是创建数据迭代器。截至目前,我们已经通过数据加载器Multi30K将源语言(德语)与目标语言(英语)的文本描述转换为词元序列,并构建了词元与数字索引一一对应的词典,接下来,需要将词元序列转换为数字索引序列。

还是以“Hello world!”为例,我们逐步演示数据迭代器中的操作

  1. 我们将表示开始和结束的特殊词元 <bos> 和 <eos> 分别添加在每个词元序列的句首和句尾。

[“hello”, “world”, “!”] --> [“<bos>”, “hello”, “world”, “!”, “<eos>”]

  1. 统一序列长度(超出长度的进行截断,未达到长度的通过填充 <pad> 进行补齐),同时记录序列的有效长度。此处假定统一的长度为7。

[“<bos>”, “hello”, “world”, “!”, “<eos>”] --> [“”, “hello”, “world”, “!”, “<eos&gt”, “<pad>”, “<pad>”], valid length = 5

  1. 最后,对文本序列进行批处理。对于每个batch中的序列,通过调用词典中的encode为序列中的所有词元找到其对应的数字索引,将结果以Tensor的形式返回。

[“<bos>”, “hello”, “world”, “!”, “<eos>”, “<pad>”, “<pad>”] --> [2, 4, 5, 6, 3, 1, 1] --> tensor

模型训练 & 模型评估

定义损失函数与优化器。

  • 损失函数:定义如何计算模型输出(logits)与目标(targets)之间的误差,这里可以使用交叉熵损失(CrossEntropyLoss)
  • 优化器:MindSpore将模型优化算法的实现称为优化器。优化器内部定义了模型的参数优化过程(即梯度如何更新至模型参数),所有优化逻辑都封装在优化器对象中。

模型训练逻辑

MindSpore在模型训练部分使用了函数式编程(FP)。

构造函数 → 函数变换 → 函数调用 \text{构造函数}\rightarrow \text{函数变换} \rightarrow \text{函数调用} 构造函数函数变换函数调用

  1. Network+loss function直接构造正向函数
  2. 函数变换,获得梯度计算(反向传播)函数
  3. 构造训练过程函数
  4. 调用函数进行训练

定义前向网络计算逻辑。

在训练过程中,表示句子结尾的 <eos> 占位符应是被模型预测出来,而不是作为模型的输入,所以在处理 Decoder 的输入时,我们需要移除目标序列最末的 <eos> 占位符。

trg = [, x_1, x_2, …, x_n, ]

trg[:-1] = [, x_1, x_2, …, x_n]

其中, x i x_i xi代表目标序列中第i个表示实际内容的词元。

我们期望最终的输出包含表示句末的 <eos> ,不包含表示句首的 <bos>,所以在计算损失时,需要同样去除的目标序列的句首 <bos> 占位符,再进行比较。

output = [y_1, y_2, …, y_n, <eos>]

trg[1:] = [x_1, x_2, …, x_n, <bos>]

其中, y i y_i yi表示预测的第i个实际内容词元。

定义梯度计算函数。

为了优化模型参数,需要求参数对loss的导数。我们调用mindspore.ops.value_and_grad函数,来获得function的微分函数。

value-and-grad

常用到的参数有三种:

  • fn:待求导的函数;
  • grad_position:指定求导输入位置的索引;
  • weights:指定求导的参数;
    由于使用Cell封装神经网络模型,模型参数为Cell的内部属性,此时我们不需要使用grad_position指定对函数输入求导,因此将其配置为None。对模型参数求导时,我们使用weights参数,使用model.trainable_params()方法从Cell中取出可以求导的参数。

定义整体训练逻辑。

在训练中,模型会以最小化损失为目标更新模型权重,故模型状态需设置为训练model.set_train(True)

定义模型评估逻辑。

在评估中,仅需正向计算loss,无需更新模型参数,故模型状态需设置为训练model.set_train(False)

模型训练

数据集遍历迭代,一次完整的数据集遍历成为一个epoch。我们逐个epoch打印训练的损失值和评估精度,并通过save_checkpoint保存评估精度最高的ckpt文件(transformer.ckpt)到home_path/.mindspore_examples/transformer.ckpt。

模型推理

首先,通过load_checkpointload_param_into_net将训练好的模型参数加载入新实例化的模型中。

推理过程中无需对模型参数进行更新,所以这里model.set_train(False)

我们输入一个德文语句,期望可以返回翻译好的英文语句。

首先通过Encoder提取德文序列中的特征信息,并将其传输至Decoder。

Decoder最开始的输入为起始占位符 <bos>,每次会根据输入预测下一个出现的单词,并对输入进行更新,直到预测出终止占位符 <eos> 。

BLEU得分

双语替换评测得分(bilingual evaluation understudy,BLEU)为衡量文本翻译模型生成出来的语句好坏的一种算法,它的核心在于评估机器翻译的译文 pred \text{pred} pred 与人工翻译的参考译文 label \text{label} label 的相似度。通过对机器译文的片段与参考译文进行比较,计算出各个片段的的分数,并配以权重进行加和,基本规则为:

  1. 惩罚过短的预测,即如果机器翻译出来的译文相对于人工翻译的参考译文过于短小,则命中率越高,需要施加更多的惩罚;
  2. 对长段落匹配更高的权重,即如果出现长段落的完全命中,说明机器翻译的译文更贴近人工翻译的参考译文;

BLEU的公式如下:

e x p ( m i n ( 0 , 1 − l e n ( label ) l e n ( pred ) ) Π n = 1 k p n 1 / 2 n ) exp(min(0, 1-\frac{len(\text{label})}{len(\text{pred})})\Pi^k_{n=1}p_n^{1/2^n}) exp(min(0,1len(pred)len(label))Πn=1kpn1/2n)

  • len(label):人工翻译的译文长度
  • len(pred):机器翻译的译文长度
  • p_n:n-gram的精度

我们可以调用nltk中的corpus_bleu函数来计算BLEU,在此之前,需要手动下载nltk

pip install nltk

from nltk.translate.bleu_score import corpus_bleudef calculate_bleu(dataset, max_len=50):trgs = []pred_trgs = []for data in dataset[:10]:src = data[0]trg = data[1]pred_trg = inference(src, max_len)pred_trgs.append(pred_trg)trgs.append([trg])return corpus_bleu(trgs, pred_trgs)bleu_score = calculate_bleu(test_dataset)print(f'BLEU score = {bleu_score*100:.2f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/146273.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

分布式消息中间件kafka

文章目录 什么是kafka?整体架构 kafka核心概念1. 生产者 (Producer)2. 消费者 (Consumer)3. 主题 (Topic)4. 分区 (Partition)5. 经纪人 (Broker)6. 复制 (Replication)7. 消费者组 (Consumer Group)8. 日志段 (Log Segment) 主要功能1. 高吞吐量2. 可靠的消息传递3. 发布/订阅…

信息安全数学基础(19)同余式的基本概念及一次同余式

一、同余式概念 同余式是数论中的一个基本概念&#xff0c;用于描述两个数在除以某个数时所得的余数相同的情况。具体地&#xff0c;设m是一个正整数&#xff0c;a和b是两个整数&#xff0c;如果a和b除以m的余数相同&#xff0c;则称a和b模m同余&#xff0c;记作a≡b(mod m)。反…

电脑ip地址怎么换地区:操作步骤与利弊分析

在当今全球化的信息时代&#xff0c;人们经常需要访问不同地区的网络资源。然而&#xff0c;由于地理位置的限制&#xff0c;某些内容或服务可能只对特定地区的用户开放。这时&#xff0c;更换电脑IP地址的地区就成为了一个实用的解决方案。本文将详细介绍两种更换电脑IP地址地…

Excel--DATEDIF函数的用法及参数含义

DATEDIF函数的用法为: DATEDIF(start_date,end_date,unit),start_date表示的是起始时间&#xff0c;end_date表示的是结束时间。unit表示的是返回的时间代码&#xff0c;是天、月、年等。如下: Datedif函数的参数含义unit参数返回值的意义"y"两个时间段之间的整年数…

QEMU:模拟 ARM 大端字节序运行环境

文章目录 1. 前言2. ARM 大小端模拟测试2.1 裸机模拟测试2.1.1 大端模拟测试2.1.2 小端模拟测试 2.2 用户空间模拟测试2.2.1 大端模拟测试2.2.2 小端模拟测试 2.3 结论 3. 参考链接 1. 前言 限于作者能力水平&#xff0c;本文可能存在谬误&#xff0c;因此而给读者带来的损失&…

PLC通信协议的转化

在自动化程序设计中&#xff0c;常常需要对通信协议进行相互转化。例如&#xff0c;某个控制器需要通过PLC控制设备的某个部件的运动&#xff0c;但PLC只支持ModbusTCP协议&#xff0c;而控制器只支持CanOpen通讯协议。这时&#xff0c;就需要一个网关进行通信协议的转化。网关…

双击就可以打开vue项目,而不用npm run dev

右键点击桌面或其他位置&#xff0c;选择“新建” -> “快捷方式”&#xff0c;在“对象的位置”处直接输入“npm run dev”&#xff0c;然后下一步 自定义一个快捷方式名称 完成后&#xff0c;桌面会创建一个快捷方式&#xff0c;右键快捷方式选择属性&#xff0c;可以看…

MiniCPM3-4B | 笔记本电脑运行端侧大模型OpenBMB/MiniCPM3-4B-GPTQ-Int4量化版 | PyCharm环境

MiniCPM3-4B&#xff0c;轻松在笔记本电脑上运行大模型&#xff1f; 背景一、选择模型二、模型下载三、模型运行四、总结 背景 2024年9月5日&#xff0c;面壁智能发布了MiniCPM3-4B&#xff0c;面壁的测试结果声称MiniCPM3-4B表现超越 Phi-3.5-mini-instruct 和 GPT-3.5-Turbo-…

proxy认识一下

免责声明:本文仅做分享。 遵守规则&#xff0c;自行跳过。 Proxy 代理技术介绍 1. 代理简介 代理&#xff08;Proxy&#xff09; 是指在客户端和目标服务器之间充当中介的设备或应用程序。代理服务器的主要功能是接收客户端的请求&#xff0c;并将这些请求转发给目标服务器&a…

解决Mac下Vscode编译运行C语言程序会自动生成DSYM文件夹的问题

&#x1f389; 前言 好久没写C语言了&#xff0c;今天打开Vscode打算写点程序练练手&#xff0c;结果发现一个让我非常苦恼的事情&#xff0c;那就是每次我运行程序的时候&#xff0c;左侧的资源管理器就会生成一大堆的文件&#xff0c;如图&#xff1a; 强迫症犯了&#xff…

模方单体化建模,建模的时候画线突然无法显示垂直线,如何解决?

垂直线对应线都可以在联动软件中设定。 模方是一款针对实景三维模型的冗余碎片、水面残缺、道路不平、标牌破损、纹理拉伸模糊等共性问题研发的实景三维模型修复编辑软件。模方4.1新增自动单体化建模功能&#xff0c;支持一键自动提取房屋结构&#xff0c;平均1栋复杂建筑物只…

机器翻译之数据处理

目录 1.导包 2.读取本地数据 3.定义函数&#xff1a;数据预处理 4.定义函数&#xff1a;词元化 5.统计每句话的长度的分布情况 6. 获取词汇表 7. 截断或者填充文本序列 8.将机器翻译的文本序列转换成小批量tensor 9.加载数据 10.知识点个人理解 1.导包 #导包 import o…

2016年国赛高教杯数学建模A题系泊系统的设计解题全过程文档及程序

2016年国赛高教杯数学建模 A题 系泊系统的设计 近浅海观测网的传输节点由浮标系统、系泊系统和水声通讯系统组成&#xff08;如图1所示&#xff09;。某型传输节点的浮标系统可简化为底面直径2m、高2m的圆柱体&#xff0c;浮标的质量为1000kg。系泊系统由钢管、钢桶、重物球、…

生信初学者教程(四):软件

文章目录 RRstudioLinux系统其他软件本书是使用R语言编写的教程,用户需要下载R和RStudio软件用于进行分析。 版权归生信学习者所有,禁止商业和盗版使用,侵权必究 R R语言是一种免费的统计计算和图形化编程语言,是一种用于数据分析和统计建模的强大工具。它具有丰富的统计…

solidwork找不到曲面

如果找不到曲面 则右键找到选项卡&#xff0c;选择曲面

Pybullet 安装过程

Pybullet 安装过程 1. 安装C编译工具2. 安装Pybullet 1. 安装C编译工具 pybullet 需要C编译套件&#xff0c;直接装之前检查下&#xff0c;要不会报缺少某版本MVSC的error&#xff0c;最好的方式是直接下载visual studio&#xff0c;直接按默认的来装。 2. 安装Pybullet 这里…

Mycat中间件

一、案例目标 &#xff08;1&#xff09;了解Mycat提供的读写分离功能。 &#xff08;2&#xff09;了解MySQL数据库的主从架构。 &#xff08;3&#xff09;构建以Mycat为中间件的读写分离数据库集群。 二、案例分析 1.规划节点 使用Mycat作为数据库中间件服务构建读写分…

聚观早报 | 小米三折叠手机专利曝光;李斌谈合肥投资蔚来

聚观早报每日整理最值得关注的行业重点事件&#xff0c;帮助大家及时了解最新行业动态&#xff0c;每日读报&#xff0c;就读聚观365资讯简报。 整理丨Cutie 9月20日消息 小米三折叠手机专利曝光 李斌谈合肥投资蔚来 索尼PS5 Pro包装亮相 新一代Spectacles AR眼镜发布 通…

媒体专访 | CertiK首席安全官李康教授:变化中的加密资产监管环境带来了新机遇

在2024韩国区块链周期间&#xff0c;CertiK首席安全官李康教授接受了韩国媒体E-Today的独家专访。采访中&#xff0c;李康教授探讨了加密资产监管环境的最新动态及其为行业带来的新机遇。同时&#xff0c;他也表达了对加密资产生态系统所面临的安全挑战的担忧&#xff0c;并强调…

无人机视角应急救援(人)数据集

无人机视角应急救援&#xff08;人&#xff09;&#xff0c;两个数据集 part1&#xff0c;使用DJI Phantom 4A拍摄&#xff0c;分辨率为19201080像素。山区场景&#xff0c;图像中人员姿势分为站立、坐着、躺着、行走、奔跑。共1981张图像6500个不同姿势的标记&#xff0c; par…