【MindSpore学习打卡】应用实践-自然语言处理-基于RNN的情感分类:使用MindSpore实现IMDB影评分类

情感分类是自然语言处理(NLP)中的一个经典任务,广泛应用于社交媒体分析、市场调研和客户反馈等领域。本篇博客将带领大家使用MindSpore框架,基于RNN(循环神经网络)实现一个情感分类模型。我们将详细介绍数据准备、模型构建、训练与评估等步骤,并最终实现对自定义输入的情感预测。

数据准备

下载和加载数据集

  • 为什么要使用requeststqdm库? requests库提供了简洁的HTTP请求接口,方便我们从网络上下载数据。tqdm库则用于显示下载进度条,帮助我们实时了解下载进度,提高用户体验。
  • 为什么要使用临时文件和shutil库? 使用临时文件可以确保下载的数据在出现意外中断时不会影响最终保存的文件。shutil库提供了高效的文件操作方法,确保数据能正确地从临时文件复制到最终保存路径。

我们使用IMDB影评数据集,这是一个经典的情感分类数据集,包含积极(Positive)和消极(Negative)两类影评。为了方便数据下载和处理,我们首先设计一个数据下载模块。

import os
import shutil
import requests
import tempfile
from tqdm import tqdm
from typing import IO
from pathlib import Path# 指定保存路径
cache_dir = Path.home() / '.mindspore_examples'def http_get(url: str, temp_file: IO):req = requests.get(url, stream=True)content_length = req.headers.get('Content-Length')total = int(content_length) if content_length is not None else Noneprogress = tqdm(unit='B', total=total)for chunk in req.iter_content(chunk_size=1024):if chunk:progress.update(len(chunk))temp_file.write(chunk)progress.close()def download(file_name: str, url: str):if not os.path.exists(cache_dir):os.makedirs(cache_dir)cache_path = os.path.join(cache_dir, file_name)cache_exist = os.path.exists(cache_path)if not cache_exist:with tempfile.NamedTemporaryFile() as temp_file:http_get(url, temp_file)temp_file.flush()temp_file.seek(0)with open(cache_path, 'wb') as cache_file:shutil.copyfileobj(temp_file, cache_file)return cache_path

下载IMDB数据集并进行解压和加载:

import re
import six
import string
import tarfileclass IMDBData():label_map = {"pos": 1, "neg": 0}def __init__(self, path, mode="train"):self.mode = modeself.path = pathself.docs, self.labels = [], []self._load("pos")self._load("neg")def _load(self, label):pattern = re.compile(r"aclImdb/{}/{}/.*\.txt$".format(self.mode, label))with tarfile.open(self.path) as tarf:tf = tarf.next()while tf is not None:if bool(pattern.match(tf.name)):self.docs.append(str(tarf.extractfile(tf).read().rstrip(six.b("\n\r")).translate(None, six.b(string.punctuation)).lower()).split())self.labels.append([self.label_map[label]])tf = tarf.next()def __getitem__(self, idx):return self.docs[idx], self.labels[idx]def __len__(self):return len(self.docs)

加载训练数据集进行测试,输出数据集数量:

