Sentence-BERT实现文本匹配【回归目标函数】

引言

上篇文章我们通过Sentence-Bert提出的分类目标函数来训练句子嵌入模型,本文同样基于Sentence-Bert的架构,但改用回归目标函数。

架构

image-20210923000654664

如上图,计算两个句嵌入 u \pmb u u v \pmb v v​之间的余弦相似度,然后可以使用均方误差(mean-squared-error)作为目标函数。
L = ∣ ∣ y − cosine_sim ( u , v ) ∣ ∣ 2 \mathcal L = ||y - \text{cosine\_sim}(\pmb u,\pmb v)||_2 L=∣∣ycosine_sim(u,v)2
这里的 y y y是真实标签。

回归目标函数的预测不再是整数标签1或0了,而可以为数值。比如对于给定的句子对,可以计算相似度得分。此时推理流程与训练完全相同。

实现

实现采用类似Huggingface的形式,每个文件夹下面有一种模型。分为modelingargumentstrainer等不同的文件。不同的架构放置在不同的文件夹内。

modeling.py:

from dataclasses import dataclassimport torch
from torch import Tensor, nnfrom transformers.file_utils import ModelOutputfrom transformers import (AutoModel,AutoTokenizer,
)import numpy as np
from tqdm.autonotebook import trange
from typing import Optional@dataclass
class BiOutput(ModelOutput):loss: Optional[Tensor] = Nonescores: Optional[Tensor] = Noneclass SentenceBert(nn.Module):def __init__(self,model_name: str,trust_remote_code: bool = True,max_length: int = None,num_classes: int = 2,pooling_mode: str = "mean",normalize_embeddings: bool = False,) -> None:super().__init__()self.model_name = model_nameself.normalize_embeddings = normalize_embeddingsself.device = "cuda" if torch.cuda.is_available() else "cpu"self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)self.model = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code).to(self.device)self.max_length = max_lengthself.pooling_mode = pooling_modeself.loss_fct = nn.MSELoss()def sentence_embedding(self, last_hidden_state, attention_mask):if self.pooling_mode == "mean":attention_mask = attention_mask.unsqueeze(-1).float()return torch.sum(last_hidden_state * attention_mask, dim=1) / torch.clamp(attention_mask.sum(1), min=1e-9)else:# clsreturn last_hidden_state[:, 0]def encode(self,sentences: str | list[str],batch_size: int = 64,convert_to_tensor: bool = True,show_progress_bar: bool = False,):if isinstance(sentences, str):sentences = [sentences]all_embeddings = []for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):batch = sentences[start_index : start_index + batch_size]features = self.tokenizer(batch,padding=True,truncation=True,return_tensors="pt",return_attention_mask=True,max_length=self.max_length,).to(self.device)out_features = self.model(**features, return_dict=True)embeddings = self.sentence_embedding(out_features.last_hidden_state, features["attention_mask"])if not self.training:embeddings = embeddings.detach()if self.normalize_embeddings:embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)if not convert_to_tensor:embeddings = embeddings.cpu()all_embeddings.extend(embeddings)if convert_to_tensor:all_embeddings = torch.stack(all_embeddings)else:all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])return all_embeddingsdef compute_loss(self, scores, labels):labels = torch.tensor(labels).float().to(self.device)return self.loss_fct(scores, labels.view(-1))def forward(self, source, target, labels) -> BiOutput:"""Args:source :target :"""source_embed = self.encode(source)target_embed = self.encode(target)scores = torch.cosine_similarity(source_embed, target_embed)loss = self.compute_loss(scores, labels)return BiOutput(loss, scores)def save_pretrained(self, output_dir: str):state_dict = self.model.state_dict()state_dict = type(state_dict)({k: v.clone().cpu().contiguous() for k, v in state_dict.items()})self.model.save_pretrained(output_dir, state_dict=state_dict)

整个模型的实现放到modeling.py文件中。

arguments.py:

