Trainer API训练属于自己行业的本地大语言模型 医疗本地问答大模型示例

Trainer API 是 Hugging Face transformers 库中强大而灵活的工具,简化了深度学习模型的训练和评估过程。通过提供高层次的接口和多种功能,Trainer API 使研究人员和开发者能够更快地构建和优化自然语言处理模型

文章目录

  • 前言
  • 一、Trainer API
    • 它能做什么?
    • 基本步骤
    • 简单示例
  • 二、详细步骤
    • 安装依赖
    • 准备数据集
    • 加载本地模型和分词器
    • 数据预处理
    • 设置训练参数
    • 创建 Trainer 实例
    • 开始训练
    • 评估模型
    • 保存模型
  • 三、 Trainer API 对大模型的限制
  • 四、用医疗数据微调 ChatGPT-3.5,让它成为智能的医疗问答助手
    • 使用 Ollama 安装大语言模型
    • 准备高质量的医疗数据
    • 数据准备与清洗
    • 数据分割
    • 创建训练数据格式
    • 使用 Trainer API 进行微调
    • 数据编码
    • 设置训练参数
    • 创建 Trainer 实例
    • 开始微调
    • 评估模型
    • 保存微调后的模型
    • 测试模型
    • 持续优化
  • 总结


前言

在如今这个数字化飞速发展的时代,越来越多的企业和个人意识到,通用的AI模型可能无法满足特定行业或业务的独特需求。

企业在日常运营中面对的语言、文化和行业术语,都是千差万别的。而本地化的语言模型,不仅可以更精准地理解这些内容,还能提供更贴近用户需求的服务。比如,医疗行业的AI助手需要懂得专业术语,而电商平台的客服AI则得知道怎样与顾客更好地互动。

更重要的是,随着数据隐私和安全问题日益突显,很多企业开始倾向于在本地进行模型训练,这样不仅可以保护敏感数据,还能在数据处理上保持灵活性。
总之,定制化训练本地大语言模型的需求,将会是未来的趋势,帮助我们更好地应对各行各业的挑战!


一、Trainer API

Trainer API 是 Hugging Face 的 transformers 库中一个帮助你训练和评估机器学习模型的工具。它提供了一个简单的接口,让你无需深入了解复杂的训练细节,就可以专注于模型的开发。

它能做什么?

  1. 训练模型:自动处理训练过程,只需要提供数据和模型。
  2. 评估性能:在每个训练周期结束时,自动评估模型的性能。
  3. 记录日志:可以记录训练进度,比如损失和准确率,便于分析。
  4. 保存模型:训练完毕后,方便地保存模型以供后续使用。
  5. 扩展性:允许你添加自定义逻辑,比如特殊的评估指标。

基本步骤

  1. 准备数据:将你的数据集准备好,并分成训练集和验证集。
  2. 选择模型:选择一个预训练模型,比如 BERT 或 GPT。
  3. 设置训练参数:定义训练的超参数,比如学习率、批次大小和训练轮数。
  4. 创建 Trainer 实例:将模型、参数和数据传给 Trainer。
  5. 开始训练:调用 train() 方法开始训练。
  6. 评估模型:使用 evaluate() 方法来评估模型性能。
  7. 保存模型:使用save_model() 方法保存训练好的模型。

简单示例

下面是一个非常简化的代码示例,展示如何使用 Trainer API:

from transformers import Trainer, TrainingArguments, BertForSequenceClassification# 1. 加载模型
model = BertForSequenceClassification.from_pretrained("bert-base-uncased")# 2. 设置训练参数
training_args = TrainingArguments(output_dir='./results',   # 输出结果存放位置num_train_epochs=3,       # 训练3个周期per_device_train_batch_size=8,  # 每个设备的批次大小
)# 3. 创建 Trainer 实例
trainer = Trainer(model=model,              # 要训练的模型args=training_args,       # 训练参数train_dataset=train_dataset,  # 训练数据集eval_dataset=valid_dataset   # 验证数据集
)# 4. 开始训练
trainer.train()# 5. 评估模型
results = trainer.evaluate()
print(results)# 6. 保存模型
trainer.save_model('./final_model')

二、详细步骤

安装依赖

确保你已经安装了 transformers 和 datasets 库。如果尚未安装,可以使用以下命令:

pip install transformers datasets

准备数据集

