基于Python的自然语言处理系列(14):TorchText + biGRU + Attention + Teacher Forcing

        在前几篇文章中,我们探索了序列到序列(seq2seq)模型的基础,并通过使用双向GRU和上下文向量改进了模型的表现。然而,模型仍然依赖一个固定的上下文向量,这意味着它必须从整个源句中压缩信息,导致在长句子的翻译中可能出现问题。

        在本篇文章中,我们将引入注意力机制来解决这个问题。注意力机制允许解码器在每一步解码时不仅仅依赖一个固定的上下文向量,而是能够动态地访问源句中的所有信息。这样,模型可以在解码过程中“关注”到最相关的词,从而提升翻译的准确性,尤其是长句子。

1. 背景

        在传统的seq2seq模型中,解码器仅依赖编码器生成的一个上下文向量。尽管我们通过双向GRU改进了模型,但上下文向量仍然需要压缩整个源句的信息,限制了模型的表现。

        为了解决这个问题,注意力机制通过计算源句中每个词的权重,让解码器能够动态地关注源句中的不同部分,而不仅仅是依赖一个固定的上下文向量。这不仅提升了模型对长句子的处理能力,还提高了翻译的准确性。

2. 数据加载与预处理

        我们继续使用TorchText加载Multi30k数据集,并使用spacy进行标记化处理。数据加载的流程与之前文章中的内容相似。

from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizerSRC_LANGUAGE = 'en'
TRG_LANGUAGE = 'de'train = Multi30k(split=('train'), language_pair=(SRC_LANGUAGE, TRG_LANGUAGE))token_transform = {}
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')
token_transform[TRG_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')

        与前面相同,我们将数据集分为训练集、验证集和测试集,并将文本进行数值化处理。

3. 模型设计

        在这个模型中,我们将实现一个结合了双向GRU注意力机制的seq2seq模型。模型结构包括以下几个部分:

3.1 编码器(Encoder)

        首先,我们将构建编码器。这里我们使用双向GRU,将输入序列从左到右和从右到左进行编码。编码器输出的隐状态将作为解码器的初始隐状态,并传递给注意力机制。

class Encoder(nn.Module):def __init__(self, input_dim, emb_dim, hid_dim, dropout):super().__init__()self.embedding = nn.Embedding(input_dim, emb_dim)self.rnn = nn.GRU(emb_dim, hid_dim, bidirectional=True)self.fc = nn.Linear(hid_dim * 2, hid_dim)self.dropout = nn.Dropout(dropout)def forward(self, src):embedded = self.dropout(self.embedding(src))outputs, hidden = self.rnn(embedded)hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)))return outputs, hidden

3.2 注意力机制(Attention)

        注意力机制的作用是计算解码器当前隐状态与源句每个词的隐状态之间的权重,帮助解码器决定应该关注源句的哪些部分。权重越大,说明该词对当前解码步骤越重要。

class Attention(nn.Module):def __init__(self, hid_dim):super().__init__()self.v = nn.Linear(hid_dim, 1, bias=False)self.W = nn.Linear(hid_dim, hid_dim)self.U = nn.Linear(hid_dim * 2, hid_dim)def forward(self, hidden, encoder_outputs):batch_size = encoder_outputs.shape[1]src_len = encoder_outputs.shape[0]hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)encoder_outputs = encoder_outputs.permute(1, 0, 2)energy = torch.tanh(self.W(hidden) + self.U(encoder_outputs))attention = self.v(energy).squeeze(2)return F.softmax(attention, dim=1)

3.3 解码器(Decoder)

        解码器在每个时刻生成一个新的目标词。它使用注意力机制得到的加权源句信息,以及解码器自身的隐状态来生成目标词的预测。

