11. DPO 微调示例:根据人类偏好优化LLM大语言模型

在部署大模型之后,我们必然要和微调打交道。现在大模型的微调有非常多的方法,过去的文章中提到的微调方法通常依赖于问题和答案对,标注成本较高。

2023 年所提出的 Direct Preference Optimization(DPO)为我们提供了一种无需标准标注答案的高效微调方法。DPO 依赖于人类对文本的偏好对(preference pairs),也就是说,数据集中只包含人类对两段文本中哪段更好的判断,而不是具体的正确答案。

在本文中,我们将利用 DPO 来微调一个模型让其按照偏好进行输出。这篇文章也为生成式人工智能导论课程中 HW6: LLM Values Alignment 提供中文引导。

代码文件下载 | 作业PDF

安装和导入一些必要的库

pip install bitsandbytes==0.43.1 datasets==2.19.0 peft==0.10.0 trl==0.8.6 accelerate==0.29.3
import os
import re
import jsonimport torch
import pandas as pd
from tqdm.auto import tqdmfrom datasets import Dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig, GenerationConfig
from trl import DPOTrainer

可能的问题:Keras 3 与 Transformers 不兼容

在导入时,你可能会看到以下报错:

RuntimeError: Failed to import trl.trainer.dpo_trainer because of the following error (look up to see its traceback):
Failed to import transformers.trainer because of the following error (look up to see its traceback):
Failed to import transformers.integrations.integration_utils because of the following error (look up to see its traceback):
Failed to import transformers.modeling_tf_utils because of the following error (look up to see its traceback):
Your currently installed version of Keras is Keras 3, but this is not yet supported in Transformers. Please install the backwards-compatible tf-keras package with pip install tf-keras.

transformers 库建议安装兼容的 tf-keras 包来解决这个兼容性问题。你可以通过以下命令安装:

pip install tf-keras

现在问题应该得到了解决。

加载数据集

我们将使用预先提供的数据集,包括带标签的偏好数据和测试提示数据。

这个数据集来自于生成式人工智能导论的HW6,处理的问题是:是否应该将动漫真人化?两个回答分别对应支持和不支持(由GPT生成),在后面的代码中你将选择支持的占比。

git clone https://github.com/Baiiiiiiiiii/GenAI_hw6_dataset.git
with open("./GenAI_hw6_dataset/labelled_data.json", 'r') as jsonfile:full_data = json.load(jsonfile)with open("./GenAI_hw6_dataset/test_prompt.json", 'r') as jsonfile:test_data = json.load(jsonfile)

直观理解数据集:

full_data

image-20240919114655048

使用 HFD 下载模型

我们这里使用多线程的方法进行快速下载。

如果直接运行以下命令报错,根据 a. 使用 HFD 加快 Hugging Face 模型和数据集的下载 进行前置安装。

当然,你也可以取消我注释的部分,使用官方的命令进行安装,但是会很慢。

安装工具

sudo apt-get update
sudo apt-get install git wget curl aria2 git-lfs
git lfs install

下载 hfd 并修改权限

wget https://hf-mirror.com/hfd/hfd.sh
chmod a+x hfd.sh

多线程下载模型

export HF_ENDPOINT=https://hf-mirror.com
./hfd.sh 'MediaTek-Research/Breeze-7B-Instruct-v0_1' --tool aria2c -x 16

下载

加载模型

将使用MediaTek-Research/Breeze-7B-Instruct-v0_1模型进行微调。

model = AutoModelForCausalLM.from_pretrained('MediaTek-Research/Breeze-7B-Instruct-v0_1',device_map='auto',trust_remote_code=True,quantization_config=BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=torch.bfloat16,bnb_4bit_use_double_quant=True,bnb_4bit_quant_type='nf4')
)

这里,我们采用了4位量化(4-bit quantization)来减少模型的内存占用,加快推理速度。

查看未经过微调的模型原始输出

在进行微调之前,我们首先查看一下原始模型的输出效果。首先,加载分词器:

tokenizer = AutoTokenizer.from_pretrained('MediaTek-Research/Breeze-7B-Instruct-v0_1')
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token

