大模型微调,使用QLoRA和自定义数据集微调大模型(下)

4.8 数据预处理

在微调模型之前,我们不能直接使用原始数据集,需要将数据集中的提示转换成模型能够理解的格式。

为了使数据集适配微调流程,这里编写辅助函数来格式化输入数据集。具体来说,就是将对话摘要(即提示-响应对)转换成大型语言模型(LLM)能够识别的明确指令。

def create_prompt_formats(sample):"""格式化样本的各种字段('instruction', 'output')然后使用两个换行符将它们连接起来:param sample: 样本字典"""INTRO_BLURB = "以下是描述任务的指令。写一个适当完成请求的响应。"INSTRUCTION_KEY = "### 指令:总结以下对话。"RESPONSE_KEY = "### 输出:"END_KEY = "### 结束"blurb = f"\n{INTRO_BLURB}"instruction = f"{INSTRUCTION_KEY}"input_context = f"{sample['dialogue']}" if sample["dialogue"] else Noneresponse = f"{RESPONSE_KEY}\n{sample['summary']}"end = f"{END_KEY}"parts = [part for part in [blurb, instruction, input_context, response, end] if part]formatted_prompt = "\n\n".join(parts)sample["text"] = formatted_promptreturn sample

上述函数负责将输入数据转换成提示格式。

接下来,我们会用模型的分词器对这些提示进行处理,将其转换成标记化的形式。我们的目标是生成长度统一的输入序列,这样做有助于微调语言模型,因为它能提升处理效率并降低计算成本。同时,我们还得确保这些序列的长度不超过模型允许的最大标记数。

from functools import partial# 来源 https://github.com/databrickslabs/dolly/blob/master/training/trainer.py
def get_max_length(model):conf = model.configmax_length = Nonefor length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:max_length = getattr(model.config, length_setting, None)if max_length:print(f"找到最大长度:{max_length}")breakif not max_length:max_length = 1024print(f"使用默认最大长度:{max_length}")return max_lengthdef preprocess_batch(batch, tokenizer, max_length):"""分批标记化"""return tokenizer(batch["text"],max_length=max_length,truncation=True,)# 来源 https://github.com/databrickslabs/dolly/blob/master/training/trainer.py
def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, seed, dataset):"""格式化并标记化,使其准备好进行训练:param tokenizer (AutoTokenizer): 模型分词器:param max_length (int): 分词器发出的最大标记数"""# 为每个样本添加提示print("预处理数据集...")dataset = dataset.map(create_prompt_formats)#, batched=True)# 对每个批次的数据集应用预处理并移除 'instruction', 'context', 'response', 'category' 字段_preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer)dataset = dataset.map(_preprocessing_function,batched=True,remove_columns=['id', 'topic', 'dialogue', 'summary'],)# 过滤掉输入_ids超过最大长度的样本dataset = dataset.filter(lambda sample: len(sample["input_ids"]) < max_length)# 随机打乱数据集dataset = dataset.shuffle(seed=seed)return dataset

通过这些函数的处理,我们的数据集已经准备好进行微调了。

## 数据集预处理
max_length = get_max_length(original_model)
print(max_length)train_dataset = preprocess_dataset(tokenizer, max_length,seed, dataset['train'])
eval_dataset = preprocess_dataset(tokenizer, max_length,seed, dataset['validation'])

4.9 准备模型进行QLoRA训练

使用PEFT库中的prepare_model_for_kbit_training方法来准备模型。

通过这个方法,我们对原始模型original_model进行初始化,设置好必要的配置,以便进行QLoRA训练。

4.10 设置PEFT进行微调

接下来,我们配置LoRA参数,以便对基础模型进行微调。