class Decoder(nn.Module):def __init__(self, output_dim, emb_dim, hid_dim, dropout, attention):super().__init__()self.output_dim = output_dimself.attention = attentionself.embedding = nn.Embedding(output_dim, emb_dim)self.gru = nn.GRU((hid_dim * 2) + emb_dim, hid_dim)self.fc_out = nn.Linear((hid_dim * 2) + hid_dim + emb_dim, output_dim)self.dropout = nn.Dropout(dropout)def forward(self, input, hidden, encoder_outputs):input = input.unsqueeze(0)embedded = self.dropout(self.embedding(input))a = self.attention(hidden, encoder_outputs).unsqueeze(1)encoder_outputs = encoder_outputs.permute(1, 0, 2)weighted = torch.bmm(a, encoder_outputs).permute(1, 0, 2)rnn_input = torch.cat((embedded, weighted), dim=2)output, hidden = self.gru(rnn_input, hidden.unsqueeze(0))embedded = embedded.squeeze(0)output = output.squeeze(0)weighted = weighted.squeeze(0)prediction = self.fc_out(torch.cat((output, weighted, embedded), dim=1))return prediction, hidden.squeeze(0)

3.4 Seq2Seq模型

        将编码器、解码器和注意力机制组合起来,我们构建了完整的seq2seq模型。

class Seq2SeqAttention(nn.Module):def __init__(self, encoder, decoder, device):super().__init__()self.encoder = encoderself.decoder = decoderself.device = devicedef forward(self, src, trg, teacher_forcing_ratio=0.5):batch_size = src.shape[1]trg_len = trg.shape[0]trg_vocab_size = self.decoder.output_dimoutputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)encoder_outputs, hidden = self.encoder(src)input_ = trg[0, :]for t in range(1, trg_len):output, hidden = self.decoder(input_, hidden, encoder_outputs)outputs[t] = outputtop1 = output.argmax(1)input_ = trg[t] if random.random() < teacher_forcing_ratio else top1return outputs

4. 训练与评估

        我们使用与前几篇文章相同的训练和评估函数。为了防止梯度爆炸,我们在训练过程中应用梯度裁剪。

def train(model, iterator, optimizer, criterion, clip):model.train()epoch_loss = 0for i, (src, trg) in enumerate(iterator):src, trg = src.to(device), trg.to(device)optimizer.zero_grad()output = model(src, trg)output_dim = output.shape[-1]output = output[1:].view(-1, output_dim)trg = trg[1:].view(-1)loss = criterion(output, trg)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), clip)optimizer.step()epoch_loss += loss.item()return epoch_loss / len(iterator)def evaluate(model, iterator, criterion):model.eval()epoch_loss = 0with torch.no_grad():for i, (src, trg) in enumerate(iterator):src, trg = src.to(device), trg.to(device)output = model(src, trg, 0)output_dim = output.shape[-1]output = output[1:].view(-1, output_dim)trg = trg[1:].view(-1)loss = criterion(output, trg)epoch_loss += loss.item()return epoch_loss / len(iterator)

训练模型:

for epoch in range(10):train_loss = train(model, train_loader, optimizer, criterion, 1)val_loss = evaluate(model, valid_loader, criterion)print(f'Epoch {epoch+1} | Train Loss: {train_loss:.3f}, Val Loss: {val_loss:.3f}')

结语

        通过在seq2seq模型中引入注意力机制,我们成功提升了模型对长句子的处理能力。在每一个解码步骤中,解码器能够灵活地访问源句中的每个词,而不再依赖一个固定的上下文向量。这大大减少了信息压缩问题,使得模型在翻译复杂句子时更加精准。

        尽管注意力机制为模型带来了显著的改进,但训练时间相对增加。尤其在处理长句子时,模型需要计算源句中每个词的注意力权重,增加了计算复杂度。

        在下一篇文章中,我们将结合双向GRU注意力机制以及Packed Padded SequencesMasking技术,进一步优化模型的训练过程。同时,我们将展示如何通过这些技术处理不同长度的输入序列,并通过可视化注意力权重,更深入地理解模型在解码时关注哪些词。敬请期待!

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1543203.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