from dataclasses import dataclass, field
from typing import Optionalimport os@dataclass
class ModelArguments:model_name_or_path: str = field(metadata={"help": "Path to pretrained model"})config_name: Optional[str] = field(default=None,metadata={"help": "Pretrained config name or path if not the same as model_name"},)tokenizer_name: Optional[str] = field(default=None,metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},)@dataclass
class DataArguments:train_data_path: str = field(default=None, metadata={"help": "Path to train corpus"})eval_data_path: str = field(default=None, metadata={"help": "Path to eval corpus"})max_length: int = field(default=512,metadata={"help": "The maximum total input sequence length after tokenization for input text."},)def __post_init__(self):if not os.path.exists(self.train_data_path):raise FileNotFoundError(f"cannot find file: {self.train_data_path}, please set a true path")if not os.path.exists(self.eval_data_path):raise FileNotFoundError(f"cannot find file: {self.eval_data_path}, please set a true path")

定义了模型和数据相关参数。

dataset.py:

from torch.utils.data import Dataset
from datasets import Dataset as dt
import pandas as pdfrom utils import build_dataframe_from_csvclass PairDataset(Dataset):def __init__(self, data_path: str) -> None:df = build_dataframe_from_csv(data_path)self.dataset = dt.from_pandas(df, split="train")self.total_len = len(self.dataset)def __len__(self):return self.total_lendef __getitem__(self, index) -> dict[str, str]:query1 = self.dataset[index]["query1"]query2 = self.dataset[index]["query2"]label = self.dataset[index]["label"]return {"query1": query1, "query2": query2, "label": label}class PairCollator:def __call__(self, features) -> dict[str, list[str]]:queries1 = []queries2 = []labels = []for feature in features:queries1.append(feature["query1"])queries2.append(feature["query2"])labels.append(feature["label"])return {"source": queries1, "target": queries2, "labels": labels}

数据集类考虑了LCQMC数据集的格式,即成对的语句和一个数值标签。类似:

Hello.	Hi.	1
Nice to see you.	Nice	0

trainer.py:

import torch
from transformers.trainer import Trainerfrom typing import Optional
import os
import loggingfrom modeling import SentenceBertTRAINING_ARGS_NAME = "training_args.bin"
logger = logging.getLogger(__name__)class BiTrainer(Trainer):def compute_loss(self, model: SentenceBert, inputs, return_outputs=False):outputs = model(**inputs)loss = outputs.lossreturn (loss, outputs) if return_outputs else lossdef _save(self, output_dir: Optional[str] = None, state_dict=None):# If we are executing this function, we are the process zero, so we don't check for that.output_dir = output_dir if output_dir is not None else self.args.output_diros.makedirs(output_dir, exist_ok=True)logger.info(f"Saving model checkpoint to {output_dir}")self.model.save_pretrained(output_dir)if self.tokenizer is not None:self.tokenizer.save_pretrained(output_dir)# Good practice: save your training arguments together with the trained modeltorch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

继承🤗 Transformers的Trainer类,重写了compute_loss_save方法。

这样我们就可以利用🤗 Transformers来训练我们的模型了。

utils.py:

import torch
import pandas as pd
from scipy.stats import pearsonr, spearmanr
from typing import Tupledef build_dataframe_from_csv(dataset_csv: str) -> pd.DataFrame:df = pd.read_csv(dataset_csv,sep="\t",header=None,names=["query1", "query2", "label"],)return dfdef compute_spearmanr(x, y):return spearmanr(x, y).correlationdef compute_pearsonr(x, y):return pearsonr(x, y)[0]def find_best_acc_and_threshold(scores, labels, high_score_more_similar: bool):"""Copied from https://github.com/UKPLab/sentence-transformers/tree/master"""assert len(scores) == len(labels)rows = list(zip(scores, labels))rows = sorted(rows, key=lambda x: x[0], reverse=high_score_more_similar)print(rows)max_acc = 0best_threshold = -1# positive examples number so farpositive_so_far = 0# remain negative examplesremaining_negatives = sum(labels == 0)for i in range(len(rows) - 1):score, label = rows[i]if label == 1:positive_so_far += 1else:remaining_negatives -= 1acc = (positive_so_far + remaining_negatives) / len(labels)if acc > max_acc:max_acc = accbest_threshold = (rows[i][0] + rows[i + 1][0]) / 2return max_acc, best_thresholddef metrics(y: torch.Tensor, y_pred: torch.Tensor) -> Tuple[float, float, float, float]:TP = ((y_pred == 1) & (y == 1)).sum().float()  # True PositiveTN = ((y_pred == 0) & (y == 0)).sum().float()  # True NegativeFN = ((y_pred == 0) & (y == 1)).sum().float()  # False NegatvieFP = ((y_pred == 1) & (y == 0)).sum().float()  # False Positivep = TP / (TP + FP).clamp(min=1e-8)  # Precisionr = TP / (TP + FN).clamp(min=1e-8)  # RecallF1 = 2 * r * p / (r + p).clamp(min=1e-8)  # F1 scoreacc = (TP + TN) / (TP + TN + FP + FN).clamp(min=1e-8)  # Accuraryreturn acc, p, r, F1def compute_metrics(predicts, labels):return metrics(labels, predicts)

