基于Python的自然语言处理系列(10):使用双向LSTM进行文本分类

        在前一篇文章中,我们介绍了如何使用RNN进行文本分类。在这篇文章中,我们将进一步优化模型,使用双向多层LSTM来替代RNN,从而提高模型在序列数据上的表现。LSTM通过引入一个额外的记忆单元(cell state)来解决标准RNN中的梯度消失问题。此外,双向LSTM能够同时考虑句子前后的信息,进一步提高模型的性能。

1. LSTM与RNN的区别

        标准RNN容易在处理长序列时出现梯度消失或爆炸的现象,导致模型难以学习长期依赖。LSTM通过引入一个额外的cell state来存储和控制长期信息的流动,避免了梯度消失的问题。具体来说,LSTM使用了三个门来控制信息的流动:输入门、遗忘门和输出门。

        LSTM的计算公式如下:

        我们将在本文中实现一个双向多层LSTM,即同时使用正向和反向的LSTM来处理文本序列。

2. 数据预处理与FastText词嵌入

        首先,我们加载数据集,并使用与前面文章类似的预处理方法,包括使用spacy进行标记化、创建词汇表,并引入预训练的FastText词嵌入。

from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import FastText# 加载数据集
train, test = AG_NEWS()# 使用spacy进行标记化
tokenizer = get_tokenizer('spacy', language='en_core_web_sm')# 构建词汇表
def yield_tokens(data_iter):for _, text in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train), specials=['<unk>', '<pad>'])
vocab.set_default_index(vocab["<unk>"])# 引入FastText词嵌入
fast_vectors = FastText(language='simple')
fast_embedding = fast_vectors.get_vecs_by_tokens(vocab.get_itos()).to(device)

3. LSTM模型设计

        在这部分中,我们设计了一个双向多层LSTM模型。我们使用nn.LSTM代替nn.RNN,并通过设置bidirectional=True来启用双向LSTM。此外,我们还将使用多层LSTM,通过设置num_layers=2来增加模型的复杂度。

import torch.nn as nnclass LSTM(nn.Module):def __init__(self, input_dim, emb_dim, hid_dim, output_dim, num_layers, bidirectional, dropout):super().__init__()# 嵌入层self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=vocab['<pad>'])# 双向多层LSTMself.lstm = nn.LSTM(emb_dim, hid_dim, num_layers=num_layers, bidirectional=bidirectional, dropout=dropout,batch_first=True)# 全连接层,接收双向LSTM的输出,因此乘以2self.fc = nn.Linear(hid_dim * 2, output_dim)def forward(self, text, text_lengths):# 嵌入层embedded = self.embedding(text)# 打包序列packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths.to('cpu'), enforce_sorted=False, batch_first=True)# 通过LSTMpacked_output, (hn, cn) = self.lstm(packed_embedded)# 解包序列output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)# 拼接正向和反向LSTM的输出hn = torch.cat((hn[-2,:,:], hn[-1,:,:]), dim=1)return self.fc(hn)

4. 训练与评估

        我们将使用Adam优化器,并在训练过程中计算模型的损失和准确率。以下是完整的训练与评估代码:

import torch.optim as optim# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)# 计算准确率
def accuracy(preds, y):predicted = torch.max(preds.data, 1)[1]batch_corr = (predicted == y).sum()acc = batch_corr / len(y)return acc# 训练函数
def train(model, loader, optimizer, criterion, loader_length):epoch_loss = 0epoch_acc = 0model.train()for i, (label, text, text_length) in enumerate(loader): label = label.to(device)text = text.to(device)# 前向传播predictions = model(text, text_length).squeeze(1)# 计算损失和准确率loss = criterion(predictions, label)acc  = accuracy(predictions, label)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()epoch_loss += loss.item()epoch_acc += acc.item()return epoch_loss / loader_length, epoch_acc / loader_length# 评估函数
def evaluate(model, loader, criterion, loader_length):epoch_loss = 0epoch_acc = 0model.eval()with torch.no_grad():for i, (label, text, text_length) in enumerate(loader): label = label.to(device)text = text.to(device)predictions = model(text, text_length).squeeze(1)loss = criterion(predictions, label)acc  = accuracy(predictions, label)epoch_loss += loss.item()epoch_acc += acc.item()return epoch_loss / loader_length, epoch_acc / loader_length

        我们通过5个epoch训练模型,并保存最佳模型的状态。