并发编程 - 锁(NSLock)

引言 在多线程编程中&#xff0c;数据一致性是一个必须解决的问题。多个线程同时访问同一片共享数据时&#xff0c;极易发生竞争条件&#xff08;race conditions&#xff09;&#xff0c;导致数据的不一致性&#xff0c;甚至程序崩溃。为了解决这些问题&#xff0c;我们需要引…

大模型备案最难材料搞定——安全评估报告、安全评估测试题【评估测试题+备案源文件】

大模型备案&#xff0c;最难搞定的2个材料&#xff0c;安全评估报告&#xff0c;安全评估测试题、拦截词&#xff0c;这里都有了 文章目录 &#xff08;一&#xff09;适用主体 &#xff08;二&#xff09;语料安全 &#xff08;三&#xff09;模型安全 &#xff08;四&…

Java_Se 数据变量与运算符

标识符、变量、常量、数据类型、运算符、基本数据类型的类型转换等。这些是编程中的“砖块”&#xff0c;是编程的基础。要想开始正式编程&#xff0c;还需要再学“控制语句”&#xff0c;控制语句就像“水泥”&#xff0c;可以把“砖块”粘到一起&#xff0c;最终形成“一座大…

2024年图纸加密防泄密软件Top10榜单 | 防止CAD图纸泄密打造坚不可摧的图纸安全

在当今数字化高速发展的时代&#xff0c;图纸作为重要的知识产权和商业机密&#xff0c;其安全问题备受关注。 一旦图纸泄露&#xff0c;可能给企业和个人带来巨大的损失。 为了保护图纸的安全&#xff0c;各种加密防泄密软件应运而生。下面为大家揭晓2024 年图纸加密防泄密软…

修复 msvcr120.dll 丢失的方法,总结几个靠谱有效的方法

1. msvcr120.dll 定义 1.1 Microsoft Visual C Redistributable Package 的一部分 msvcr120.dll 是 Microsoft Visual C 2013 Redistributable Package 的核心组件&#xff0c;该包为运行时环境提供了必要的库文件。它确保了使用 Visual C 2013 编译的应用程序能够在没有开发…

想在产品上扩展大储存怎么做?开源啦!

相比伙伴们都遇到过&#xff0c;芯片内存不够的问题&#xff1a;经常会有大量的照片、音频、文档等需要存储&#xff0c;怎么办呢&#xff1f; 我们知道可以外扩&#xff0c;要编写各种驱动&#xff0c;还有Flash替换&#xff0c;这个工程不可谓不大啊&#xff01; 但&#x…

Spring中的容器接口

容器接口 首先了解一下BeanFactory和ApplicationContext这两个接口的关系。 其实在一个 SpringBoot 项目中&#xff0c;这个 SpringBoot 项目的启动类的返回值就是一个 ApplicationContext 接口的实现类。 然后在 IDEA 中选中这个类&#xff0c;按住ctrlaltU可以查看类图&…

JavaScript 安装库npm报错

今天在编写JavaScript代码时&#xff0c;缺少了包express。 const express require(express); const app express();app.get(/, (req, res) > {res.send(Hello, world!); });app.listen(3000, () > {console.log(Server is running on port 3000); });npm install exp…

小程序开发设计-小程序的宿主环境:组件⑦

上一篇文章导航&#xff1a; 小程序开发设计-小程序的宿主环境&#xff1a;宿主环境简介⑥-CSDN博客https://blog.csdn.net/qq_60872637/article/details/142425131?spm1001.2014.3001.5501 注&#xff1a;不同版本选项有所不同&#xff0c;并无大碍。 1.小程序中组件的分类…

深度学习(1):基础概念与创建项目

文章目录 基础概念创建项目1.在Anaconda上创建序虚拟环境2.创建PyProject3.创建完成 基础概念 CPU&#xff08;中央处理器&#xff09; CPU 是计算机的核心部件&#xff0c;负责执行计算和逻辑操作。它按照指令序列进行任务处理&#xff0c;擅长处理串行任务。CPU 的性能直接…