定义了一些帮助函数,从sentence-transformers库中拷贝了寻找最佳准确率阈值的实现find_best_acc_and_threshold

除了准确率,还计算了句嵌入的余弦相似度与真实标签之间的斯皮尔曼等级相关系数指标。

最后定义训练和测试脚本。

train.py:

from transformers import set_seed, HfArgumentParser, TrainingArgumentsimport logging
from pathlib import Pathfrom datetime import datetimefrom modeling import SentenceBert
from trainer import BiTrainer
from arguments import DataArguments, ModelArguments
from dataset import PairCollator, PairDatasetlogger = logging.getLogger(__name__)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",datefmt="%m/%d/%Y %H:%M:%S",level=logging.INFO,
)def main():parser = HfArgumentParser((TrainingArguments, DataArguments, ModelArguments))training_args, data_args, model_args = parser.parse_args_into_dataclasses()# 根据当前时间生成输出目录output_dir = f"{training_args.output_dir}/{model_args.model_name_or_path.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"training_args.output_dir = output_dirlogger.info(f"Training parameters {training_args}")logger.info(f"Data parameters {data_args}")logger.info(f"Model parameters {model_args}")# 设置随机种子set_seed(training_args.seed)# 加载预训练模型model = SentenceBert(model_args.model_name_or_path,trust_remote_code=True,max_length=data_args.max_length,)tokenizer = model.tokenizer# 构建训练和测试集train_dataset = PairDataset(data_args.train_data_path)eval_dataset = PairDataset(data_args.eval_data_path)# 传入参数trainer = BiTrainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset,data_collator=PairCollator(),tokenizer=tokenizer,)Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)# 开始训练trainer.train()trainer.save_model()if __name__ == "__main__":main()

训练

基于train.py定义了train.sh传入相关参数:

timestamp=$(date +%Y%m%d%H%M)
logfile="train_${timestamp}.log"# change CUDA_VISIBLE_DEVICES
CUDA_VISIBLE_DEVICES=3 nohup python train.py \--model_name_or_path=hfl/chinese-macbert-large \--output_dir=output \--train_data_path=data/train.txt \--eval_data_path=data/dev.txt \--num_train_epochs=3 \--save_total_limit=5 \--learning_rate=2e-5 \--weight_decay=0.01 \--warmup_ratio=0.01 \--bf16=True \--eval_strategy=epoch \--save_strategy=epoch \--per_device_train_batch_size=64 \--report_to="none" \--remove_unused_columns=False \--max_length=128 \> "$logfile" 2>&1 &

以上参数根据个人环境修改,这里使用的是哈工大的chinese-macbert-large预训练模型。

注意:

  • --remove_unused_columns是必须的。
  • 通过bf16=True可以加速训练同时不影响效果。
  • 其他参数可以自己调整。
100%|██████████| 18655/18655 [1:17:23<00:00,  4.44it/s]
100%|██████████| 18655/18655 [1:17:23<00:00,  4.02it/s]
09/02/2024 21:02:41 - INFO - trainer - Saving model checkpoint to output/hfl-chinese-macbert-large-2024-09-02_19-45-12
{'eval_loss': 0.09294428676366806, 'eval_runtime': 56.1261, 'eval_samples_per_second': 156.825, 'eval_steps_per_second': 19.617, 'epoch': 5.0}
{'train_runtime': 4643.261, 'train_samples_per_second': 257.11, 'train_steps_per_second': 4.018, 'train_loss': 0.049199433276877584, 'epoch': 5.0}

这里训练了5轮,我们拿最后保存的模型output/hfl-chinese-macbert-large-2024-09-02_19-45-12进行测试。

参数忘改了,为了便于比较,实际上下面的结果是以3轮的训练结果验证的。

测试

test.py: 测试脚本见后文的完整代码。

test.sh:

# change CUDA_VISIBLE_DEVICES
CUDA_VISIBLE_DEVICES=0 python test.py \--model_name_or_path=output/hfl-chinese-macbert-large-2024-09-02_19-45-12/checkpoint-11193 \--test_data_path=data/test.txt

输出:

TestArguments(model_name_or_path='output/hfl-chinese-macbert-large-2024-09-02_19-45-12/checkpoint-11193', test_data_path='data/test.txt', max_length=64, batch_size=128)
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:11<00:00,  8.77it/s]
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:11<00:00,  8.89it/s]
max_acc: 0.8832, best_threshold: 0.794167
spearman corr: 0.7795 |  pearson_corr corr: 0.7668 | compute time: 22.25s
accuracy=0.883 precision=0.876 recal=0.893 f1 score=0.8843

测试集上的准确率达到88.3%,这种以回归目标函数进行训练的效果没有分类的好。

完整代码

完整代码: →点此←

本文代码是和某次commit相关的,Master分支上的代码随时可能会被优化。

参考

  1. [论文笔记]Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1522999.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Webpack详解与配置环境

webpack&#xff1a;webpack网址 1、工作原理&#xff1a; Webpack是一个非常强大的静态模块的打包工具。从文件入口开始&#xff0c;递归解析以来关系&#xff0c;然后将所有模块打包成一个或多个budle文件。 2、webpack核心概念&#xff1a; Entry&#xff1a;入口起点(en…

【STM32+HAL库】---- 通用定时器实现外部脉冲计数

硬件开发板&#xff1a;STM32G0B1RET6 软件平台&#xff1a;cubemaxkeilVScode1 新建cubemax工程 1.1 配置系统时钟RCC 1.2 配置定时器 选择通用定时器TIM2&#xff0c;时钟源选择ETR2&#xff0c;对应的输入端口为PA0引脚&#xff0c;预分频系数为0&#xff0c;重装载值选择…

Python简易IDE工作界面制作

、 休闲一下&#xff0c;学习编程还是要学习一些界面编程&#xff0c;能够根据需要制作图形操作界面&#xff0c;这样我们开发的程序才能方便操作和使用&#xff0c;同时获得更友好的人机交互体验。下面是一个用PyQt5制作的简易界面&#xff0c;供大学参考。如下图所示&a…

在SpringMVC中用fmt标签实现国际化/多语言

SpringMVC中用fmt标签实现国际化主要解决界面的多语言化&#xff0c;ftm标签会根据浏览器的语言值来先择对应的文件配置&#xff0c;如中文简体的浏览器值是zh_CN,那么ftm标签就会用以zh_CN.properties结尾的配置文件中的key来取值&#xff0c;从而实现自多语言的自动切换&…

自动化运维之SaltStack 部署应用

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:Linux运维老纪的首页…

【Steam游戏星露谷物语添加Mod步骤】

Steam游戏星露谷物语添加Mod步骤 星露谷物语添加拖拉机模组一、安装SMAPI二、正式开始添加MOD 星露谷物语添加拖拉机模组 一、安装SMAPI 星露谷物语添加拖拉机mod为例&#xff0c;添加其它mod一样的步骤。 首先&#xff0c;打开Steam&#xff0c;打开一次星露谷物语这款游戏&…

论文速读纪录 - 202408

特别鸣谢kimi&#xff0c;以下论文均在kimi辅助下阅读。 目录 RMIB: Representation Matching Information Bottleneck for Matching Text RepresentationsAttentionRank: Unsupervised keyphrase Extraction using Self and Cross AttentionsANSWERING COMPLEX OPEN-DOMAIN …

标准库标头 <optional> (C++17)学习之optional

类模板 std::optional 管理一个可选 &#xfeff;的所含值&#xff0c;即既可以存在也可以不存在的值。 一种常见的 optional 使用情况是作为可能失败的函数的返回值。与如 std::pair<T, bool> 等其他手段相比&#xff0c;optional 可以很好地处理构造开销高昂的对象&a…

DataGridView用法合集【精品】

