基于Python的自然语言处理系列(16):TorchText + CNN + Teacher Forcing

        在本篇文章中,我们将实现 卷积序列到序列学习模型(Convolutional Sequence to Sequence Learning)。与之前介绍的基于循环神经网络(RNN)的模型不同,卷积模型不依赖递归成分,而是通过卷积层(CNN)来实现序列间的特征提取与学习。我们还将继续使用 Teacher Forcing 来增强训练效果。

模型简介

        本模型的核心区别在于使用了 卷积层 来处理输入序列,而非递归结构。卷积层通过滑动窗口和滤波器机制逐步处理输入序列的多个连续词汇片段,并从中提取特征。在这篇教程中,我们将使用 1024 个滤波器,每个滤波器可以“看到”3个连续词汇,从而为模型提取有意义的特征,这些特征会被进一步传递给下一层。

        下图展示了基于卷积的序列学习模型的工作原理:

        在这个模型中,编码器 将源序列(原始语言的句子)编码成上下文向量(context vector),而 解码器 则基于该上下文向量生成目标序列(翻译后的句子)。为了在卷积网络中体现输入序列的顺序信息,模型还使用了 位置嵌入(Positional Embedding)

实现步骤

1. 数据加载与预处理

        首先,我们使用 TorchText 加载 Multi30k 数据集,这个数据集包含英语到德语的翻译句子。为了让卷积神经网络处理更高效,我们将批量处理的方式改为以 batch first 的形式,即批次维度位于最前。这样可以确保 CNN 模型能够正确接收数据。

import torch, torchdata, torchtext
from torch import nn
import random, math, timedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True# 加载 Multi30k 数据集
from torchtext.datasets import Multi30kSRC_LANGUAGE = 'en'
TRG_LANGUAGE = 'de'train = Multi30k(split=('train'), language_pair=(SRC_LANGUAGE, TRG_LANGUAGE))# 数据预处理,包括分词和构建词汇表
from torchtext.data.utils import get_tokenizer
token_transform = {SRC_LANGUAGE: get_tokenizer('spacy', language='en_core_web_sm'),TRG_LANGUAGE: get_tokenizer('spacy', language='de_core_news_sm')
}vocab_transform = {}
from torchtext.vocab import build_vocab_from_iteratorfor ln in [SRC_LANGUAGE, TRG_LANGUAGE]:vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train, ln), min_freq=2, specials=['<unk>', '<pad>', '<sos>', '<eos>'])

2. 模型设计

        本模型的设计分为 编码器(Encoder)解码器(Decoder) 两部分。编码器负责将输入序列转换为上下文向量,解码器则根据上下文向量生成翻译后的目标序列。

2.1 编码器

        编码器的主要任务是将输入的句子转换为特征向量。它使用卷积层从输入句子中提取特征,并通过 位置嵌入 提供位置信息,以保持序列的顺序。