from peft import LoraConfig, get_peft_modelconfig = LoraConfig(r=32,  # 定义适配器的秩lora_alpha=32,  # 学习权重的缩放因子target_modules=['q_proj', 'k_proj', 'v_proj', 'dense'],bias="none",lora_dropout=0.05,  # 防止过拟合task_type="CAUSAL_LM",  # 任务类型
)# 启用梯度检查点减少内存消耗
original_model.gradient_checkpointing_enable()# 获取配置好的PEFT模型
peft_model = get_peft_model(original_model, config)

这里的r(秩)参数控制适配器的复杂度,影响模型的表达能力和计算成本。lora_alpha参数用于调整学习权重,影响LoRA激活的强度。配置完成后,我们可以通过辅助函数查看模型的可训练参数数量。

print(print_number_of_trainable_model_parameters(peft_model))

图片

4.11 训练PEFT适配器

来设置训练参数,并初始化训练器。

import transformers
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling# 设置训练参数
output_dir = f'./peft-dialogue-summary-training-{str(int(time.time()))}'
peft_training_args = TrainingArguments(output_dir=output_dir,warmup_steps=1,per_device_train_batch_size=1,gradient_accumulation_steps=4,max_steps=1000,learning_rate=2e-4,optim="paged_adamw_8bit",logging_steps=25,logging_dir="./logs",save_strategy="steps",save_steps=25,evaluation_strategy="steps",eval_steps=25,do_eval=True,gradient_checkpointing=True,report_to="none",overwrite_output_dir=True,group_by_length=True,
)# 禁用缓存,准备训练器
peft_model.config.use_cache = False
peft_trainer = Trainer(model=peft_model,args=peft_training_args,train_dataset=train_dataset,eval_dataset=eval_dataset,data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

我们计划进行1000步训练,这个数字对于我们的特定数据集应该是合适的。不过,可能需要根据实际情况调整这个数值。现在,可以开始训练了,训练时间会根据超参数的不同而有所差异。

peft_trainer.train()

训练完成后,准备一个用于推理的模型。给原始的Phi-2模型添加一个适配器,并设置为不可训练状态,因为我们只打算用它来做推理。

import torch
from transformers import AutoModelForCausalLM, AutoTokenizerbase_model_id = "microsoft/phi-2"
base_model = AutoModelForCausalLM.from_pretrained(base_model_id)
eval_tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True)
eval_tokenizer.pad_token = eval_tokenizer.eos_tokenfrom peft import PeftModelft_model = PeftModel.from_pretrained(base_model,"/kaggle/working/peft-dialogue-summary-training-1705417060/checkpoint-1000",torch_dtype=torch.float16,is_trainable=False
)

微调是一个需要反复试验的过程。我们可能需要根据验证集和测试集的表现来调整模型结构、超参数或训练数据,以提升模型性能。

4.12 定性评估模型(人工评估)

使用PEFT模型对同一个输入进行推理,以评估其性能。

from transformers import set_seed
set_seed(seed)# 选择测试集中的一个对话样本
index = 5
dialogue = dataset['test'][index]['dialogue']
summary = dataset['test'][index]['summary']# 构建推理提示
prompt = f"Instruct: Summarize the following conversation.\n{dialogue}\nOutput:\n"# 使用PEFT模型生成摘要
peft_model_res = gen(ft_model, prompt, 100)
peft_model_output = peft_model_res[0].split('Output:\n')[1]
prefix, _, _ = peft_model_output.partition('###')# 打印结果
dash_line = '-' * 100
print(dash_line)
print(f'输入提示:\n{prompt}')
print(dash_line)
print(f'人工摘要:\n{summary}\n')
print(dash_line)
print(f'PEFT模型输出:\n{prefix}')

这段代码将展示PEFT模型相对于人类基线摘要的表现。通过比较模型输出和人类生成的摘要,我们可以定性地评估PEFT模型的性能。

图片

PEFT 模型输出

4.13 定量评估模型(使用ROUGE指标)

ROUGE(Recall-Oriented Understudy for Gisting Evaluation)是一种评估自动摘要和机器翻译软件的工具,它通过比较机器生成的摘要与人工生成的参考摘要来衡量效果。虽然不完美,但ROUGE指标能有效反映微调后摘要质量的整体提升。