首先,你需要准备一个数据集。假设你有一个 CSV 文件,包含问答对(问题和答案)。你需要将数据加载为适合 Trainer 的格式。

import pandas as pd
from datasets import Dataset# 加载数据
data = pd.read_csv('your_dataset.csv')  # 替换为你的数据文件
dataset = Dataset.from_pandas(data)# 分割数据集
train_test_split = dataset.train_test_split(test_size=0.2)
train_dataset = train_test_split['train']
valid_dataset = train_test_split['test']

加载本地模型和分词器

加载你本地的大模型和相应的分词器。假设你的模型存放在 path/to/your/model 目录下。

from transformers import AutoTokenizer, AutoModelForSequenceClassificationmodel_path = "path/to/your/model"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)

数据预处理

使用分词器对数据进行编码,将文本转为模型可以接受的格式。

def preprocess_function(examples):return tokenizer(examples['question'], padding="max_length", truncation=True)train_dataset = train_dataset.map(preprocess_function, batched=True)
valid_dataset = valid_dataset.map(preprocess_function, batched=True)

设置训练参数

定义训练超参数。这些参数会影响训练过程和模型性能。

from transformers import TrainingArgumentstraining_args = TrainingArguments(output_dir='./results',              # 输出目录evaluation_strategy="epoch",          # 每个周期评估一次learning_rate=2e-5,                   # 学习率per_device_train_batch_size=8,        # 训练批次大小per_device_eval_batch_size=8,         # 评估批次大小num_train_epochs=3,                   # 训练轮数weight_decay=0.01                     # 权重衰减
)

创建 Trainer 实例

将模型、训练参数和数据集传给 Trainer 实例。

from transformers import Trainertrainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=valid_dataset,
)

开始训练

调用 train() 方法开始训练模型。

trainer.train()

评估模型

训练完成后,可以使用验证集对模型进行评估。

results = trainer.evaluate()
print(results)

保存模型

最后,将训练好的模型和分词器保存到指定目录。

trainer.save_model('./final_model')  # 保存模型
tokenizer.save_pretrained('./final_model')  # 保存分词器

三、 Trainer API 对大模型的限制

  1. 内存限制
    GPU 内存:大模型可能会超出可用 GPU 内存,导致训练失败。需要确保你的硬件能够支持模型的大小,或者使用梯度累积等方法减小内存压力。
    CPU 内存:如果使用 CPU 训练,模型和数据也可能占用大量内存。
  2. 训练时间
    大模型通常需要更长的训练时间,尤其是在较大的数据集上。这可能导致实验周期变长。
  3. 超参数调整
    对于大模型,超参数的选择(如学习率、批次大小等)可能会影响训练的稳定性和收敛速度。调试过程可能更加复杂。
  4. 分布式训练
    如果需要在多个 GPU 上进行分布式训练,配置和管理会更加复杂,涉及到同步和通信问题。
  5. 推理速度
    大模型的推理速度相对较慢,这在实时应用中可能成为瓶颈。
  6. 兼容性
    某些大模型可能不支持 Trainer API 的所有功能,例如特定的损失函数或自定义回调。

应对策略

  • 使用混合精度训练:可以减少内存使用并加速训练。
  • 模型压缩:尝试量化或剪枝等技术,以减少模型的大小和计算需求。
  • 分布式训练:使用accelerate 库或其他框架进行分布式训练。
  • 动态批次大小:根据可用资源动态调整批次大小。

四、用医疗数据微调 ChatGPT-3.5,让它成为智能的医疗问答助手

使用 Ollama 安装大语言模型

首先,我们需要在本地安装 Ollama。这是一个非常便捷的工具,能够让你轻松下载和运行大型语言模型。打开终端,输入以下命令:

curl -sSfL https://ollama.com/download.sh | sh

安装完成后,使用以下命令下载 GPT-3.5 模型:

ollama pull gpt-3.5

接下来,我们需要启动 Ollama 服务器,以便通过 API 调用模型:

ollama serve gpt-3.5

这会在 http://localhost:11434 启动一个服务器,你可以通过这个地址访问模型。

准备高质量的医疗数据

用高质量的医疗数据来微调 ChatGPT-3.5,让它成为一个更智能的医疗问答助手,首先,我们需要收集一些高质量的医疗问答数据。这里有几个建议:

  • 数据来源:可以从权威的医疗网站、学术论文以及知名的医疗问答平台上收集数据。确保这些信息是最新且可靠的。
  • 数据格式:你的数据最好是以问答对的形式保存,比如 CSV 文件,这样方便后续处理。