num_epochs = 5
best_valid_loss = float('inf')for epoch in range(num_epochs):train_loss, train_acc = train(model, train_loader, optimizer, criterion, len(train_loader))valid_loss, valid_acc = evaluate(model, valid_loader, criterion, len(valid_loader))if valid_loss < best_valid_loss:best_valid_loss = valid_losstorch.save(model.state_dict(), 'best-model.pt')print(f'Epoch {epoch+1} | Train Loss: {train_loss:.3f}, Train Acc: {train_acc*100:.2f}%')print(f'Valid Loss: {valid_loss:.3f}, Valid Acc: {valid_acc*100:.2f}%')

5. 测试与预测

        训练完成后,我们可以使用模型对新文本进行预测。以下是如何使用训练好的模型预测随机新闻文本的类别:

def predict(text, text_length):with torch.no_grad():output = model(text, text_length).squeeze(1)predicted = torch.max(output.data, 1)[1]return predictedtest_str = "Google is now facing challenges in its business strategy."
text = torch.tensor(text_pipeline(test_str)).unsqueeze(0).to(device)
text_length = torch.tensor([text.size(1)]).to(device)prediction = predict(text, text_length)
print(f'预测结果: {prediction.item()}')

结语

        在这篇文章中,我们通过引入双向LSTM改进了文本分类模型的性能。LSTM通过其独特的记忆单元门控机制,有效解决了传统RNN中存在的梯度消失问题,从而能够更好地捕捉长序列中的依赖关系。此外,双向LSTM的加入使模型不仅能够关注序列的前向信息,还能同时捕捉序列中的反向信息,这在处理自然语言中尤为重要。毕竟,在许多语言表达中,句子前后的词语和短语之间存在密切关联,双向LSTM的设计帮助我们更全面地理解文本中的语义。

        通过实验,我们观察到,双向多层LSTM能够显著提升文本分类任务的准确性。相较于传统RNN,LSTM不仅能够捕捉更长时间步的依赖,还通过多层结构让模型具有更深的语义理解能力。使用双向LSTM,模型在多个方向上进行信息处理,进一步提升了模型的学习能力。

        尽管LSTM在序列建模中展现了其优势,但它依然存在一些局限性。例如,当处理极长的序列时,LSTM的效率可能会受到影响。此外,虽然双向LSTM能够提供更好的上下文信息,但它的计算量也相应增加,尤其是当模型层数增加时,训练时间可能会大幅增长。因此,在实际应用中,我们还需要根据具体的任务场景平衡模型的性能和计算成本。

        在未来的研究和实践中,我们可以继续探索更为先进的模型,如Transformer,它在并行计算和长序列建模方面展现了强大的能力。此外,我们也可以尝试将LSTM与其他模型(如卷积神经网络CNN)结合,进一步提高模型的表达能力。

        总的来说,LSTM为处理自然语言中的序列数据提供了强大的工具,尤其是在文本分类、机器翻译、序列标注等任务中具有广泛的应用前景。通过掌握LSTM及其变种模型,开发者可以在更多复杂的自然语言处理任务中获得显著的性能提升。

        在下一篇文章中,我们将探索如何使用**卷积神经网络(CNN)**进行文本分类,CNN以其在图像处理中的成功经验,也能为文本分类任务提供一种有效的建模方式。我们将讨论如何将CNN应用于自然语言处理任务中,并通过实验验证其效果。敬请期待!

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1536880.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

24.Redis实现全局唯一ID

