HuggingFace情感分析任务微调

官方教程地址:https://huggingface.co/learn/nlp-course/zh-CN/chapter3/1?fw=pt

部分内容参考:

李福林, & 计算机技术. (2023). HuggingFace 自然语言处理详解: 基于 BERT 中文模型的任务实战. 清华大学出版社.

HuggingFace将AI项目研发分为四个步骤,准备数据集、定义模型、训练、测试,在此大方向下可以细化成几个小步骤,其中HuggingFace对此提供了一些工具集,具体如图:

HuggingFace开发流程和提供的工具集
因为一些原因可能无法连接到huggingface的服务器,所以可以在代码片中加入这一段来连接国内镜像

import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

1. 准备数据集

1.1 加载编码工具

这里我们选择微调的是IDEA-CCNL/Erlangshen-Roberta-330M-Sentiment模型,这个是一个情感分析模型,本身已经做的特别好了,其实微调意义已经不大甚至可能适得其反,但是这里仅为记录微调过程而并非真正需要优化模型。由于模型与编码器经常是成对出现,所以这里加载编码器也是选择IDEA-CCNL/Erlangshen-Roberta-330M-Sentiment

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Sentiment")

这里可以测试一下:

tokenizer(['房东已经不养猫了', '今天真的要减肥了'],truncation=True,max_length=512,
)

输出如下:

{'input_ids': [[101, 2791, 691, 2347, 5307, 679, 1075, 4344, 749, 102], [101, 791, 1921, 4696, 4638, 6206, 1121, 5503, 749, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

解释一下这个编码结果:

  1. input_ids
  • 将文本转换为数字序列的token ID
  • 每个ID对应词表中的一个token(词或子词)
  1. token_type_ids (也叫segment_ids)
  • 用于区分输入中不同的句子或文本段
  • 通常用0和1标记,0表示第一个句子,1表示第二个句子
  • 在单句任务中全部为0
  1. attention_mask
  • 用于标记哪些token应该被注意(1),哪些应该被忽略(0)
  • 主要用于处理变长序列的padding情况
  • 实际token为1,padding token为0

1.2 加载数据集

# from datasets import load_from_disk
from datasets import load_dataset# dataset = load_from_disk("/kaggle/working/Huggingface_Toturials/data/ChnSentiCorp")
dataset = load_dataset('lansinuote/ChnSentiCorp')
dataset['train'] = dataset['train'].shuffle().select(range(2000))
dataset['test'] = dataset['test'].shuffle().select(range(100))

加载数据集部分可以直接从网站一键下载,也可以手动下载了从磁盘载入,这里使用的是ChnSentiCorp,这里为简化运算,只取随机2000行,可以打印一下dataset结果如下

DatasetDict({train: Dataset({features: ['text', 'label'],num_rows: 2000})validation: Dataset({features: ['text', 'label'],num_rows: 1200})test: Dataset({features: ['text', 'label'],num_rows: 100})
})

1.3 数据集预处理

用刚刚加载进来的编码器编码

def f(data):return tokenizer.batch_encode_plus(data['text'], truncation=True, max_length=512)dataset = dataset.map(f, batched=True, remove_columns=['text'], batch_size=1000, num_proc=3)

打印dataset可以看到:

DatasetDict({train: Dataset({features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],num_rows: 2000})validation: Dataset({features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],num_rows: 1200})test: Dataset({features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],num_rows: 100})
})

模型一般对句子的长度有所限制,因此将长度超过512的句子截断或者过滤,这里为了编码方便简单的选择了删掉长度不合格的句子。

def f(data):return [len(i) <= 512 for i in data['input_ids']]
dataset = dataset.filter(f, batched=True, num_proc=3, batch_size=1000)

2. 定义模型和训练工具

2.1 加载预训练模型

先将模型加载进来

from transformers import AutoModelForSequenceClassification
import torchmodel = AutoModelForSequenceClassification.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Sentiment", num_labels=2)

简单计算下参数量

sum(p.numel() for p in model.parameters()) / 1e6

结果:

325.524482

参数量大概是325.5M

模型加载进来后进行简单的试算

data = {'input_ids': torch.ones(1, 10, dtype=torch.long),'attention_mask': torch.ones(1, 10, dtype=torch.long),'token_type_ids': torch.ones(1, 10, dtype=torch.long),'labels': torch.ones(1, dtype=torch.long)
}
out = model(**data)
out.loss, out.logits.shape

2.2 加载评价函数

import evaluate
metric = evaluate.load("accuracy")

这个评价函数接受的主要参数是一个预测值和一个标签值,与模型的输出不符,因此我们需要做一些处理。

