聊聊基于BERT模型实现多标签分类任务的实践与思考

概述

以预训练大模型为基座神经网络模型,通过模型预训练后的泛化能力与微调后的领域能力,作为NLP任务的解决方案。

在github上找了一个简单的仓库——multi_label_classification,该仓库基于BERT预训练大模型实现了多分类任务。通过对该仓库源码的分析,深入研究其逻辑原理。

代码文件的注释如下:
在这里插入图片描述

模型类定义

基于BERT预训练模型,结合线性层与sigmoid函数,输出多分类任务的概率分布

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel# 基于BERT模型的多分类模型定义
class BertMultiLabelCls(nn.Module):def __init__(self, hidden_size, class_num, dropout=0.1):super(BertMultiLabelCls, self).__init__()# 定义线性层(密集层)# 输入的维度是hidden_size;输出的维度是多分类的个数。self.fc = nn.Linear(hidden_size, class_num)self.drop = nn.Dropout(dropout)# 加载预训练模型self.bert = BertModel.from_pretrained("bert-base-chinese")def forward(self, input_ids, attention_mask, token_type_ids):# 将输入的input_ids(文本token的ID),# attention_mask(表示哪些token是重要的)# token_type_ids(区分不同类型的token,如句子A和句子B)传递给BERT模型,获取模型的输出。outputs = self.bert(input_ids, attention_mask, token_type_ids)# 打印输出# print(outputs)# 从BERT模型的输出中选择第一个元素(通常是[CLS]标记的输出),然后通过dropout层。cls = self.drop(outputs[1])# 将经过dropout层的[CLS]标记的输出传递到全连接层self.fc,然后应用sigmoid激活函数,将输出转换为概率分布。out = F.sigmoid(self.fc(cls))# 返回最终的输出概率分布,用于多分类任务。return out

Sigmoid函数通常用于二分类问题的神经网络中;Softmax函数通常用于多分类问题的神经网络中。而这里用到的是Sigmoid函数,这个原因在后面也会讲到。

数据预处理

一般都会有一个数据预处理的流程,因为模型的数据需要处理成指定的JSON格式后,才能喂到模型训练。为什么要有对应的格式?因为模型的训练与微调有对应的格式要求。

譬如,alpaca格式数据集与sharegpt格式数据集等常见格式要求。因此就会要求将原始数据处理成对应格式。