imdb_path = download('aclImdb_v1.tar.gz', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/aclImdb_v1.tar.gz')
imdb_train = IMDBData(imdb_path, 'train')
print(len(imdb_train))

加载预训练词向量

为什么要使用预训练词向量? 预训练词向量(如Glove)是基于大规模语料库训练得到的,能够捕捉到丰富的语义信息。使用预训练词向量可以提高模型的性能,尤其是在数据量较小的情况下。

  • 为什么要添加<unk><pad>标记符? <unk>标记符用于处理词汇表中未出现的单词,避免模型在遇到新词时无法处理。<pad>标记符用于填充不同长度的文本序列,使其能够被打包为一个batch进行并行计算,提高训练效率。

我们使用Glove预训练词向量对文本进行编码。以下是加载Glove词向量的代码:

import zipfile
import numpy as npdef load_glove(glove_path):glove_100d_path = os.path.join(cache_dir, 'glove.6B.100d.txt')if not os.path.exists(glove_100d_path):glove_zip = zipfile.ZipFile(glove_path)glove_zip.extractall(cache_dir)embeddings = []tokens = []with open(glove_100d_path, encoding='utf-8') as gf:for glove in gf:word, embedding = glove.split(maxsplit=1)tokens.append(word)embeddings.append(np.fromstring(embedding, dtype=np.float32, sep=' '))embeddings.append(np.random.rand(100))embeddings.append(np.zeros((100,), np.float32))vocab = ds.text.Vocab.from_list(tokens, special_tokens=["<unk>", "<pad>"], special_first=False)embeddings = np.array(embeddings).astype(np.float32)return vocab, embeddingsglove_path = download('glove.6B.zip', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/glove.6B.zip')
vocab, embeddings = load_glove(glove_path)
print(len(vocab.vocab()))

数据集预处理

为什么要对文本进行分词和去除标点? 分词和去除标点是文本预处理的基础步骤,有助于模型更好地理解文本的语义。将文本转为小写可以减少词汇表的大小,提高模型的泛化能力。

对加载的数据集进行预处理,包括将文本转为index id序列,并进行序列填充。

import mindspore as ms
import mindspore.dataset as dslookup_op = ds.text.Lookup(vocab, unknown_token='<unk>')
pad_op = ds.transforms.PadEnd([500], pad_value=vocab.tokens_to_ids('<pad>'))
type_cast_op = ds.transforms.TypeCast(ms.float32)imdb_train = imdb_train.map(operations=[lookup_op, pad_op], input_columns=['text'])
imdb_train = imdb_train.map(operations=[type_cast_op], input_columns=['label'])imdb_test = imdb_test.map(operations=[lookup_op, pad_op], input_columns=['text'])
imdb_test = imdb_test.map(operations=[type_cast_op], input_columns=['label'])imdb_train, imdb_valid = imdb_train.split([0.7, 0.3])
imdb_valid = imdb_valid.batch(64, drop_remainder=True)

模型构建

接下来,我们构建用于情感分类的RNN模型。模型主要包括以下几层:

  1. Embedding层:将单词的index id转为词向量。
  2. RNN层:使用LSTM进行特征提取。
  3. 全连接层:将LSTM的输出特征映射到分类结果。

为什么选择LSTM而不是经典RNN? LSTM通过引入门控机制,有效地缓解了经典RNN中存在的梯度消失问题,能够更好地捕捉长距离依赖信息,提高模型的效果。

import math
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Uniform, HeUniformclass RNN(nn.Cell):def __init__(self, embeddings, hidden_dim, output_dim, n_layers, bidirectional, pad_idx):super().__init__()vocab_size, embedding_dim = embeddings.shapeself.embedding = nn.Embedding(vocab_size, embedding_dim, embedding_table=ms.Tensor(embeddings), padding_idx=pad_idx)self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, batch_first=True)weight_init = HeUniform(math.sqrt(5))bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)def construct(self, inputs):embedded = self.embedding(inputs)_, (hidden, _) = self.rnn(embedded)hidden = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)output = self.fc(hidden)return output

损失函数与优化器

我们使用二分类交叉熵损失函数nn.BCEWithLogitsLoss和Adam优化器。

为什么使用二分类交叉熵损失函数和Adam优化器? 二分类交叉熵损失函数适用于二分类任务,能够衡量模型预测结果与真实标签之间的差异。Adam优化器结合了动量和自适应学习率的优点,具有较快的收敛速度和较好的效果,广泛应用于深度学习模型的训练。

hidden_size = 256
output_size = 1
num_layers = 2
bidirectional = True
lr = 0.001
pad_idx = vocab.tokens_to_ids('<pad>')model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
optimizer = nn.Adam(model.trainable_params(), learning_rate=lr)