定义一个数据处理函数,将数据格式化为模型可以接受的输入,我们这里的 prompt 延续原来的繁体(因为Breeze-7B-Instruct-v0_1更多使用繁体中文进行训练,你并不需要修改它):

def data_formulate(data):messages = [{"role": "system", "content": '回覆請少於20字'},{"role": "user", "content": data['prompt']},]prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)return prompt

接下来,生成原始模型的响应:

original_model_response = []
for data in tqdm(test_data):id = data['id']print(f'Question {id}:\n'+data['prompt'])inputs = tokenizer(data_formulate(data), return_tensors="pt").to('cuda')generation_config=GenerationConfig(do_sample=False,max_new_tokens = 200,pad_token_id = tokenizer.pad_token_id)output = model.generate(**inputs, generation_config=generation_config)output = tokenizer.batch_decode(output, skip_special_tokens=True)[0].split('[/INST] ')[1]original_model_response.append(output)print('Response from original model:\n'+output+'\n')

这段代码将遍历测试数据集,生成并打印每个问题的原始模型响应。

image-20240919113918391

设置参数

你只需要修改这个模块,不需要改变其他的,除非你真的知道自己在做什么。

support_ratio 将反映你的偏好:

  • 0 表示完全不支持(反对)真人化
  • 1 表示完全支持真人化
  • 0.1 表示 10% 支持真人化
num_epoch = 1
data_size = 50
support_ratio = 0.1

准备训练数据

这里,我们将数据集分为支持(support)和反对(oppose)两部分,构建一个包含偏好对的训练数据集(是的,这里就是 DPO)。

# 选择部分数据用于训练
training_data = full_data[:data_size]# 定义 support 数据集的大小
support_data_size = int(data_size * support_ratio)# 为训练数据集准备数据
prompt_list = [data_formulate(data) for data in training_data]
chosen_list = [data['support'] for data in training_data[:support_data_size]] + [data['oppose'] for data in training_data[support_data_size:]]
rejected_list = [data['oppose'] for data in training_data[:support_data_size]] + [data['support'] for data in training_data[support_data_size:]]
position_list = ['support' for _ in range(support_data_size)] + ['oppose' for _ in range(data_size - support_data_size)]# 创建训练数据集
train_dataset = Dataset.from_dict({'prompt': prompt_list, 'position': position_list, 'chosen': chosen_list, 'rejected': rejected_list})
pd.DataFrame(train_dataset).rename(columns={"chosen": "preferred", "rejected": "non-preferred"})

总共有 50 笔训练数据,当 support 设置为 0.1 时,前 50*0.1=5 笔训练资料的偏好将倾向于支持真人化,后 50-4=45 笔资料反对真人化。

image-20240919114949791

训练

现在,我们进入训练阶段。首先,设置训练参数:

training_args = TrainingArguments(output_dir='./',per_device_train_batch_size=1,num_train_epochs=num_epoch,gradient_accumulation_steps=8,gradient_checkpointing=False,learning_rate=2e-4,optim="paged_adamw_8bit",logging_steps = 1,warmup_ratio = 0.1,report_to = 'none'
)

接下来,配置PEFT(Parameter-Efficient Fine-Tuning):

peft_config = LoraConfig(lora_alpha=16,lora_dropout=0.1,r=64,bias="none",task_type="CAUSAL_LM",
)

然后,初始化DPO训练器:

dpo_trainer = DPOTrainer(model,args=training_args,beta=0.1,train_dataset=train_dataset,tokenizer=tokenizer,peft_config=peft_config,
)

开始训练:

dpo_trainer.train()

image-20240919115410184

查看微调后的模型输出

训练完成后,我们需要查看微调后的模型效果。以下是生成训练后模型响应的代码:

trained_model_response = []
for data in tqdm(test_data):id = data['id']print(f'Question {id}:\n'+data['prompt'])inputs = tokenizer(data_formulate(data), return_tensors="pt").to('cuda')generation_config=GenerationConfig(do_sample=False,max_new_tokens = 200,pad_token_id = tokenizer.pad_token_id)output = model.generate(**inputs, generation_config=generation_config)output = tokenizer.batch_decode(output, skip_special_tokens=True)[0].split('[/INST] ')[1]trained_model_response.append(output)print('Response from trained model:\n'+output+'\n')