我们通过ROUGE指标来定量评估模型生成的摘要。以下是评估过程:

from transformers import AutoModelForCausalLM
import pandas as pd
import evaluate
import numpy as np# 加载模型和数据
original_model = AutoModelForCausalLM.from_pretrained(base_model_id, device_map='auto')
dialogues = dataset['test'][0:10]['dialogue']
human_baseline_summaries = dataset['test'][0:10]['summary']# 初始化摘要列表
original_model_summaries = []
peft_model_summaries = []# 生成摘要并评估
for idx, dialogue in enumerate(dialogues):prompt = f"Instruct: Summarize the following conversation.\n{dialogue}\nOutput:\n"original_model_res = gen(original_model, prompt, 100)original_model_text_output = original_model_res[0].split('Output:\n')[1]peft_model_res = gen(ft_model, prompt, 100)peft_model_output = peft_model_res[0].split('Output:\n')[1]peft_model_text_output, _, _ = peft_model_output.partition('###')original_model_summaries.append(original_model_text_output)peft_model_summaries.append(peft_model_text_output)# 将结果存入DataFrame
df = pd.DataFrame(list(zip(human_baseline_summaries, original_model_summaries, peft_model_summaries)),columns=['human_baseline_summaries', 'original_model_summaries', 'peft_model_summaries'])# 计算ROUGE指标
rouge = evaluate.load('rouge')
original_model_results = rouge.compute(predictions=original_model_summaries, references=human_baseline_summaries)
peft_model_results = rouge.compute(predictions=peft_model_summaries, references=human_baseline_summaries)# 打印结果
print('原始模型ROUGE指标:')
print(original_model_results)
print('PEFT模型ROUGE指标:')
print(peft_model_results)# 计算提升百分比
improvement = (np.array(list(peft_model_results.values())) - np.array(list(original_model_results.values())))
for key, value in zip(peft_model_results.keys(), improvement):print(f'{key}提升: {value*100:.2f}%')

通过ROUGE指标,我们可以看到PEFT模型相较于原始模型在摘要质量上的显著提升。

图片

5 结语

微调大型语言模型(LLM)已成为寻求优化运营流程的企业必不可少的步骤。虽然LLM的初始训练赋予了广泛的语言理解能力,但微调过程将这些模型细化为能够处理特定主题并提供更准确结果的专用工具。为不同任务、行业或数据集定制LLM扩展了这些模型的能力,确保了它们在不断变化的数字环境中的相关性和价值。展望未来,LLM的持续探索和创新,加上精细的微调方法,有望推进更智能、更高效、更具情境意识的人工智能系统的发展。

如何学习AI大模型?

作为一名热心肠的互联网老兵,我决定把宝贵的AI知识分享给大家。 至于能学习到多少就看你的学习毅力和能力了 。我已将重要的AI大模型资料包括AI大模型入门学习思维导图、精品AI大模型学习书籍手册、视频教程、实战学习等录播视频免费分享出来。

这份完整版的大模型 AI 学习资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

一、全套AGI大模型学习路线

AI大模型时代的学习之旅:从基础到前沿,掌握人工智能的核心技能!

img

二、640套AI大模型报告合集

这套包含640份报告的合集,涵盖了AI大模型的理论研究、技术实现、行业应用等多个方面。无论您是科研人员、工程师,还是对AI大模型感兴趣的爱好者,这套报告合集都将为您提供宝贵的信息和启示。

img

三、AI大模型经典PDF籍

随着人工智能技术的飞速发展,AI大模型已经成为了当今科技领域的一大热点。这些大型预训练模型,如GPT-3、BERT、XLNet等,以其强大的语言理解和生成能力,正在改变我们对人工智能的认识。 那以下这些PDF籍就是非常不错的学习资源。

img

四、AI大模型商业化落地方案

img