【MyBatis 源码拆解系列】MyBatis 运行原理 - 读取 xml 配置文件

欢迎关注公众号&#xff08;通过文章导读关注&#xff1a;【11来了】&#xff09;&#xff0c;持续 分享大厂系统设计&#xff01; 在我后台回复 「资料」 可领取编程高频电子书&#xff01; 在我后台回复「面试」可领取硬核面试笔记&#xff01; 文章导读地址&#xff1a;点击…

Nature:科研论文中正确使用ChatGPT的三个原则

我是娜姐 迪娜学姐 &#xff0c;一个SCI医学期刊编辑&#xff0c;探索用AI工具提效论文写作和发表。 美国科罗拉多大学安舒茨医学院的生物医学信息学研究员Milton Pividori&#xff0c;一直在探索如何将ChatGPT等AI工具该技术融入课题组日常科研任务&#xff0c;例如进行文献综…

远程升级不成功?背后“凶手”可能是模组差分包…

最近有客户反馈在乡村里频繁出现掉线的情况。通过换货、换SIM卡对比排查测试&#xff0c;发现只有去年某批采购的那批模块在客户环境附近会出现掉线的情况&#xff0c;而今年采购的模块批次就不会掉线&#xff0c;很奇怪。 这个出问题的模块&#xff0c;就是合宙4G-Cat.1低功耗…

01.前端面试题之ts:说说如何在Vue项目中应用TypeScript?

文章目录 一、前言二、使用Componentcomputed、data、methodspropswatchemit 三 、总结 一、前言 与link类似 在VUE项目中应用typescript&#xff0c;我们需要引入一个库vue-property-decorator&#xff0c; 其是基于vue-class-component库而来&#xff0c;这个库vue官方推出…

数据驱动农业——农业中的大数据革命

橙蜂智能公司致力于提供先进的人工智能和物联网解决方案&#xff0c;帮助企业优化运营并实现技术潜能。公司主要服务包括AI数字人、AI翻译、埃域知识库、大模型服务等。其核心价值观为创新、客户至上、质量、合作和可持续发展。 橙蜂智农的智慧农业产品涵盖了多方面的功能&…

静态链接和动态链接的Golang二进制文件

关注TechLead&#xff0c;复旦博士&#xff0c;分享云服务领域全维度开发技术。拥有10年互联网服务架构、AI产品研发经验、团队管理经验&#xff0c;复旦机器人智能实验室成员&#xff0c;国家级大学生赛事评审专家&#xff0c;发表多篇SCI核心期刊学术论文&#xff0c;阿里云认…

抖音截流神器发布:不限量评论采集,实时推送,提升运营效率

在短视频风靡的今天&#xff0c;抖音成为品牌营销的新战场。如何在海量内容中脱颖而出&#xff0c;提升运营效率成为关键。本文将揭秘一款革命性的抖音运营工具&#xff0c;它不仅支持不限量评论采集&#xff0c;还实现了实时推送功能&#xff0c;助力运营者精准把握用户反馈&a…

保姆级 Stable Diffusion 教程,看完这篇就够了!

在美国科罗拉多州举办了一场新兴数字艺术家竞赛&#xff0c;一幅名为《太空歌剧院》的作品获得“数字艺术/数字修饰照片”类别的一等奖&#xff0c;神奇的是&#xff0c;该作品的作者并没有绘画基础&#xff0c;这幅画是他用 AI 生成的。 这让人们充分见识到AI 在绘画领域惊人的…

Shell实战(一)

Shell实战&#xff08;一&#xff09; 导语程序实例解压缩交互功能描述代码和运行结果实现解析 监视CPU和内存功能描述代码和运行结果实现解析 用户管理功能描述代码和运行结果实现解析 总结 导语 本篇引入三个书上的shell程序设计项目&#xff0c;由于书上的版本较老&#xf…