这段代码与之前生成原始模型响应的代码类似,但这次生成的是经过微调后的模型响应:

image-20240919115643310

观察输出结果

最后,我们对比微调前后的模型响应,观察DPO方法带来的效果提升:

model_response = []
print(f'num_epoch: {num_epoch}\ndata_size: {data_size}\nsupport_ratio: {support_ratio}')
print()
for data in test_data:id = data['id']ref_output = original_model_response[id-1]output = trained_model_response[id-1]print(f'Question {id}:\n'+data['prompt'])print('Response from original model:\n'+ref_output)print('Response from trained model:\n'+output)print()model_response.append({'id':data['id'], 'prompt':data['prompt'], 'response_from_original_model':ref_output, 'response_from_trained_model':output})

image-20240919115708299

拓展

在使用 GPT 的时候你应该也见到过其同时生成两个回答让我们选择更倾向于哪个,这个和 Google 验证码有着异曲同工之妙。

进一步

12. Inseq 特征归因:可视化解释 LLM 的输出
李宏毅2024生成式人工智能导论 中文镜像版指导与作业

推荐阅读

Direct Preference Optimization: Your Language Model is Secretly a Reward Model

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/146168.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

C++——map和set的使用以及map系列

目录 map和set的使用 1. 序列式容器和关联式容器 2. set系列的使⽤ 2.1 set和multiset参考⽂档 2.2 set类的介绍 2.3 set的构造和迭代器 2.4 set的增删查 set的增删查关注以下⼏个接⼝即可: 2.6 find和erase使⽤样例: lower_bound(); upper_bo…

Python 从入门到实战23(属性property)

我们的目标是:通过这一套资料学习下来,通过熟练掌握python基础,然后结合经典实例、实践相结合,使我们完全掌握python,并做到独立完成项目开发的能力。 上篇文章我们讨论了类的定义、使用方法的相关知识。今天我们将学…

uboot:源码分析-启动第一阶段-start.S解析

start.S引入 进入start.S文件中,发现57行中就是_start标号的定义处 SourceInsight中添加行号 在SI中,如果我们知道我们要找的文件的名字,但是我们又不知道他在哪个目录下,我们要怎样找到并打开这个文件?方法是在SI中先…

教你快速制作一本3D翻页电子杂志

​在制作3D翻页电子杂志之前,我们需要了解一些基本概念。3D翻页电子杂志主要通过翻页效果来展示内容,读者可以通过手指滑动或点击鼠标来进行翻页。此外,它还支持图片、文字、视频等多种媒体形式的展示,为读者带来全方位的阅读体验…

KTH5774 —— 3D 摇杆/操纵杆霍尔位置传感器芯片

KTH5774 是一款摇杆、操纵杆专用的 3D 霍尔磁感 应芯片,主要面向对线性度和可靠性要求严格的应用 场景。 KTH5774 基于 3D 霍尔技术,内部分别集成了 X 轴、 Y 轴和 Z 轴三个独立的霍尔元件,能够通过测量和 处理磁通密度矢量的三个空间分量…

决策树算法中篇

手动计算实现决策树分类 数据整合 X[真实用户] y X 计算未划分信息熵 s X[真实用户] p s.value_counts()/s.size (p * np.log2(1/p)).sum() 按照日志密度进行划分 x X[日志密度].unique() x.sort() # 如何划分呢,分成两部分 for i in range(len(x) - 1):sp…

Checkstyle 使用总结

1. 使用 GitHub 地址:checkstyle/checkstyle: Checkstyle is a development tool to help programmers write Java code that adheres to a coding standard. 官网文档地址:checkstyle – Checkstyle 10.17.0 1.1 IDEA 插件 在 IDEA 搜索插件 CheckS…

DOS(Disk Operating System,磁盘操作系统)常用指令

目录 背景: 早期探索: DOS之父: 发展历程: 常用指令: 进入命令: 操作1.进入和回退: 操作2.增、删: 操作3.其它: 总结: 背景: 早期探索: DOS(Disk Operating System,磁盘操作系统)在…