训练逻辑

设计训练一个epoch的函数,用于训练过程和loss的可视化。

def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return lossgrad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)def train_step(data, label):loss, grads = grad_fn(data, label)optimizer(grads)return lossdef train_one_epoch(model, train_dataset, epoch=0):model.set_train()total = train_dataset.get_dataset_size()loss_total = 0step_total = 0with tqdm(total=total) as t:t.set_description('Epoch %i' % epoch)for i in train_dataset.create_tuple_iterator():loss = train_step(*i)loss_total += loss.asnumpy()step_total += 1t.set_postfix(loss=loss_total/step_total)t.update(1)

评估指标和逻辑

设计评估逻辑,计算模型在验证集上的准确率。

def binary_accuracy(preds, y):rounded_preds = np.around(ops.sigmoid(preds).asnumpy())correct = (rounded_preds == y).astype(np.float32)acc = correct.sum() / len(correct)return accdef evaluate(model, test_dataset, criterion, epoch=0):total = test_dataset.get_dataset_size()epoch_loss = 0epoch_acc = 0step_total = 0model.set_train(False)with tqdm(total=total) as t:t.set_description('Epoch %i' % epoch)for i in test_dataset.create_tuple_iterator():predictions = model(i[0])loss = criterion(predictions, i[1])epoch_loss += loss.asnumpy()acc = binary_accuracy(predictions, i[1])epoch_acc += accstep_total += 1t.set_postfix(loss=epoch_loss/step_total, acc=epoch_acc/step_total)t.update(1)return epoch_loss / total

模型训练与保存

进行模型训练,并保存最优模型。

num_epochs = 2
best_valid_loss = float('inf')
ckpt_file_name = os.path.join(cache_dir, 'sentiment-analysis.ckpt')for epoch in range(num_epochs):train_one_epoch(model, imdb_train, epoch)valid_loss = evaluate(model, imdb_valid, loss_fn, epoch)if valid_loss < best_valid_loss:best_valid_loss = valid_lossms.save_checkpoint(model, ckpt_file_name)

模型加载与测试

加载已保存的最优模型,并在测试集上进行评估。

param_dict = ms.load_checkpoint(ckpt_file_name)
ms.load_param_into_net(model, param_dict)
imdb_test = imdb_test.batch(64)
evaluate(model, imdb_test, loss_fn)

自定义输入测试

设计一个预测函数,实现对自定义输入的情感预测。

score_map = {1: "Positive",0: "Negative"
}def predict_sentiment(model, vocab, sentence):model.set_train(False)tokenized = sentence.lower().split()indexed = vocab.tokens_to_ids(tokenized)tensor = ms.Tensor(indexed, ms.int32)tensor = tensor.expand_dims(0)prediction = model(tensor)

通过本文的学习,我们成功地使用MindSpore框架实现了一个基于RNN的情感分类模型。我们从数据准备开始,详细讲解了如何加载和处理IMDB影评数据集,以及使用预训练的Glove词向量对文本进行编码。然后,我们构建了一个包含Embedding层、LSTM层和全连接层的情感分类模型,并使用二分类交叉熵损失函数和Adam优化器进行训练。最后,我们评估了模型在测试集上的性能,并实现了对自定义输入的情感预测。希望这篇博客能帮助你更好地理解RNN在自然语言处理中的应用,并激发你在NLP领域的更多探索和实践。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1473884.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

UE5 07-给物体添加一个拖尾粒子

添加一个(旧版粒子系统)cascade粒子系统组件 ,在模板中选择一个开发学习初始包里的粒子

智慧文旅(景区)解决方案PPT(42页)

智慧文旅解决方案摘要 行业分析中国旅游业正经历消费大众化、需求品质化、发展全域化和产业现代化的发展趋势。《“十三五”旅游业发展规划》的发布&#xff0c;以及文化和旅游部的设立&#xff0c;标志着旅游业的信息化和智能化建设成为国家战略。2018年推出的旅游行业安全防范…