是一种分布式系统下用来生成全局唯一ID的工具。 特点 1.唯一性 2.高可用 3.高性能 4.递增性&#xff0c;数据也要保持一种递增&#xff0c;有利于数据库进行查询。 5.安全性 全局唯一ID的生成策略 1.UUID(没有顺序&#xff0c;字符串类型&#xff0c;效率不高) 2.Redis…

【电路笔记】-差分运算放大器

差分运算放大器 文章目录 差分运算放大器1、概述2、差分运算放大器表示2.1 差分模式2.2 减法器模式3、差分放大器示例3.1 相关电阻3.2 惠斯通桥3.3 光/温度检测4、仪表放大器5、总结1、概述 在之前的文章中,我们讨论了反相运算放大器和同相运算放大器,我们考虑了在运算放大器…

floodfill算法(二)

目录 一、太平洋大西洋水流问题 1. 题目链接&#xff1a;417. 太平洋大西洋水流问题 2. 题目描述&#xff1a; 3. 解法 &#x1f334;算法思路&#xff1a; &#x1f334;算法代码&#xff1a; 二、扫雷游戏 1. 题目链接&#xff1a;529. 扫雷游戏 2. 题目描述&#xf…

softmax回归的从零实现(附代码)

softmax回归是一个多分类模型&#xff0c;但是他跟线性回归一样将输入特征与权重做线性叠加&#xff0c;与线性不同的是他有多个输出&#xff0c;输出的个数对应分类标签的个数&#xff0c;比如四个特征和三种输出动物类别&#xff0c;则权重包含12个标量&#xff08;带下标的w…

深度学习之线性代数预备知识点

概念定义公式/案例标量(Scalar)一个单独的数值&#xff0c;表示单一的量。例如&#xff1a;5, 3.14, -2向量 (Vector)一维数组&#xff0c;表示具有方向和大小的量。 &#xff0c;表示三维空间中的向量 模(Magnitude)向量的长度&#xff0c;也称为范数&#xff08;通常为L2范数…

HCIA--实验十六:ACL通信实验(2)

2.高级ACL配置 一、实验内容 1.需求/要求&#xff1a; 使用三台PC和一台交换机&#xff0c;在交换机上配置高级ACL&#xff0c;测试PC1、PC2、PC3间的连通性。 二、实验过程 1.拓扑图&#xff1a; 2.步骤&#xff1a; 1.给PC3配置ip地址&#xff1a; 2.给交换机SW3配置高…

Hello,Spring Boot...

今天开启了Spring Boot学习之旅。 首先就是&#xff0c;JDK、Maven、IDEA以及各种官网的下载、安装与配置 然后通过组件创建小类&#xff0c;最让人头痛的就是&#xff0c;这个spring-boot-starter-thymeleaf&#xff0c;下错版本了 其他的一切顺利&#xff0c;自动化明显 最后…

2024最新版mysql数据库表的查询操作-总结