作为普通人,入局大模型时代需要持续学习和实践,不断提高自己的技能和认知水平,同时也需要有责任感和伦理意识,为人工智能的健康发展贡献力量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/7869.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【NOIP普及组】质因数分解

【NOIP普及组】质因数分解 C语言代码C代码Java代码Python代码 &#x1f490;The Begin&#x1f490;点点关注&#xff0c;收藏不迷路&#x1f490; 已知正整数 n 是两个不同的质数的乘积&#xff0c;试求出较大的那个质数。 输入 输入只有一行&#xff0c;包含一个正整数…

js--高阶函数之参数归一化

一、前言 参数归一化&#xff1a;是我们软件开发里一个非常重要且实用的技巧&#xff0c;用的好极大简化代码同时提升代码的可阅读性和可维护性。以下我用日期格式化为例&#xff0c;演示一下参数归一化的技巧。 二、日期格式化实例 /*** 辅助格式化函数* param {string|functi…

均值、期望、方差、标准差与协方差:基础概念解析

均值、期望、方差、标准差与协方差&#xff1a;基础概念解析 在统计学和数据分析中&#xff0c;均值、期望、方差、标准差和协方差是描述数据分布和关系的基本工具。理解这些概念有助于我们更好地分析和处理数据。本文将详细讲解这些概念的定义、计算方法及其在实际应用中的意…

shell基础

一、理解bash基础 默认的Linux shell——Bash&#xff08;Bourne Again SHell&#xff09;可以通过命令控制系统&#xff0c;执行文件操作&#xff0c;或者启动应用程序。它可以在命令行上交互式使用&#xff0c;或者你可以创建一个包含多个shell命令的文件&#xff0c;并像启…

js树状结构,自叶到根统计各级数量