cs224n作业4

NMT结构图&#xff1a;&#xff08;具体结构图&#xff09; LSTM基础知识 nmt_model.py&#xff1a; 参考文章&#xff1a;LSTM输出结构描述 #!/usr/bin/env python3 # -*- coding: utf-8 -*-""" CS224N 2020-21: Homework 4 nmt_model.py: NMT Model Penchen…

2024年导游资格证题库备考题库,高效备考!

1.台湾著名的太鲁阁公园的特色是&#xff08;&#xff09;。 A.丘陵和溶洞 B.森林和瀑布 C.峡谷和断崖 D.彩林和彩池 答案&#xff1a;C 解析&#xff1a;台湾著名的太鲁阁公园的特色是峡谷和断崖。 2.下列位于台湾的景区中&#xff0c;素有"神秘的森林王国"之…

51单片机STC89C52RC——15.1 AD/DA(模数数模)

目的/效果 1 LCD1602 显示 可调电阻、光敏电阻、热敏电阻值&#xff08;AD&#xff09; 2 模拟信号控制LED明暗&#xff08;DA&#xff09; 一&#xff0c;STC单片机模块 二&#xff0c;AD/DA 2.1 AD/DA 介绍 AD&#xff08;Analog to Digital&#xff09;&#xff1a;模拟…

金丝键合强度测试仪试验条件要求:键合拉脱/引线拉力/剪切力等

金丝键合强度测试仪是测量引线键合强度&#xff0c;评估键合强度分布或测定键合强度是否符合有关的订购文件的要求。键合强度试验机可应用于采用低温焊、热压焊、超声焊或有关技术键合的、具有内引线的器件封装内部的引线-芯片键合、引线-基板键合或内引线一封装引线键合&#…

Redis的zset的zrem命令可以做到O(1)吗?

事情是这样的&#xff0c;当我用zrem命令去移除value的时候&#xff0c;我知道他之前会做的几个步骤 1、查找这个value对应的score&#xff08;通过zset中的dict&#xff09;2、根据这个score查找到跳表中的节点3、删除这个节点 我就想了一下为什么dict为什么要保存score呢&a…

非堆成加密是公私钥使用

对称加密学习-CSDN博客 加密算法学习-CSDN博客 非对称加密算法使用一对密钥&#xff0c;包括一个公钥和一个私钥&#xff0c;它们是数学上相关联的&#xff0c;但公钥可以公开分享&#xff0c;而私钥必须保密。以下是使用非对称加密算法的一般步骤&#xff1a; 密钥生成&…

2024年江苏省研究生数学建模竞赛B题火箭烟幕弹运用策略优化论文和代码分析

经过不懈的努力&#xff0c; 2024年江苏省研究生数学建模竞赛B题火箭烟幕弹运用策略优化论文和代码已完成&#xff0c;代码为B题全部问题的代码&#xff0c;论文包括摘要、问题重述、问题分析、模型假设、符号说明、模型的建立和求解&#xff08;问题1模型的建立和求解、问题2模…

基于Java中的SSM框架实现计算机类考研院校推荐系统项目【项目源码+论文说明】

基于Java中的SSM框架实现计算机类考研院校推荐系统演示 摘要 在互联网时代人们获取信息的方式变得非常快捷&#xff0c;登录网站搜索就能快速查找到相关的信息&#xff0c;但是网络上面的信息数量非常庞大&#xff0c;有很多信息虽然和自己搜索的相关&#xff0c;但并不是自己…

猫狗图像分类-划分数据集

&#x1f4da;博客主页&#xff1a;knighthood2001 ✨公众号&#xff1a;认知up吧 &#xff08;目前正在带领大家一起提升认知&#xff0c;感兴趣可以来围观一下&#xff09; &#x1f383;知识星球&#xff1a;【认知up吧|成长|副业】介绍 ❤️如遇文章付费&#xff0c;可先看…