import numpy as npdef compute_metrics(eval_pred):logits, labels = eval_predpredictions = np.argmax(logits, axis=1)acc = metric.compute(predictions=predictions, references=labels)return acc

测试一下这个函数

from transformers.trainer_utils import EvalPredictioneval_pred = EvalPrediction(predictions=np.array([[0, 1], [2, 3], [4, 5], [6, 7]]),label_ids=np.array([1, 1, 0, 1]))
compute_metrics(eval_pred) 

结果:

{'accuracy': 0.75}

2.3 定义训练函数

定义训练参数

from transformers import Trainer, TrainingArguments
import accelerate
# 参数
training_args = TrainingArguments(output_dir="./output_dir",evaluation_strategy="steps",learning_rate=2e-5,per_device_train_batch_size=4,per_device_eval_batch_size=4,num_train_epochs=2,weight_decay=0.01,eval_steps=20,no_cuda=False,report_to='none',
)

注意:report_to='none',在用colab或者kaggle时注意要加上,不然会让你输入api key,比较麻烦

构建训练器

from transformers import Trainer
from transformers import DataCollatorWithPaddingtrainer = Trainer(model=model,args=training_args,train_dataset=dataset['train'],eval_dataset=dataset['test'],data_collator=DataCollatorWithPadding(tokenizer),compute_metrics=compute_metrics,
)

上面的训练器中出现了一个常用的DataCollatorWithPadding对象,它的主要功能是将不同长度的序列补齐到同一长度,自动处理padding,使得一个batch内的所有样本长度一致。这里可以测试一下

# 测试数据整理函数
data_collator = DataCollatorWithPadding(tokenizer)
data = dataset['train'][:5]
for i in data['input_ids']:print(len(i))
data = data_collator(data)
for k, v in data.items():print(k, v.shape)

结果:

103
162
171
51
95
input_ids torch.Size([5, 171])
token_type_ids torch.Size([5, 171])
attention_mask torch.Size([5, 171])
labels torch.Size([5])

长度全部都补齐到171了

可以解码看看

tokenizer.decode(data['input_ids'][0])  # 解码

结果:

'[CLS] 看 了 两 边 , 第 一 感 觉 是 - - 很 一 般 。 内 容 上 , 真 不 敢 苟 同 , 还 号 称 是 学 术 明 星 , 于 丹 的 同 志 的 见 解 真 让 我 张 了 见 识 ! 页 数 大 可 以 压 缩 到 50 页 , 何 必 浪 费 纸 张 呢 ? 难 道 孔 夫 子 没 教 育 你 怎 么 搞 环 保 ? 评 论 到 此 结 束 , 懒 得 浪 费 我 得 笔 墨 ! [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

3. 训练和测试函数

在训练前,先看看模型本身的能力

trainer.evaluate()  # 评估

结果

{'eval_loss': 0.2859102189540863,'eval_accuracy': 0.97,'eval_runtime': 5.3759,'eval_samples_per_second': 18.602,'eval_steps_per_second': 2.418}

模型本身的准确率就有0.97,已经非常优秀了,本文目的不在于优化模型。

训练模型

trainer.train()

结果

 [500/500 09:48, Epoch 2/2]
Step	Training Loss	Validation Loss	Accuracy
20	No log	0.628869	0.940000
40	No log	0.231194	0.980000
60	No log	0.496170	0.930000
80	No log	0.381901	0.950000
100	No log	0.326569	0.940000
120	No log	0.262761	0.950000
140	No log	0.305643	0.960000
160	No log	0.266394	0.960000
180	No log	0.251125	0.960000
200	No log	0.268621	0.950000
220	No log	0.188149	0.980000
240	No log	0.365949	0.950000
260	No log	0.420138	0.940000
280	No log	0.337165	0.940000
300	No log	0.343916	0.950000
320	No log	0.427644	0.950000
340	No log	0.543159	0.930000
360	No log	0.514463	0.930000
380	No log	0.450759	0.940000
400	No log	0.422249	0.940000
420	No log	0.437221	0.940000
440	No log	0.448282	0.950000
460	No log	0.447407	0.950000
480	No log	0.447090	0.950000
500	0.068800	0.447007	0.950000

测试

trainer.evaluate()