序言 1、MySQL表操作(创建表&#xff0c;查询表结构&#xff0c;更改表字段等)&#xff0c; 2、MySQL的数据类型(CHAR、VARCHAR、BLOB,等)&#xff0c; 本节比较重要&#xff0c;对数据表数据进行查询操作&#xff0c;其中可能大家不熟悉的就对于INNER JOIN(内连接)、LEFT JOIN…

Learn ComputeShader 15 Grass

1.Using Blender to create a single grass clump 首先blender与unity的坐标轴不同&#xff0c;z轴向上&#xff0c;不是y轴 通过小键盘的数字键可以快速切换视图&#xff0c;选中物体以后按下小键盘的点可以将物体聚焦于屏幕中心 首先我们创建一个平面&#xff0c;宽度为0.2…

SpringBoot中使用EasyExcel并行导出多个excel文件并压缩zip后下载

❃博主首页 &#xff1a; 「码到三十五」 &#xff0c;同名公众号 :「码到三十五」&#xff0c;wx号 : 「liwu0213」 ☠博主专栏 &#xff1a; <mysql高手> <elasticsearch高手> <源码解读> <java核心> <面试攻关> ♝博主的话 &#xff1a…

SysML图例-农业无人机

DDD领域驱动设计批评文集>> 《软件方法》强化自测题集>> 《软件方法》各章合集>>

dll修复工具4DDiG DLL Fixer,解决电脑dll丢失问题

4DDiG DLL Fixer是一款专业的DLL修复工具&#xff0c;旨在解决Windows系统中各种DLL相关问题。该工具能够快速全面地扫描计算机&#xff0c;检测并修复导致程序功能异常的DLL错误。它支持一键式操作&#xff0c;自动扫描、识别和替换缺失或损坏的DLL文件&#xff0c;从而帮助用…

推荐3款AIai论文大纲一键生成文献,精选整理!

在当前的学术写作环境中&#xff0c;AI论文大纲生成工具已经成为许多学者和学生的重要助手。这些工具不仅能够快速生成高质量的论文大纲&#xff0c;还能提供内容填充、文献引用和查重修改等全方位的服务。以下是三款值得推荐的AI论文大纲一键生成文献工具&#xff1a;千笔-AIP…

爬虫--翻页tips

免责声明&#xff1a;本文仅做分享&#xff01; 伪线程 from DrissionPage import ChromiumPage import timepage ChromiumPage() page.get("https://you.ctrip.com/sight/taian746.html") # 初始化 第0页 index_page 0# 翻页点击函数 sleep def page_turn():page…

C/C++语言基础--从C到C++的不同(下),15个部分说明C与C++的不同

本专栏目的 更新C/C的基础语法&#xff0c;包括C的一些新特性 前言 1-10在上篇C/C语言基础–从C到C的不同(上&#xff09;&#xff1b;当然C和C的不同还有很多&#xff0c;本人暂时只总结这些&#xff0c;其他的慢慢更新&#xff1b;上一篇C/C语言基础–从C到C的不同(上&…

node.js 中的进程和线程工作原理

本文所有的代码均基于 node.js 14 LTS 版本分析 概念 进程是对正在运行中的程序的一个抽象&#xff0c;是系统进行资源分配和调度的基本单位&#xff0c;操作系统的其他所有内容都是围绕着进程展开的 线程是操作系统能够进行运算调度的最小单位&#xff0c;其是进程中的一个执…

康养小站:长者舒缓疼痛的港湾

【导语】在老龄化日益加剧的当下&#xff0c;如何关爱和照顾好长者&#xff0c;成为社会关注的焦点。近日&#xff0c;笔者走进深圳宝安区一家专注于长者康养的社区小站&#xff0c;探访它如何帮助长者缓解疼痛&#xff0c;提高生活质量。 随着我国人口老龄化问题日益显著&…

算法:30.串联所有单词的子串

题目 链接&#xff1a;leetcode链接 思路分析&#xff08;滑动窗口&#xff09; 这道题目类似寻找异位词的题目&#xff0c;我认为是寻找异位词的升级版 传送门:寻找异位词 为什么说像呢&#xff1f; 注意&#xff1a;这道题目中words数组里面的字符串长度都是相同的&…

[JAVA]介绍怎样在Java中通过字节字符流实现文件读取与写入

一&#xff0c;初识File类及其常用方法 File类是java.io包下代表与平台无关的文件和目录&#xff0c;程序中操作文件和目录&#xff0c;都可以通过File类来完成。 通过这个File对象&#xff0c;可以进行一系列与文件相关的操作&#xff0c;比如判断文件是否存在&#xff0c;获…

Java毕业设计 基于SpringBoot和Vue药店管理系统

Java毕业设计 基于SpringBoot和Vue药店管理系统 这篇博文将介绍一个基于SpringBoot框架和Vue开发的药店管理系统&#xff0c;适合用于Java毕业设计。 功能介绍 首页 图片轮播 登录 注册 药品信息 药品详情 评论 收藏 购买 添加到购物车 用药指南 公告资讯 购物车 …