class Encoder(nn.Module):def __init__(self, input_dim, emb_dim, hid_dim, n_layers, kernel_size, dropout, device, max_length=100):super().__init__()self.device = deviceself.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device)self.tok_embedding = nn.Embedding(input_dim, emb_dim)self.pos_embedding = nn.Embedding(max_length, emb_dim)self.emb2hid = nn.Linear(emb_dim, hid_dim)self.hid2emb = nn.Linear(hid_dim, emb_dim)self.convs = nn.ModuleList([nn.Conv1d(in_channels=hid_dim, out_channels=2*hid_dim, kernel_size=kernel_size, padding=(kernel_size - 1) // 2) for _ in range(n_layers)])self.dropout = nn.Dropout(dropout)def forward(self, src):batch_size = src.shape[0]src_len = src.shape[1]pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)tok_embedded = self.tok_embedding(src)pos_embedded = self.pos_embedding(pos)embedded = self.dropout(tok_embedded + pos_embedded)conv_input = self.emb2hid(embedded).permute(0, 2, 1)for conv in self.convs:conved = F.glu(conv(self.dropout(conv_input)), dim=1)conved = (conved + conv_input) * self.scaleconv_input = convedconved = self.hid2emb(conved.permute(0, 2, 1))combined = (conved + embedded) * self.scalereturn conved, combined
2.2 解码器

        解码器的设计与编码器类似,但解码器会生成目标序列。通过卷积层,解码器并行处理序列中的每个词,并结合编码器的输出生成翻译结果。

class Decoder(nn.Module):def __init__(self, output_dim, emb_dim, hid_dim, n_layers, kernel_size, dropout, trg_pad_idx, device, max_length=100):super().__init__()self.kernel_size = kernel_sizeself.trg_pad_idx = trg_pad_idxself.device = deviceself.scale = torch.sqrt(torch.FloatTensor([0.5])).to(device)self.tok_embedding = nn.Embedding(output_dim, emb_dim)self.pos_embedding = nn.Embedding(max_length, emb_dim)self.emb2hid = nn.Linear(emb_dim, hid_dim)self.hid2emb = nn.Linear(hid_dim, emb_dim)self.fc_out = nn.Linear(emb_dim, output_dim)self.convs = nn.ModuleList([nn.Conv1d(in_channels=hid_dim, out_channels=2*hid_dim, kernel_size=kernel_size) for _ in range(n_layers)])self.dropout = nn.Dropout(dropout)def calculate_attention(self, embedded, conved, encoder_conved, encoder_combined):conved_emb = self.attn_hid2emb(conved.permute(0, 2, 1))combined = (conved_emb + embedded) * self.scaleenergy = torch.matmul(combined, encoder_conved.permute(0, 2, 1))attention = F.softmax(energy, dim=2)attended_encoding = torch.matmul(attention, encoder_combined)attended_encoding = self.attn_emb2hid(attended_encoding)attended_combined = (conved + attended_encoding.permute(0, 2, 1)) * self.scalereturn attention, attended_combineddef forward(self, trg, encoder_conved, encoder_combined):batch_size = trg.shape[0]trg_len = trg.shape[1]pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)tok_embedded = self.tok_embedding(trg)pos_embedded = self.pos_embedding(pos)embedded = self.dropout(tok_embedded + pos_embedded)conv_input = self.emb2hid(embedded).permute(0, 2, 1)for conv in self.convs:padding = torch.zeros(batch_size, conv_input.shape[1], self.kernel_size - 1).fill_(self.trg_pad_idx).to(self.device)padded_conv_input = torch.cat((padding, conv_input), dim=2)conved = F.glu(conv(padded_conv_input), dim=1)attention, conved = self.calculate_attention(embedded, conved, encoder_conved, encoder_combined)conved = (conved + conv_input) * self.scaleconv_input = convedconved = self.hid2emb(conved.permute(0, 2, 1))output = self.fc_out(self.dropout(conved))return output, attention

3. 模型训练

# 定义训练和评估函数
def train(model, loader, optimizer, criterion, clip, loader_length):model.train()epoch_loss = 0for src, trg in loader:src, trg = src.to(device), trg.to(device)optimizer.zero_grad()output, _ = model(src, trg[:, :-1])output = output.reshape(-1, output.shape[-1])trg = trg[:, 1:].reshape(-1)loss = criterion(output, trg)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), clip)optimizer.step()epoch_loss += loss.item()return epoch_loss / loader_lengthdef evaluate(model, loader, criterion, loader_length):model.eval()epoch_loss = 0with torch.no_grad():for src, trg in loader:src, trg = src.to(device), trg.to(device)output, _ = model(src, trg[:, :-1])output = output.reshape(-1, output.shape[-1])trg = trg[:, 1:].reshape(-1)loss = criterion(output, trg)epoch_loss += loss.item()return epoch_loss / loader_length

4. 实验与评估

        我们使用相同的数据集来训练和评估模型,并在模型性能上进行分析。由于 CNN 并行计算的优势,模型在处理速度上表现出显著提升。

# 训练模型
best_valid_loss = float('inf')
num_epochs = 10
clip = 0.1
for epoch in range(num_epochs):train_loss = train(model, train_loader, optimizer, criterion, clip, train_loader_length)valid_loss = evaluate(model, valid_loader, criterion, val_loader_length)if valid_loss < best_valid_loss:best_valid_loss = valid_losstorch.save(model.state_dict(), 'convseq2seq-model.pt')