4. 一键复制的python代码

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Sentiment")
from datasets import load_from_disk
from datasets import load_dataset
import torch
from transformers import AutoModelForSequenceClassification
import evaluate
import numpy as np
from transformers.trainer_utils import EvalPredictiondataset = load_dataset('lansinuote/ChnSentiCorp')
dataset['train'] = dataset['train'].shuffle().select(range(2000))
dataset['test'] = dataset['test'].shuffle().select(range(100))def f(data):return tokenizer.batch_encode_plus(data['text'], truncation=True, max_length=512)dataset = dataset.map(f, batched=True, remove_columns=['text'], batch_size=1000, num_proc=3)def f(data):return [len(i) <= 512 for i in data['input_ids']]
dataset = dataset.filter(f, batched=True, num_proc=3, batch_size=1000)model = AutoModelForSequenceClassification.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Sentiment", num_labels=2)metric = evaluate.load("accuracy")def compute_metrics(eval_pred):logits, labels = eval_predpredictions = np.argmax(logits, axis=1)acc = metric.compute(predictions=predictions, references=labels)return acc# 定义训练函数
from transformers import Trainer, TrainingArguments
import accelerate
# 参数
training_args = TrainingArguments(output_dir="./output_dir",evaluation_strategy="steps",learning_rate=2e-5,per_device_train_batch_size=4,per_device_eval_batch_size=4,num_train_epochs=2,weight_decay=0.01,eval_steps=20,no_cuda=False,report_to='none',
)
# 训练器
from transformers import Trainer
from transformers import DataCollatorWithPaddingtrainer = Trainer(model=model,args=training_args,train_dataset=dataset['train'],eval_dataset=dataset['test'],data_collator=DataCollatorWithPadding(tokenizer),compute_metrics=compute_metrics,
)
trainer.train()
trainer.evaluate()  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/3924.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Springboot——对接支付宝实现扫码支付

文章目录 前言官方文档以及说明1、申请沙箱2、进入沙箱获取对应的关键信息3、拿到系统生成的公钥和密钥 注意事项创建springboot项目1、引入依赖2、配置连接参数3、创建配置类&#xff0c;用于接收这些参数4、中间类的定义(订单类)5、编写测试接口场景一、pc端请求后端后&#…

迪杰斯特拉算法

迪杰斯特拉算法 LeetCode 743. 网络延迟时间 https://blog.csdn.net/xiaoxi_hahaha/article/details/110257368 import sysdef dijkstra(graph, source):"""dijkstra算法:param graph: 邻接矩阵:param source: 出发点&#xff0c;源点:return:""&…

STL学习-容器适配器

一.stack栈 1.栈的介绍 stack 栈是一种只在一端(栈顶)进行数据插入(入栈)和删除(出栈)的数据结构,它满足后进 先出(LIFO)的特性。 使用push(入栈)将数据放入stack,使用pop(出栈)将元素从容器中移除。 栈的结构如图&#xff1a; 在头文件<stack>中&#xff0c;class st…

【C语言】动态内存开辟

写在前面 C语言中有不少开辟空间的办法&#xff0c;但是在堆上开辟的方法也就只有动态内存开辟&#xff0c;其访问特性与数组相似&#xff0c;但最大区别是数组是开辟在栈上&#xff0c;而动态内存开辟是开辟在堆上的。这篇笔记就让不才娓娓道来。 PS:本篇没有目录实在抱歉CSD…

海的记忆:海滨学院班级回忆录项目

4系统概要设计 4.1概述 本系统采用B/S结构(Browser/Server,浏览器/服务器结构)和基于Web服务两种模式&#xff0c;是一个适用于Internet环境下的模型结构。只要用户能连上Internet,便可以在任何时间、任何地点使用。系统工作原理图如图4-1所示&#xff1a; 图4-1系统工作原理…

【VScode】C/C++多文件夹下、多文件引用、分别编译——仅一个设置【适合新人入手】

【VScode】C/C多文件夹内的多文件引用编译 1、问题2、前提&#xff08;最简环境&#xff09;3、核心&#xff08;关键配置&#xff09;4、成功享用~ 1、问题 在使用 VScode 编写一个简单项目的时候&#xff0c;没有特别配置的情况下&#xff0c;若主文件(.c)引用了自定义的头文…

62 mysql 中 存储引擎MyISAM 中索引的使用