$($(".tree-item").get().reverse()).each(function () {let self $(this).find("span").text()let prev $(this).parent(".two").prevAll(".tree-item").find("span").text()self self ? self : 0prev prev ? prev :…

学习threejs,使用JSON格式保存和加载整个场景

&#x1f468;‍⚕️ 主页&#xff1a; gis分享者 &#x1f468;‍⚕️ 感谢各位大佬 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍⚕️ 收录于专栏&#xff1a;threejs gis工程师 文章目录 一、&#x1f340;前言1.1 ☘️THREE toJSON()方法 二、&a…

论文1—《基于卷积神经网络的手术机器人控制系统设计》文献阅读分析报告

论文报告&#xff1a;基于卷积神经网络的手术机器人控制系统设计 摘要 本研究针对传统手术机器人控制系统精准度不足的问题&#xff0c;提出了一种基于卷积神经网络的手术机器人控制系统设计。研究设计了控制系统的总体结构&#xff0c;并选用PCI插槽上直接内插CAN适配卡作为上…

SLF4J: Failed to load class “org.slf4j.impl.StaticLoggerBinder“

SLF4J常见问题 导入依赖&#xff1a; <dependency><groupId>log4j</groupId><artifactId>log4j</artifactId><version>1.2.17</version> </dependency> <dependency><groupId>org.slf4j</groupId><arti…

资产管理系统:SpringBoot技术驱动

4系统概要设计 4.1概述 系统设计原则 以技术先进、系统实用、结构合理、产品主流、低成本、低维护量作为基本建设原则&#xff0c;规划系统的整体构架. 先进性&#xff1a; 在产品设计上&#xff0c;整个系统软硬件设备的设计符合高新技术的潮流&#xff0c;媒体数字化、压缩、…

YOLO可视化界面,目标检测前端页面。

使用PySide6/QT实现YOLOv5/v8可视化GUI页面 在人工智能和计算机视觉领域&#xff0c;YOLO&#xff08;You Only Look Once&#xff09;是一种广泛使用的实时目标检测算法。为了直观地展示YOLO算法的检测效果&#xff0c;我们可以使用Python中的PySide6库来创建一个简单的GUI应…

使用vuex动态设置全局字号

1.安装vuex npm install vuexnext --save 2.编写字号设置样式 // 定义字号变量 :root {--font-size: 18px;--font-size-step1: 16px;--font-size-step2: 14px;--font-size-step3: 12px; } // 定义样式&#xff08;全局样式文件&#xff09; body, page {font-size: var(--fo…

编程爱好者的福音:实用技巧与教程

引言 你是否曾经因为代码无法正常运行而感到挫败&#xff1f;或者在面对一行行复杂的代码时&#xff0c;不知道从何下手&#xff1f;编程&#xff0c;这项充满挑战与创造力的技能&#xff0c;往往让人既爱又恨。无论你是刚刚入门的初学者&#xff0c;还是已经具备一定经验的开发…

了解bootstrap改造asp.net core MVC的样式模板

我们都知道&#xff0c;在使用默认的asp.net core MVC模板建立项目的时候&#xff0c;里面的样式是已经事先被写好了的。一般来说都在css目录下的site.css和bootstrap.css及下面的bootstrap.min.css中。我们打开bootstrap这些样式文件&#xff0c;里面有大量的样式类的定义&…

通过使用 FFmpeg 提取某站视频 MV 中的音频为 MP3

无论是为了个人收藏、制作播客还是作为背景音乐&#xff0c;将视频中的音频提取出来都是一个非常实用的技能。本教程中简鹿办公将介绍两种方法来实现这一目标&#xff1a;一种是通过命令行工具 FFmpeg&#xff0c;另一种是使用图形界面工具 - 简鹿音频格式转换器。 使用 FFmpeg…

探秘国际数字影像产业园:数字化转型之路

数字化园区的概念正日益受到全球瞩目&#xff0c;这不仅是科技进步的必然产物&#xff0c;更是现代经济发展的迫切需求。对于国际数字影像产业园而言&#xff0c;打造数字化园区意味着通过尖端科技手段&#xff0c;全面提升园区的管理效率、服务质量及入驻企业和居民的生活体验…

外包干了2年,快要废了。。

先说一下自己的情况&#xff0c;普通本科&#xff0c;在外包干了2年多的功能测试&#xff0c;这几年因为大环境不好&#xff0c;我整个人心惊胆战的&#xff0c;怕自己卷铺盖走人了&#xff0c;我感觉自己不能够在这样蹉跎下去了&#xff0c;长时间呆在一个舒适的环境真的会让一…

5G的发展演进

5G发展的驱动力 什么是5G [远程会议&#xff0c;2020年7月10日] 在来自世界各地的政府主管部门、电信制造及运营企业、研究机构约200多名会议代表和专家们的共同见证下&#xff0c;ITU-R WP 5D#35e远程会议宣布3GPP 5G技术&#xff08;含NB-IoT&#xff09;满足IMT-2020 5G技…

【C++打怪之路Lv14】- “多态“篇

&#x1f308; 个人主页&#xff1a;白子寰 &#x1f525; 分类专栏&#xff1a;重生之我在学Linux&#xff0c;C打怪之路&#xff0c;python从入门到精通&#xff0c;数据结构&#xff0c;C语言&#xff0c;C语言题集&#x1f448; 希望得到您的订阅和支持~ &#x1f4a1; 坚持…

Github 2024-11-05 Python开源项目日报Top10

根据Github Trendings的统计,今日(2024-11-05统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Python项目10HTML项目1TypeScript项目1系统设计指南 创建周期:2507 天开发语言:Python协议类型:OtherStar数量:241693 个Fork数量:42010 次…

如何从 Android 图库中恢复误删除的照片

如果您正在阅读这篇文章&#xff0c;那么您肯定意外地从 Android 设备中删除了照片。并且您正在寻找一种简单的方法来恢复 Android 图库中已删除的照片。 从图库恢复已删除的照片 随着技术的进步&#xff0c;现在使用单个设备&#xff08;即 Android 手机&#xff09;&#xf…