1.当前的单元格属性取得、变更 [VB.NET] Console.WriteLine(DataGridView1.CurrentCell.Value) Console.WriteLine(DataGridView1.CurrentCell.ColumnIndex) Console.WriteLine(DataGridView1.CurrentCell.RowIndex) DataGridView1.CurrentCell DataGridView1(0, 0) [C#] Con…

虚幻引擎VR游戏开发02 | 性能优化设置

常识&#xff1a;VR需要保持至少90 FPS的刷新率&#xff0c;以避免用户体验到延迟或晕眩感。以下是优化性能的一系列设置&#xff08;make sure the frame rate does not drop below a certain threshold&#xff09; In project setting-> &#xff08;以下十个设置都在pr…

基于php+vue+uniapp的医院预约挂号系统小程序

开发语言&#xff1a;PHP框架&#xff1a;phpuniapp数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;PhpStorm 系统展示 后台登录界面 管理员功能界面 用户管理 医生管理 科室分类管理 医生信息管理 预…

机器人外呼有哪些优势?

机器人外呼&#xff0c;作为一种结合了计算机技术和人工智能技术的自动化工具&#xff0c;具有多重显著优势。以下是其主要优势的详细阐述&#xff1a; ### 1. 高效性 * **大幅提升工作效率**&#xff1a;机器人外呼可以全天候、不间断地进行工作&#xff0c;不受时间、地点和…

【优质源码】3D多人在线游戏,前端ThreeJS,后端NodeJS

3D多人在线游戏 【源码】3D多人在线游戏源码&#xff0c;前端ThreeJS&#xff0c;后端NodeJS&#xff0c;完整源码。 游戏画面 启动方法 先启动服务器端。 在目录&#xff0c;3D-multi-player-main\3D-multi-player-main\nodeapps\blockland 中&#xff0c;运行&#xff1a…

Elasticsearch之原理详解

简介 ES是使用 Java 编写的一种开源搜索引擎&#xff0c;它在内部使用 Lucene 做索引与搜索&#xff0c;通过对 Lucene 的封装&#xff0c;隐藏了 Lucene 的复杂性&#xff0c;取而代之的提供一套简单一致的 RESTful API 然而&#xff0c;Elasticsearch 不仅仅是 Lucene&#…

设计模式及创建型模式-python版

1 架构模式与设计模式 架构模式搞层次的设计模式&#xff0c; 描述系统整体结构和组织方式&#xff0c;设计模式是针对某个问题的解决方案&#xff0c;是一种解决问题的思路。 2 设计模式的分类 2.1 创建型模式 单例模式&#xff0c;工厂方法模式&#xff0c;抽象工厂模式&…

JVM2-JVM组成、字节码文件、类的生命周期、类加载器

目录 Java虚拟机的组成 字节码文件 字节码文件打开方式 字节码文件的组成 基本信息 Magic魔数 主副版本号 常量池 字段 方法 属性 字节码常用工具 javap jclasslib插件 Arthas 类的生命周期 概述 加载阶段 连接阶段 验证 准备 解析 初始化阶段 类加载器…

linux 下一跳缓存,early demux(‌早期解复用)‌介绍

3.6版本以后的下一跳缓存 3.6版本移除了FIB查找前的路由缓存。这意味着每一个接收发送的skb现在都必须要进行FIB查找了。这样的好处是现在查找路由的代价变得稳定(consistent)了。3.6版本实际上是将FIB查找缓存到了下一跳(fib_nh)结构上&#xff0c;也就是下一跳缓存下一跳缓存…

ESP32无线WiFi芯片模组,设备物联网连接通信,产品智能化交互升级

在数字化浪潮的推动下&#xff0c;我们正步入一个万物互联的新时代。物联网&#xff08;IoT&#xff09;技术&#xff0c;作为连接物理世界与数字世界的桥梁&#xff0c;正逐渐渗透到我们生活的每一个角落。 乐鑫正通过其创新的无线WiFi芯片模组&#xff0c;为这些领域的发展提…

界面控件DevExpress中文教程:如何使用AI扩展Excel计算?

DevExpress WinForms拥有180组件和UI库&#xff0c;能为Windows Forms平台创建具有影响力的业务解决方案。DevExpress WinForms能完美构建流畅、美观且易于使用的应用程序&#xff0c;无论是Office风格的界面&#xff0c;还是分析处理大批量的业务数据&#xff0c;它都能轻松胜…

Elasticsearch的Restful风格API

前言&#xff1a;本博客仅作记录学习使用&#xff0c;部分图片出自网络&#xff0c;如有侵犯您的权益&#xff0c;请联系删除 1、Restful及JSON格式 RESTFUL是一种网络应用程序的设计风格和开发方式&#xff0c;基于HTTP&#xff0c;可以使用 XML 格式定义或 JSON 格式定义。R…