前言 固定数据表 mysql. tables_priv 的表结构创建如下 CREATE TABLE tables_priv (Host char(60) COLLATE utf8_bin NOT NULL DEFAULT ,Db char(64) COLLATE utf8_bin NOT NULL DEFAULT ,User char(32) COLLATE utf8_bin NOT NULL DEFAULT ,Table_name char(64) COLLATE u…

使用buildx构建多架构平台镜像

1. 查看buildx插件信息 比较新的docker-ce版本默认已经集成了buildx插件 [rootdocker ~]# docker buildx version github.com/docker/buildx v0.11.2 9872040 [rootdocker ~]#2. 增加多平台镜像构建支持 通过tonistiigi/binfmt:latest初始化一个基于容器的构建环境&#xff…

数据库基础(3) . Navicat使用

0.下载安装 官网 : https://www.navicat.com.cn/ Navicat 中国 | 支持 MySQL、Redis、MariaDB、MongoDB、SQL Server、SQLite、Oracle 和 PostgreSQL 的数据库管理 1.连接数据库 1.1.连接 1.1.1.点击连接 打开navicat 点击 左上角连接 1.1.2.选择MySQL 弹出配置界面 1.1…

MySQL(上)

一、SQL优化 1、如何定位及优化SQL语句的性能问题&#xff1f;创建的索引有没有被使用到?或者说怎么才可以知道这条语句运行很慢的原因&#xff1f; 对于性能比较低的sql语句定位&#xff0c;最重要的也是最有效的方法其实还是看sql的执行计划&#xff0c;而对于mysql来说&a…

国密SM2 非对称加解密前后端工具

1.依赖 <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.8.21</version></dependency><dependency><groupId>org.bouncycastle</groupId><artifactId>bcpki…

【银河麒麟操作系统】软raid重建速度限制问题分析

了解更多银河麒麟操作系统全新产品&#xff0c;请点击访问 麒麟软件产品专区&#xff1a;https://product.kylinos.cn 开发者专区&#xff1a;https://developer.kylinos.cn 文档中心&#xff1a;https://documentkylinos.cn 现象描述 遇到软raid重建速度问题&#xff0c;分…

ssm教室信息管理系统+vue

系统包含&#xff1a;源码论文 所用技术&#xff1a;SpringBootVueSSMMybatisMysql 免费提供给大家参考或者学习&#xff0c;获取源码看文章最下面 需要定制看文章最下面 目 录 目 录 III 1 绪论 1 1.1 研究背景 1 1.2目的和意义 1 1.3 论文结构安排 2 2 相关技术 3 …

去中心化存储:Web3中的数据安全新标准

随着Web3的兴起&#xff0c;去中心化存储逐渐成为数据安全的新标准。传统的中心化存储方式将数据集中保存在少数服务器上&#xff0c;这种模式尽管在早期互联网中被广泛应用&#xff0c;但随着数据量和数据价值的增加&#xff0c;其潜在的安全风险和隐私问题也逐渐暴露。而去中…

Ubuntu 22 安装 Apache Doris 3.0.3 笔记

Ubuntu 22 安装 Apache Doris 3.0.3 笔记 1. 环境准备 Doris 需要 Java 17 作为运行环境&#xff0c;所以首先需要安装 Java 17。 sudo apt-get install openjdk-17-jdk -y sudo update-alternatives --config java在安装 Java 17 后&#xff0c;可以通过 sudo update-alter…

安卓摄像头的详细使用

安卓摄像头的详细使用 一、引言二、权限设置三、打开摄像头四、摄像头的属性设置&#xff08;一&#xff09;预览尺寸&#xff08;二&#xff09;图片格式&#xff08;三&#xff09;对焦模式 五、摄像头预览六、拍照功能七、视频录制 一、引言 在安卓开发中&#xff0c;摄像头…

服务器的配置复杂,租用时该如何选择参数?

对于互联网企业来说&#xff0c;开发一套可以接入互联网的产品&#xff0c;并利用它来盈利是终极目的。但互联网产品必须有服务器才能运行&#xff0c;对于很多公司来说&#xff0c;托管服务器成本太高&#xff0c;而租用服务器才算得上是最好的选择&#xff0c;但面对配置参数…

10min本地安装Qwen1.5-0.5B-Chat

大模型系列文章 本地电脑离线部署大模型 配置&#xff1a;MAC-M1-8GB 10min本地安装Qwen1.5-0.5B-Chat 大模型系列文章前言一、下载Qwen1.5-0.5B-Chat二、构造函数chatBot.py三、启动命令1、放置脚本2、启动命令3、效果图 前言 在人工智能领域&#xff0c;大模型无疑是最炙手…

90%会展主办方都会用的6款数字化工具

在会展行业&#xff0c;数字化转型已成为提升竞争力的关键。面对日益增长的运营成本和收入增长的瓶颈&#xff0c;主办方需要借助数字化工具来实现效率提升和成本控制。 今天介绍几种常见的数字化工具和应用方式。 一、线上展览平台 构建线上展览平台是会展主办方拓展线上销…

弃用 RestTemplate,来了解一下官方推荐的 WebClient !

在 Spring Framework 5.0 及更高版本中&#xff0c;RestTemplate 已被弃用&#xff0c;取而代之的是较新的 WebClient。这意味着虽然 RestTemplate 仍然可用&#xff0c;但鼓励 Spring 开发人员迁移到新项目的 WebClient。 WebClient 优于 RestTemplate 的原因有几个&#xff…