网络-calico问题分析

项目场景&#xff1a; calico-node日志提示 Failed to auto-detect host MTU - no interfaces matched the MTU interface pattern. To use auto-MTU, set mtuifacePattern to match your hosts’s interfaes. 同时&#xff0c;cali开头网卡的mtu是1440大小 原因分析&#xff…

强技能 展风采 促提升——北京市大兴区餐饮行业职工技能竞赛精彩呈现

6月19日&#xff0c;由大兴区总工会、区商务局、青云店镇人民政府联合主办&#xff0c;区服务工会、区餐饮行业协会承办的“传承中国技艺&#xff0c;打造新一代餐饮工匠”2024年大兴区餐饮行业职工职业技能竞赛决赛在北京华联创新学习中心隆重开幕。区总工会副主席郝泽宏&…

Alibaba Cloud Toolkit前端使用proxy代理配置

1、vscode 先安装插件 Alibaba Cloud Toolkit 2、前端代码&#xff1a; /personnel: {// target: http://xxx.xx.xxx.xx:9100, // 测试环境// target: http://xxx.xx.xxx.xx:9200, // 线上环境target: http://127.0.0.1:18002, // toolkit 代理changeOrigin: true,},3、打开插…

【Pyhton】读取寄存器数据到MySQL数据库

目录 步骤 modsim32软件配置 Navicat for MySQL 代码实现 步骤 安装必要的库&#xff1a;确保安装了pymodbus和pymysql。 配置Modbus连接&#xff1a;设置Modbus从站的IP地址、端口&#xff08;对于TCP&#xff09;或串行通信参数&#xff08;对于RTU&#xff09;。 连接M…

第三方商城对接重构(HF202407)

文章目录 项目背景一、模块范围二、问题方案1. 商品模块整体来说这块对接的不是太顺利,梳理了几条大概的思路:2. 订单模块3. 售后4. 发票5. 结算单经验总结项目背景 作为供应商入围第三方商城成功,然后运营了一段时间,第三方通知要重构, 需要重新对接打通接口完成系统对接…

gcc/g++的四步编译

目录 前言1.预处理&#xff08;进行宏替换&#xff09;2.编译&#xff08;生成汇编&#xff09;3.汇编&#xff08;生成二进制文件&#xff09;4. 链接 &#xff08;生成可执行文件&#xff09;a. 动态库 && 动态链接b. 静态库 && 静态链接c. 验证d. 动静态链接…

SCI论文发表:构建清晰论文框架的10个原则 (附思维导图,建议收藏)

我是娜姐 迪娜学姐 &#xff0c;一个SCI医学期刊编辑&#xff0c;探索用AI工具提效论文写作和发表。 论文框架是什么&#xff1f;对我们完成一篇论文有哪些作用&#xff1f; 之前娜姐分享过一篇深圳湾实验室周耀旗教授关于论文写作的文章&#xff0c;他提出的第一个重要原则就…

VScode将界面语言设置为中文

1. 点击左侧的扩展图标&#xff0c;打开侧边栏“EXTENSIONS”面板。 2. 在搜索框中输入“Chinese”&#xff0c;查找出“中文简体”插件&#xff0c;点击“install”按钮。 3. 等待插件安装完成&#xff0c;点击右下角“restart”按钮&#xff0c;从而重新启动Vscode。

Linux多进程和多线程(七)进程间通信-信号量

进程间通信之信号量 资源竞争 多个进程竞争同一资源时&#xff0c;会发生资源竞争。 资源竞争会导致进程的执行出现不可预测的结果。 临界资源 不允许同时有多个进程访问的资源, 包括硬件资源 (CPU、内存、存储器以及其他外 围设备) 与软件资源(共享代码段、共享数据结构) …