例如,你的数据文件可能长这样:

Question,Answer
"什么是糖尿病?","糖尿病是一种慢性疾病,血糖水平过高。"
"流感的症状有哪些?","流感的常见症状包括发热、咳嗽和全身疼痛。"

确保问题和答案都是准确且具有代表性的。

数据准备与清洗

使用 Pandas 读取和清洗数据,确保数据质量,避免空值和重复项:

import pandas as pd# 读取数据
data = pd.read_csv('medical_qa_chinese.csv')# 数据清洗
data.dropna(inplace=True)  # 删除空值
data = data[data['Question'].str.strip() != '']  # 删除空问题
data = data.drop_duplicates(subset='Question')  # 删除重复问题

清洗后的数据可以更好地支持模型训练。

数据分割

将数据集分为训练集和验证集,以便于后续评估模型:

from sklearn.model_selection import train_test_splittrain_data, valid_data = train_test_split(data, test_size=0.2, random_state=42)

创建训练数据格式

将问题和答案整理为模型可接受的 JSON 格式,方便后续处理:

train_pairs = [{"prompt": q, "completion": a} for q, a in zip(train_data['Question'], train_data['Answer'])]
valid_pairs = [{"prompt": q, "completion": a} for q, a in zip(valid_data['Question'], valid_data['Answer'])]import json# 保存训练数据
with open('train_data.json', 'w', encoding='utf-8') as f:json.dump(train_pairs, f, ensure_ascii=False)# 保存验证数据
with open('valid_data.json', 'w', encoding='utf-8') as f:json.dump(valid_pairs, f, ensure_ascii=False)

使用 Trainer API 进行微调

确保你安装了 Hugging Face 的 Transformers 和 Torch 库:

pip install transformers torch

接下来,加载 GPT-3.5 模型并进行微调:

from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer# 加载 Ollama 模型
model_name = "ollama/gpt-3.5"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

数据编码

将训练和验证数据编码为模型所需的格式:

train_encodings = tokenizer(train_pairs, truncation=True, padding=True, return_tensors='pt')
valid_encodings = tokenizer(valid_pairs, truncation=True, padding=True, return_tensors='pt')import torchclass MedicalDataset(torch.utils.data.Dataset):def __init__(self, encodings):self.encodings = encodingsdef __getitem__(self, idx):return {key: val[idx] for key, val in self.encodings.items()}def __len__(self):return len(self.encodings['input_ids'])train_dataset = MedicalDataset(train_encodings)
valid_dataset = MedicalDataset(valid_encodings)

设置训练参数

配置训练参数,以确保模型在训练过程中达到最佳效果。以下是一些关键参数:

  • 学习率:控制模型学习的速度,太高可能导致训练不稳定,太低则训练过慢。
  • 训练轮数:通常可以从 3 到 5 轮开始,具体根据模型的学习情况进行调整。
  • 批量大小:根据你的硬件配置选择合适的批量大小,通常 4-16 之间。
training_args = TrainingArguments(output_dir='./results',              # 输出目录evaluation_strategy="epoch",          # 每个训练周期后进行评估learning_rate=5e-5,                   # 学习率per_device_train_batch_size=4,        # 每个设备的训练批次大小num_train_epochs=5,                   # 增加训练轮数以提高效果weight_decay=0.01                     # 权重衰减
)

创建 Trainer 实例

创建 Trainer 实例并传入模型、训练参数和数据集:

trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=valid_dataset,
)

开始微调

调用 train() 方法开始微调过程:

trainer.train()

微调过程:

  • 在训练期间,模型会逐步学习如何从输入问题中生成合适的答案。
  • 每个 epoch 后,模型会在验证集上评估性能,输出损失和准确率等指标。
  • 如果发现模型在某个 epoch 的性能停滞,可以考虑调整学习率或增加训练轮数。

评估模型

训练完成后,使用验证集评估模型的性能:

results = trainer.evaluate()
print(results)

评估结果将提供模型在验证集上的损失、准确率等信息。这些指标可以帮助你判断模型是否过拟合或欠拟合。

保存微调后的模型

将微调后的模型和分词器保存,以便后续使用:

trainer.save_model('./fine_tuned_model')
tokenizer.save_pretrained('./fine_tuned_model')

测试模型

使用新的问题来测试微调后的模型,确保其能够正确回答医疗相关问题:

test_question = "糖尿病的饮食注意事项有哪些?"
input_ids = tokenizer.encode(test_question, return_tensors='pt')
outputs = model.generate(input_ids)
predicted_answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"预测答案: {predicted_answer}")

持续优化

如果模型的表现不尽如人意,可以通过以下方式进一步优化:

  • 数据扩充:增加更多高质量的数据,尤其是针对模型表现不好的问题。
  • 超参数调整:根据验证集的评估结果调整学习率、训练轮数等超参数。
  • 增量训练:在新的数据上进行增量训练,而不是完全从头开始训练。

简易版的结构说明
简易版的结构说明


总结

调优大语言模型的过程是一个系统化的流程,从数据准备到模型训练、评估和保存,都需要精心设计和执行。确保数据的质量和模型的适用性是关键。在整个过程中,监控模型的性能和调整超参数也是非常重要的。如果你有具体的问题或想了解某个步骤的更多细节,请随时询问!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1540929.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

RNN的反向传播

目录 1.RNN网络:通过时间反向传播(through time back propagate TTBP) 2.RNN梯度分析 2.1隐藏状态和输出 2.2正向传播: 2.3反向传播: 2.4问题瓶颈: 3.截断时间步分类: 4.截断策略比较 5.反向传播的细节 ​编辑…

达梦数据库踩坑

提示:第一次接触达梦,是真的不好用,各种报错不提示详细信息,吐槽归吐槽,还是需要学习使用的。 前言 题主刚接触达梦数据库时,本来是想下载官网的连接工具进行数据库连接的,但是谁曾想&#xff…

监控易监测对象及指标之:全面监控GBase数据库

在数字化时代,数据库作为企业核心数据资产的管理中心,其稳定性和性能直接关系到业务的连续性和企业的运营效率。GBase数据库作为高性能的分布式数据库系统,广泛应用于各类业务场景。为了确保GBase数据库的稳定运行和高效性能,对其…

git安装包夸克网盘下载

git安装包夸克网盘下载 git夸克网盘 git网站上的安装包下载速度有点慢,因此为了方便以后下载就将文件保存到夸克网盘上,链接:我用夸克网盘分享了「git」,点击链接即可保存。 链接:https://pan.quark.cn/s/07c73c4a30…

C++速通LeetCode中等第12题-矩阵置零(空间O(1)含注释)