基于云的补丁管理

什么是云补丁 云补丁或基于云的补丁管理是指扫描和检测缺失补丁、测试补丁并将它们部署到所需系统的过程,所有这些都通过基于云的控制台或软件完成。虽然补丁管理工作流程通常保持不变,但基于云的补丁管理的主要区别在于,整个过程仅通过基于…

数据跨境流通发展现状浅析

文章目录 前言一、数据跨境流通的场景二、数据跨境流通国内发展现状三、数据跨境流通国外发展现状1、国外的数据跨境政策类型:(1)美国以数据自由流动为核心(2)欧盟将人权保护作为首要考虑(3)俄罗…

2.1 SQL语言及如何创建数据表

一、什么是SQL语言 SQL语言全称叫做结构化查询语言,它是一种计算机语言,但是跟其他编程语言来比较还是有很大区别的。比如说HTML,CSS,Java script,这三种计算机语言是用在网页设计上面的。那么swift语言是用来开发IOS…

反转字符串中的单词--力扣151

反转字符串中的单词 题目思路代码 题目 思路 题目的难点在于首先要清除多余的空格,并且单词之间要留一个空格,首单词前和末尾单词后不能有多余空格。我们使用双指针去除所有的空格,然后在处理完一个单词后手动加一个单词。具体思路是当快指针…

k8s快速搭建+prometheus部署及使用(纯干货!!!)

目录 环境准备 1.所有主机安装docker 2.部署harbor 3.部署k8s 集群初始化 安装网络插件(此时选择的是flannel网络插件 后面也有calico网络插件的安装方法) 节点扩容 4.calico网络插件的部署(如果安装了flannel插件需要先删除&#xf…

web前端-HTML常用标签-综合案例

如图&#xff1a; 代码如下&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document&…

LlamaIndex 中的 NodeParser

LlamaIndex 中 Document 会被转成 Node,Node 中的文字会进行 Embedding,最终保留向量数据做后续的搜索处理。这里的关键步骤是 Document 转为 Node 的策略,LlamaIndex 内置了多个 Document Reader 和 Node Parser,每个 NodeParser 都有自己的策略,需在初始化时进行设置。 …

基于springboot+vue超市管理系统

基于springbootvue超市管理系统 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本无人超市管理系统就是在这样的大环境下诞生&#xff0c;其可以帮助使用者在…

STM32如何修改外部晶振频率和主频

对于STM32F10x系列的单片机&#xff0c;除了STM32F10x_CL单片机&#xff0c;其它的单片机一般外部晶振HSE的时钟频率都默认是8MHz。如果我们使用的外部晶振为12Mhz&#xff0c;那么可以把上图绿色标记改为:12000000 72MHz的主频8MHz的外部晶振HSE*倍频系数9。当然如果像上面把外…

四款好用的电脑录屏工具推荐!!

在科技日益发展的今天&#xff0c;屏幕录制已成为我们工作、学习和娱乐中不可或缺的一部分&#xff1b;无论是制作教程、记录游戏过程还是分享精彩瞬间&#xff0c;一个好的录屏工具都是不可或缺的&#xff1b;今天&#xff0c;我就为大家推荐四款实用又好用的电脑录屏工具&…

矿用立式负压自动排渣放水器感恩遇见

做良心产品一直是我们的初心好产品加上好服务&#xff0c;让您满意是我们一直的追求只凭低价去换取销量的话&#xff0c;就会想方设法降低成苯质量难有保障 矿用立式负压自动排渣放水器感恩遇见 概述 负压自动排渣放水器的型号为YCFP&#xff0c;YC指品牌永成&#xff0c;FP指…

mac os x 找不到钥匙串访问

昨天手贱更新了最新的mac系统&#xff0c;结果在实用工具中找不到钥匙串访问APP了。。。 最新mac系统为 15.0 (24A335) 真是醉了。。。 那就得想办法把他给呼出来&#xff0c;在开发者中心下载了一个.cer文件&#xff0c;然后双击打开&#xff0c;此时钥匙串打开了&#xff…