结语

        在这篇文章中,我们成功实现了基于 CNN 的序列到序列学习模型,并结合 Teacher Forcing 技术加速了模型的训练。与传统的 RNNLSTM 模型相比,卷积神经网络具备强大的并行计算能力,使其在处理长序列时更加高效。虽然 CNN 模型没有显式的递归结构,但通过 卷积层位置嵌入,它能够有效捕捉输入序列中的时序信息。

        此外,尽管 CNN 模型不能像 RNN 那样灵活地使用 Teacher Forcing,但其速度优势使得它在某些任务中表现出色。在接下来的文章中,我们将探索更加复杂的 Transformer 模型,该模型也具备并行处理能力,并且在多个自然语言处理任务中已证明了其强大的性能。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/147548.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

增强现实系列—Map-Relative Pose Regression for Visual Re-Localization

&#x1f31f;&#x1f31f; 欢迎来到我的技术小筑&#xff0c;一个专为技术探索者打造的交流空间。在这里&#xff0c;我们不仅分享代码的智慧&#xff0c;还探讨技术的深度与广度。无论您是资深开发者还是技术新手&#xff0c;这里都有一片属于您的天空。让我们在知识的海洋中…

基于JAVA+SpringBoot+Vue的社区智慧养老监护管理平台

基于JAVASpringBootVue的社区智慧养老监护管理平台 前言 ✌全网粉丝20W,csdn特邀作者、博客专家、CSDN[新星计划]导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末附源码下载链接&#x1…

科研绘图系列:R语言多个AUC曲线图(multiple AUC curves)

文章目录 介绍加载R包导入数据数据预处理画图输出结果组图系统信息介绍 多个ROC曲线在同一张图上可以直观地展示和比较不同模型或方法的性能。这种图通常被称为ROC曲线图,它通过比较不同模型的ROC曲线下的面积(AUC)大小来比较模型的优劣。AUC值越大,模型的诊断或预测效果越…

前后端跨域问题及其在ThinkPHP中的解决方案

在现代Web开发中&#xff0c;前后端分离的架构越来越普遍&#xff0c;但这也带来了跨域问题。跨域指的是在一个域下的网页试图请求另一个域的资源&#xff0c;浏览器出于安全考虑会限制这种行为。本文将探讨如何在ThinkPHP中解决跨域问题。 #### 1. 什么是跨域&#xff1f; 跨…

一个皮肤科医生长痘的的自救

内服 复方锌铁钙口服液 丹参瞳胶囊 盐酸米诺环素胶囊 (每天一次) 内服 外用: 克林霉素甲硝搽剂 (泛红的痘痘) 人表皮生长因子(痘印)氢醌软膏 (点阵激光留下的色沉 早晚一次) 至少用两个月【痤疮|痘痘用药 一个皮肤科医生的自救】https://www.bilibili.com/video/BV1zu41…

算法题之每日温度

每日温度 给定一个整数数组 temperatures &#xff0c;表示每天的温度&#xff0c;返回一个数组 answer &#xff0c;其中 answer[i] 是指对于第 i 天&#xff0c;下一个更高温度出现在几天后。如果气温在这之后都不会升高&#xff0c;请在该位置用 0 来代替。 示例 1: 输入…

java计算机毕设课设—企业车辆管理系统(附源码、文章、相关截图、部署视频)

这是什么系统&#xff1f; 资源获取方式在最下方 java计算机毕设课设—企业车辆管理系统(附源码、文章、相关截图、部署视频) 企业车辆管理系统通过计算机&#xff0c;能够直接“透视”车辆使用情况&#xff0c;数据计算自动完成&#xff0c;尽量减少人工干预&#xff0c;可…

Java项目实战II基于Java+Spring Boot+MySQL的植物健康系统(开发文档+源码+数据库)

目录 目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发&#xff0c;CSDN平台Java领域新星创作者&#xff0c;专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 随着…

实战指南:深度剖析Servlet+JSP+JDBC技术栈下的用户CRUD操作

本博客总结基于MVC(JSPServletJDBC)操作用户信息的CRUD&#xff08;增删改查功能&#xff09;的完整小项目。包括图片上传和回显&#xff0c;模糊查询&#xff0c;过滤器的登录校验和设置全局字符集以及监听器统计在线用户人数等额外功能&#xff0c;因为代码较多&#xff0c;我…

UnLua实现继承

一、在蓝图中实现继承 1、创建父类&#xff0c;并绑定Lua脚本 2、创建子类蓝图&#xff0c;如果先创建的子类&#xff0c;可以修改父类继承 注意&#xff0c;提示选择继承父类的接口&#xff01; 二、在Lua中实现继承 1、在父类Lua脚本中实现函数 BP_CharacterBase.lua func…