class Solution { public:void setZeroes(vector<vector<int>>& matrix) {int m matrix.size();int n matrix[0].size();int flag_col0 false, flag_row0 false;//先记录第一行和第一列是否有零for (int i 0; i < m; i) {if (!matrix[i][0]) {flag_col…

基于单片机的智能健康水杯设计

摘要&#xff1a;随着时代的发展&#xff0c;单片机领域不断扩张。人工智能产品的出现改变了人们的生活方式。智能产品不仅加快了人们的生活节奏&#xff0c;还为人们的安全提供了保障。在快节奏生活的同时&#xff0c;人们开始越来越关注自己的身体健康&#xff0c;基于 52 单…

高级java每日一道面试题-2024年9月20日-分布式篇-什么是CAP理论?

如果有遗漏,评论区告诉我进行补充 面试官: 什么是CAP理论&#xff1f; 我回答: 在Java高级面试中&#xff0c;CAP理论是一个经常被提及的重要概念&#xff0c;它对于理解分布式系统的设计和优化至关重要。CAP理论是分布式系统理论中的一个重要概念&#xff0c;它描述了一个分…

c++11右值引用和移动语义

一.左值引用和右值引用 什么是左值引用&#xff0c;什么是右值引用 左值是一个表示数据的表达式&#xff08;变量名解引用的指针&#xff09;&#xff0c;我们可以获取到它的地址&#xff0c;可以对它赋值&#xff0c;左值可以出现在符号的左边。使用const修饰后&#xff0c;…

通威股份半年报业绩巨降:销售费用大增,近一年股价跌四成

《港湾商业观察》施子夫 王璐 光伏领域龙头企业通威股份&#xff08;600438.SH&#xff09;交出的半年报延续了2023年营收和净利润双下滑趋势&#xff0c;幅度显得更大。 即便受行业波动影响&#xff0c;但如何重整及提升盈利能力&#xff0c;通威股份还需要给出解决方案。​…

vue项目关闭浏览器中的全屏错误提示

vue.config.js module.exports {devServer: {client: {overlay: false }} }

c++优先级队列自定义排序实现方式

1、使用常规方法实现 使用结构体实现自定义排序函数 2、使用lambda表达式实现 使用lambda表达式实现自定义排序函数 3、具体实现如下&#xff1a; #include <iostream> #include <queue> #include <vector>using namespace std; using Pair pair<in…

什么是大模型的泛化能力?

大模型的泛化能力指的是模型在未见过的数据上表现的能力&#xff0c;即模型不仅能在训练数据上表现良好&#xff0c;也能在新的、未知的数据集上保持良好的性能。这种能力是衡量机器学习模型优劣的重要指标之一。 泛化能力的好处包括但不限于&#xff1a; 提高模型的适应性&a…

Qt 构建目录

Qt Creator新建项目时&#xff0c;选择构建套件是必要的一环&#xff1a; 构建目录的默认设置 在Qt Creator中&#xff0c;项目的构建目录通常是默认设置的&#xff0c;位于项目文件夹内的一个子文件夹中&#xff0c;如&#xff1a;build-项目名-Desktop_Qt_版本号_编译器类型_…

电子烟智能化创新体验:WTK6900P语音交互芯片方案,融合频谱计算、精准语音识别与流畅音频播报

一&#xff1a;开发背景 在这个科技日新月异的时代&#xff0c;每一个细节的创新都是对传统的一次超越。今天&#xff0c;我们自豪地宣布一项革命性的融合——将先进的语音识别技术与电子烟相结合&#xff0c;通过WTK6900P芯片的卓越性能&#xff0c;为您开启前所未有的个性化…

稀疏向量 milvus存储检索RAG使用案例

参考&#xff1a; https://milvus.io/docs/hybrid_search_with_milvus.md milvus使用不方便&#xff1a; 1&#xff09;离线计算向量很慢BGEM3EmbeddingFunction 2&#xff09;milvus安装环境支持很多问题&#xff0c;不支持windows、centos等 在线demo&#xff1a; https://co…

Hadoop 常用生态组件

Hadoop核心组件 安装 Hadoop 时&#xff0c;通常会自动包含以下几个关键核心组件&#xff0c;特别是如果使用了完整的 Hadoop 发行版&#xff08;如 Apache Hadoop、Cloudera 或 Hortonworks 等&#xff09;。这些组件构成了 Hadoop 的核心&#xff1a; 1. HDFS&#xff08;H…

-bash: apt-get: command not found -bash: yum: command not found

1. 现象&#xff1a; 1.1. 容器内使用apt-get, yum 提示命令未找到 1.2. dockerfile制作镜像时候&#xff0c;使用apt-get, yum同样报此错误。 2.原因&#xff1a; 2.1. linux 分为&#xff1a; 1. RedHat系列&#xff1a; Redhat、Centos、Fedora等 2. Debian系列&#xff1a…

ABAP-Swagger 一种公开 ABAP REST 服务的方法

ABAP-Swagger An approach to expose ABAP REST services 一种公开 ABAP REST 服务的方法 Usage 1: develop a class in ABAP with public methods 2: implement interface ZIF_SWAG_HANDLER, and register the public methods(example method zif_swag_handler~meta) 3: …

初体验《SpringCloud 核心组件Eureka》

文章目录 1.案例准备1.1 案例说明1.2 案例数据库准备1.3 环境搭建1.3.1. 创建一个空的项目1.3.2. 创建Maven工程1.3.3. 配置父工程依赖&#xff0c;SpringCloud版本以及对应的SpringBoot版本1.3.4. 创建公共模块1.3.5. 创建用户模块工程1.3.5.1 引入依赖以及配置文件1.3.5.2 在…

C/C++通过CLion2024进行Linux远程开发保姆级教学

目前来说&#xff0c;对Linux远程开发支持相对比较好的也就是Clion和VSCode了&#xff0c;这两个其实对于C和C语言开发都很友好&#xff0c;大可不必过于纠结使用那个&#xff0c;至于VS和QtCreator&#xff0c;前者太过重量级了&#xff0c;后者更是不用说&#xff0c;主要用于…