# -*- coding: utf-8 -*-
import json
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
from data_preprocess import load_jsonclass MultiClsDataSet(Dataset):def __init__(self, data_path, max_len=128, label2idx_path="./data/label2idx.json"):self.label2idx = load_json(label2idx_path)self.class_num = len(self.label2idx)self.tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")self.max_len = max_lenself.input_ids, self.token_type_ids, self.attention_mask, self.labels = self.encoder(data_path)def encoder(self, data_path):texts = []labels = []with open(data_path, encoding="utf-8") as f:for line in f:line = json.loads(line)texts.append(line["text"])tmp_label = [0] * self.class_numfor label in line["label"]:tmp_label[self.label2idx[label]] = 1labels.append(tmp_label)tokenizers = self.tokenizer(texts,padding=True,truncation=True,max_length=self.max_len,return_tensors="pt",is_split_into_words=False)input_ids = tokenizers["input_ids"]token_type_ids = tokenizers["token_type_ids"]attention_mask = tokenizers["attention_mask"]return input_ids, token_type_ids, attention_mask, \torch.tensor(labels, dtype=torch.float)def __len__(self):return len(self.labels)def __getitem__(self, item):return self.input_ids[item],  self.attention_mask[item], \self.token_type_ids[item], self.labels[item]if __name__ == '__main__':dataset = MultiClsDataSet(data_path="./data/train.json")print(dataset.input_ids)print(dataset.token_type_ids)print(dataset.attention_mask)print(dataset.labels)
# -*- coding: utf-8 -*-"""
数据预处理
"""import jsondef load_json(data_path):with open(data_path, encoding="utf-8") as f:return json.loads(f.read())def dump_json(project, out_path):with open(out_path, "w", encoding="utf-8") as f:json.dump(project, f, ensure_ascii=False)def preprocess(train_data_path, label2idx_path, max_len_ratio=0.9):""":param train_data_path::param label2idx_path::param max_len_ratio::return:"""labels = []text_length = []with open(train_data_path, encoding="utf-8") as f:for data in f:data = json.loads(data)text_length.append(len(data["text"]))labels.extend(data["label"])labels = list(set(labels))label2idx = {label: idx for idx, label in enumerate(labels)}dump_json(label2idx, label2idx_path)text_length.sort()print("当设置max_len={}时,可覆盖{}的文本".format(text_length[int(len(text_length)*max_len_ratio)], max_len_ratio))if __name__ == '__main__':preprocess("./data/train.json", "./data/label2idx.json")

训练

训练的源码如下:

# -*- coding: utf-8 -*-import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AdamW
import numpy as np
from data_preprocess import load_json
from bert_multilabel_cls import BertMultiLabelCls
from data_helper import MultiClsDataSet
from sklearn.metrics import accuracy_scoretrain_path = "./data/train.json"
dev_path = "./data/dev.json"
test_path = "./data/test.json"
label2idx_path = "./data/label2idx.json"
save_model_path = "./model/multi_label_cls.pth"
label2idx = load_json(label2idx_path)
class_num = len(label2idx)
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 2e-5
batch_size = 128
max_len = 128
hidden_size = 768
epochs = 10# 预处理数据
train_dataset = MultiClsDataSet(train_path, max_len=max_len, label2idx_path=label2idx_path)
dev_dataset = MultiClsDataSet(dev_path, max_len=max_len, label2idx_path=label2idx_path)# 从数据集中 批量 加载数据,批大小为batch_size
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dev_dataloader = DataLoader(dev_dataset, batch_size=batch_size, shuffle=False)# 计算准确率
def get_acc_score(y_true_tensor, y_pred_tensor):y_pred_tensor = (y_pred_tensor.cpu() > 0.5).int().numpy()y_true_tensor = y_true_tensor.cpu().numpy()return accuracy_score(y_true_tensor, y_pred_tensor)# 训练
def train():model = BertMultiLabelCls(hidden_size=hidden_size, class_num=class_num)# 启用 batch normalization 和 dropout 。model.train()model.to(device)# 定义优化器optimizer = AdamW(model.parameters(), lr=lr)# 定义了一个二进制交叉熵损失函数(BCELoss),用于多标签分类问题,因为它可以处理多个标签。criterion = nn.BCELoss()dev_best_acc = 0.# 按epoch训练,即训练轮数for epoch in range(1, epochs):# 启用 batch normalization 和 dropout 。model.train()# 按batch训练,即训练批次for i, batch in enumerate(train_dataloader):# 清空梯度optimizer.zero_grad()batch = [d.to(device) for d in batch]# 获取批数据中的标签label数据labels = batch[-1]# 执行预训练模型的forward方法logits = model(*batch[:3])# 通过二分类交叉熵损失,计算模型返回值与标签实际值的损失概率loss = criterion(logits, labels)# 反向传播loss.backward()# 梯度更新optimizer.step()# 打印数据if i % 100 == 0:acc_score = get_acc_score(labels, logits)print("Train epoch:{} step:{}  acc: {} loss:{} ".format(epoch, i, acc_score, loss.item()))# 验证集合dev_loss, dev_acc = dev(model, dev_dataloader, criterion)print("Dev epoch:{} acc:{} loss:{}".format(epoch, dev_acc, dev_loss))if dev_acc > dev_best_acc:dev_best_acc = dev_acctorch.save(model.state_dict(), save_model_path)# 测试test_acc = test(save_model_path, test_path)print("Test acc: {}".format(test_acc))# 验证
def dev(model, dataloader, criterion):all_loss = []# 切换成评估模式model.eval()true_labels = []pred_labels = []with torch.no_grad():for i, batch in enumerate(dataloader):input_ids, attention_mask, token_type_ids, labels = [d.to(device) for d in batch]logits = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)loss = criterion(logits, labels)all_loss.append(loss.item())true_labels.append(labels)pred_labels.append(logits)true_labels = torch.cat(true_labels, dim=0)pred_labels = torch.cat(pred_labels, dim=0)acc_score = get_acc_score(true_labels, pred_labels)return np.mean(all_loss), acc_score# 测试
def test(model_path, test_data_path):test_dataset = MultiClsDataSet(test_data_path, max_len=max_len, label2idx_path=label2idx_path)test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)model = BertMultiLabelCls(hidden_size=hidden_size, class_num=class_num)model.load_state_dict(torch.load(model_path))model.to(device)# 切换成评估模式model.eval()true_labels = []pred_labels = []with torch.no_grad():for i, batch in enumerate(test_dataloader):input_ids, attention_mask, token_type_ids, labels = [d.to(device) for d in batch]logits = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)true_labels.append(labels)pred_labels.append(logits)true_labels = torch.cat(true_labels, dim=0)pred_labels = torch.cat(pred_labels, dim=0)acc_score = get_acc_score(true_labels, pred_labels)return acc_scoreif __name__ == '__main__':train()

在上面的代码中,可以看到很多熟悉的参数。譬如epochsbatch_sizemax_lenhidden_size——这些参数一般在配置文件config中毕竟常见,而且都是有默认值。以chatglm3-6b的配置文件为例:

{"_name_or_path": "THUDM/chatglm-6b","architectures": ["ChatGLMModel"],"bos_token_id": 130004,"eos_token_id": 130005,"mask_token_id": 130000,"gmask_token_id": 130001,"pad_token_id": 3,"hidden_size": 4096,"inner_hidden_size": 16384,"layernorm_epsilon": 1e-05,"max_sequence_length": 2048,"model_type": "chatglm","num_attention_heads": 32,"num_layers": 28,"position_encoding_2d": true,"torch_dtype": "float16","use_cache": true,"vocab_size": 130528
}

在每个batch的训练中,其流程是这样的:

  1. 清空梯度
  2. 获取实际的标签
  3. 通过预训练模型输出预测值
  4. 基于二分类交叉熵损失函数,计算实际值与预测值的损失
  5. 将损失反向传播
  6. 更新梯度

严格来说,多分类任务应该用对应的损失函数,譬如CrossEntropyLoss或者NLLLoss( Negative Log Likelihood Loss)。这两个损失函数都是为多分类问题设计的,CrossEntropyLoss是更为常用的选择。

不过这里用二分类交叉熵损失函数也可以,即将多分类问题转化为多个二分类问题。所以在模型的定义中,使用了Sigmode函数。

预测

推理预测的代码如下:

# -*- coding: utf-8 -*-import torch
from data_preprocess import load_json
from bert_multilabel_cls import BertMultiLabelCls
from transformers import BertTokenizerhidden_size = 768
class_num = 3
label2idx_path = "./data/label2idx.json"
save_model_path = "./model/multi_label_cls.pth"
label2idx = load_json(label2idx_path)
idx2label = {idx: label for label, idx in label2idx.items()}
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
max_len = 128model = BertMultiLabelCls(hidden_size=hidden_size, class_num=class_num)
model.load_state_dict(torch.load(save_model_path))
model.to(device)
# 切换评估模式
model.eval()def predict(texts):# 加载分词器分词outputs = tokenizer(texts, return_tensors="pt", max_length=max_len,padding=True, truncation=True)# 加载模型logits = model(outputs["input_ids"].to(device),outputs["attention_mask"].to(device),outputs["token_type_ids"].to(device))logits = logits.cpu().tolist()# print(logits)result = []for sample in logits:pred_label = []for idx, logit in enumerate(sample):if logit > 0.5:pred_label.append(idx2label[idx])result.append(pred_label)return resultif __name__ == '__main__':texts = ["中超-德尔加多扳平郭田雨绝杀 泰山2-1逆转亚泰", "今日沪深两市指数整体呈现震荡调整格局"]result = predict(texts)print(result)

推理预测的对象是用户,这个流程一般就两个:分词器、模型。分词器会将输入文本处理后输出各种维度的数据,并将其作为模型的输入,最终返回分类的概率分布。

总结

从上面来看,所谓基于预训练大模型来做解决方案,总的来说就是接入大模型,利用大模型已有的泛化能力,训练/微调出一个"领域任务",使其在某个任务上更具有领域性与专业性。

归根来说,其流程就是神经网络模型建模与训练的过程:在基座大模型的基础上,再次训练了一个多分类任务的神经网络模型,以满足特定任务的需要;但模型的基本信息,都是基于基座大模型的。

如果对这里不了解的话,建议可以先去看看神经网络模型的建模与训练推理流程——基于Pytorch建模一个简单的神经网络模型,再做训练与推理;与这里面的流程做一个比对。

如何学习AI大模型?

作为一名热心肠的互联网老兵,我决定把宝贵的AI知识分享给大家。 至于能学习到多少就看你的学习毅力和能力了 。我已将重要的AI大模型资料包括AI大模型入门学习思维导图、精品AI大模型学习书籍手册、视频教程、实战学习等录播视频免费分享出来。

这份完整版的大模型 AI 学习资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

一、全套AGI大模型学习路线

AI大模型时代的学习之旅:从基础到前沿,掌握人工智能的核心技能!

img

二、640套AI大模型报告合集

这套包含640份报告的合集,涵盖了AI大模型的理论研究、技术实现、行业应用等多个方面。无论您是科研人员、工程师,还是对AI大模型感兴趣的爱好者,这套报告合集都将为您提供宝贵的信息和启示。

img

三、AI大模型经典PDF籍

随着人工智能技术的飞速发展,AI大模型已经成为了当今科技领域的一大热点。这些大型预训练模型,如GPT-3、BERT、XLNet等,以其强大的语言理解和生成能力,正在改变我们对人工智能的认识。 那以下这些PDF籍就是非常不错的学习资源。

img

四、AI大模型商业化落地方案

img

作为普通人,入局大模型时代需要持续学习和实践,不断提高自己的技能和认知水平,同时也需要有责任感和伦理意识,为人工智能的健康发展贡献力量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/9190.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

C语言 【大白话讲指针(中)】

在之前的文章中我们已经知道了指针的概念,指针就是一个变量,用来存放地址,地址指向唯一一块内存空间。指针的大小是固定的4/8个字节(32为机器/64位机器)。指针是有类型的,指针的类型决定了指针加减整数的步…

大数据分析在市场营销中的应用

💓 博客主页:瑕疵的CSDN主页 📝 Gitee主页:瑕疵的gitee主页 ⏩ 文章专栏:《热点资讯》 大数据分析在市场营销中的应用 大数据分析在市场营销中的应用 大数据分析在市场营销中的应用 引言 大数据分析概述 定义与原理 发…

启明云端触觉智能与您相约2024年慕尼黑国际电子元器件博览会,不见不散!

展会信息 展会日期: 2024年11月12-15日 展馆名称: 慕尼黑国际展览中心 MesseMnchen exhibition center 展馆地址: Messegelnde 81829 Mnchen Germany 启明云端&触觉智能展位号:B6-351 诚邀您莅临我司展位,让我们在慕尼黑不见不散! …

OPPO开源Diffusion多语言适配器—— MultilingualSD3-adapter 和 ChineseFLUX.1-adapter

MultilingualSD3-adapter 是为 SD3 量身定制的多语言适配器。 它源自 ECCV 2024 的一篇题为 PEA-Diffusion 的论文。ChineseFLUX.1-adapter是为Flux.1系列机型量身定制的多语言适配器,理论上继承了ByT5,可支持100多种语言,但在中文方面做了额…

【JavaEE初阶】网络原理(4)

欢迎关注个人主页:逸狼 创造不易,可以点点赞吗~ 如有错误,欢迎指出~ 目录 网络层 > IP协议 IP协议报头结构 4位版本 4位首部长度 8位服务类型(TOS) 16位总长度(字节数), 16位标识 3位标志位 13位片偏移 8位生存时间(TTL) 8位协议 16位首部…

树莓派上安装与配置 Nginx Web 服务器教程

在树莓派上配置 Nginx 作为 Web 服务器的步骤如下: 1. 更新树莓派 首先,确保你的树莓派系统是最新的。打开终端并执行以下命令: sudo apt update sudo apt upgrade -y2. 安装 Nginx 在树莓派上安装 Nginx: sudo apt install …

Android Studio 中关于com.github.barteksc:android-pdf-viewer 无法正确加载的问题

Android Studio 的app 模块下,添加依赖: implementation com.github.barteksc:android-pdf-viewer:3.2.0-beta.1 运行程序报错: Caused by: org.gradle.api.internal.artifacts.ivyservice.DefaultLenientConfiguration$ArtifactResolveEx…

[JAVA]Maven项目标准结构介绍

什么是Maven? Maven 是一个强大的项目管理和构建自动化工具,在Java开发中,一个项目通常会依赖许多外部的库,比如开发一个Web应用可能需要依赖Servlet APL,Spring框架等,和需要引入大量的Jar包。往往一个Ja…

Ansys EMC Plus:MHARNESS 串扰演示

Ansys EMC Plus 是一款强大的工具,专门用于分析电磁场及其影响,涵盖电磁兼容性和雷电效应分析等领域。 在本演示中,我们将探讨建立 MHARNESS 仿真的基础知识。这包括构建基本电缆线束、创建 MHARNESS 源和设置 MHARNESS 探针的过程。 概述 …

星环大数据平台--TDH部署

1.1 准备一台虚拟机 正常安装一台新的虚拟机, 内存16G,cpu8核,硬盘50G 1.2 安装前系统配置改动 修改/etc/hosts文件,确保hostname该文件包含节点的hostname和IP地址的映射关系列表。 hostname由数字、小写字母或“-”组成&am…

B+树与聚簇索引以及非聚簇索引的关系

B树、聚簇索引和非聚簇索引是数据库系统中非常重要的概念,它们共同决定了数据的存储和查询效率。本文将详细解释B树的结构,以及聚簇索引和非聚簇索引的区别和联系,使读者能够更好地理解这些概念。 1.B树简介 B树是一种多路平衡树,…

IoTDB 与 HBase 对比详解:架构、功能与性能

五大方向,洞悉 IoTDB 与 HBase 的详尽对比! 在物联网(IoT)领域,数据的采集、存储和分析是确保系统高效运行和决策准确的重要环节。随着物联网设备数量的增加和数据量的爆炸式增长,开发者和决策者们需要选择…

了解RSA和DSA的联系和区别

引言 在信息安全领域,加密算法起着至关重要的作用。RSA(Rivest-Shamir-Adleman)和DSA(Digital Signature Algorithm)是两种常见的公钥加密算法,它们在网络安全领域具有重要的应用价值。本文将对比分析RSA和…

项目管理体系文档,代码评审规范文档,代码审查,代码走查标准化文档(word原件)

1.代码评审(Code Review)简介 1.1Code Review的目的 1.2Code Review的前提 1.3.Code Review需要做什么 1.3.1完整性检查(Completeness) 1.3.2一致性检查(Consistency) 1.3.3正确性检查(Correctness) …

前端算法:树(力扣144、94、145、100、104题)

目录 一、树(Tree) 1.介绍 2.特点 3.基本术语 4.种类 二、树之操作 1.遍历 前序遍历(Pre-order Traversal):访问根节点 -> 遍历左子树 -> 遍历右子树。 中序遍历(In-order Traversal&#xf…

Webserver(5.3)线程池实现

目录 线程池locker.hthreadpool.h 线程池 相比于动态地创建子线程,选择一个已经存在的子线程的代价显然要小得多。至于主线程选择哪个子线程来为新任务服务,有多种方式: 主线程使用某种算法来主动选择子线程。最简单、最常用的算法是随机算…

02_ElementUI

一.前端工程化 1.1 概述 前端工程化是使用软件工程的方法来单独解决前端的开发流程 中模块化、组件化、规范化、自动化的问题,其主要目的为了 提高效率和降低成本。 1.2 NodeJS的安装 Node.js 是一个基于 Chrome V8 引擎的 JavaScript 运行时环 境,可以使 JavaS…

从无音响Windows 端到 有音响macOS 端实时音频传输播放

以下是从 Windows 端到 macOS 端传输音频的优化方案,基于上述链接中的思路进行调整: Windows 端操作 安装必要软件 安装 Python(确保版本兼容且已正确配置环境变量)。安装 PyAudio 库,可通过 pip install pyaudio 命令…

SpringBoot实现的企业资产管理系统

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统,它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等,非常…

建筑行业智慧知识库的搭建与运用

一、引言 在建筑领域,知识管理是企业持续发展和提升竞争力的关键所在。智慧知识库的构建,不仅能够促进知识的有效传递与共享,还能为项目管理和决策提供有力支持。本文将重点探讨建筑行业智慧知识库构建的价值、实践路径以及需要注意的关键点…