构建数字化生态系统:打造数字化转型中开放协作平台的最佳实践和关键实施技巧

在数字化转型浪潮中&#xff0c;企业如何确保成功实施至关重要。除了技术上的革新&#xff0c;企业还必须在战略执行、架构优化以及合规性管理等方面掌握最佳实践。随着云计算、大数据、人工智能等新兴技术的迅速发展&#xff0c;企业通过正确的实施技巧不仅能提升业务效率&…

Qemu开发ARM篇-3、qemu运行uboot演示

文章目录 1、运行uboot2、qemu常用命令 在上一篇Qemu开发ARM篇-2、uboot交叉编译文章中&#xff0c;我们搭建了交叉编译工具链&#xff0c;并成功进行了uboot的交叉编译&#xff0c;在该篇中&#xff0c;我们将演示如何利用qemu运行上一篇中交叉编译的uboot程序。 1、运行uboo…

计算机毕业设计之:基于微信小程序的学生考勤系统的设计与实现(源码+文档+讲解)

博主介绍&#xff1a; ✌我是阿龙&#xff0c;一名专注于Java技术领域的程序员&#xff0c;全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师&#xff0c;我在计算机毕业设计开发方面积累了丰富的经验。同时&#xff0c;我也是掘金、华为云、阿里云、InfoQ等平台…

Redis——redispluspls库hash及zset类型相关接口使用

文章目录 hash类型相关接口hset和hgethexistshdelhkeys 和 hvalshmset和hmget zset类型相关接口zadd和zrangezcard 和 zremzscore和zrank hash类型相关接口 hset和hget std::cout<<"hset 和 hget"<<std::endl;redis.flushall();redis.hset("key&qu…

Java 分布式锁:原理与实践

在分布式系统中&#xff0c;多个节点同时操作共享资源的情况非常普遍。为了保证数据的一致性&#xff0c;分布式锁 应运而生。分布式锁 是一种跨多个服务器的互斥锁&#xff0c;用于协调分布式环境下的资源访问。 本文将介绍 Java 实现分布式锁 的几种常见方式&#xff0c;并结…

OpenAI API: How to catch all 5xx errors in Python?

题意&#xff1a;OpenAI API&#xff1a;如何在 Python 中捕获所有 5xx 错误&#xff1f; 问题背景&#xff1a; I want to catch all 5xx errors (e.g., 500) that OpenAI API sends so that I can retry before giving up and reporting an exception. 我想捕获 OpenAI API…

浙大数据结构:05-树8 File Transfer

数据结构MOOC PTA习题 这道题考察并查集的操作&#xff0c;合并以及找根结点 机翻&#xff1a; 1、条件准备 node是数组存放1-N结点的根节点的&#xff0c;n为总结点数 #include <iostream> using namespace std;const int N 1e4 5; int node[N]; int n; 先初始化…

<<编码>> 第 16 章 存储器组织(3)--3-8 译码器 示例电路

3-8 译码器 info::操作说明 “写入” 开关先断开(Q 为低电平, 表示不写入) S2-S1-S0 设置一个二进制数, 选中 Q0~Q7 其中一个作为 Q 的输出 “数据输入” 端置入要保存的数(0或1) 闭合 “写入” 开关, 对应的一位锁存器中的 W 为高电平, 表示可以写入, “数据输入” 的值最终…

嵌入式常用GUI介绍

目录 前言一、GuiLite二、LVGL三、SimpleGUI四、MiniGUI五、emWin六、TouchGFX七、uGUI八、GFX九、Embedded Wizard十、CrankSoftware十一、PEG Graphics Software十二、Guiliani十三、MPLAB Harmony Graphics Suite 前言 图形用户界面&#xff08;Graphical User Interface&am…

关系数据库设计之Armstrong公理详解

~犬&#x1f4f0;余~ “我欲贱而贵&#xff0c;愚而智&#xff0c;贫而富&#xff0c;可乎&#xff1f; 曰&#xff1a;其唯学乎” 一、Armstrong公理简介 Armstrong公理是一组在关系数据库理论中用于推导属性依赖的基本规则。这些公理是以著名计算机科学家威廉阿姆斯特朗&…