【Finetune】(二)、transformers之Prompt-Tuning微调

文章目录

  • 0、prompt-tuning基本原理
  • 1、实战
    • 1.1、导包
    • 1.2、加载数据
    • 1.3、数据预处理
    • 1.4、创建模型
    • 1.5、Prompt Tuning*
      • 1.5.1、配置文件
      • 1.5.2、创建模型
    • 1.6、配置训练参数
    • 1.7、创建训练器
    • 1.8、模型训练
    • 1.9、推理:加载预训练好的模型

0、prompt-tuning基本原理

 prompt-tuning的基本思想就是冻结主模型的全部参数,在训练数据前加入一小段Prompt,只训练Prompt的表示向量,即一个Embedding模块。其中,prompt又存在两种形式,一种是hard prompt,一种是soft prompt。

在这里插入图片描述

1、实战

1.1、导包

from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer

1.2、加载数据

ds = Dataset.load_from_disk("../Data/alpaca_data_zh/")

1.3、数据预处理

tokenizer = AutoTokenizer.from_pretrained("../Model/bloom-389m-zh")
tokenizer

def process_func(example):MAX_LENGTH = 256input_ids, attention_mask, labels = [], [], []instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")response = tokenizer(example["output"] + tokenizer.eos_token)input_ids = instruction["input_ids"] + response["input_ids"]attention_mask = instruction["attention_mask"] + response["attention_mask"]labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
tokenized_ds

1.4、创建模型

model = AutoModelForCausalLM.from_pretrained("../Model/bloom-389m-zh",low_cpu_mem_usage=True)

1.5、Prompt Tuning*

1.5.1、配置文件

#soft prompt# config = PromptTuningConfig(
#     task_type=TaskType.CAUSAL_LM,
#     num_virtual_tokens=10,
#     )
# config
#hard prompt
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM,prompt_tuning_init = PromptTuningInit.TEXT,prompt_tuning_init_text = '下面是一段机器人的对话:',num_virtual_tokens=len(tokenizer('下面是一段机器人的对话:')['input_ids']),tokenizer_name_or_path='../Model/bloom-389m-zh',)
config

1.5.2、创建模型

model= get_peft_model(model,config)
model

打印模型训练参数

model.print_trainable_parameters()

1.6、配置训练参数

args = TrainingArguments(output_dir="./chatbot",per_device_train_batch_size=1,gradient_accumulation_steps=4,logging_steps=10,num_train_epochs=1
)

1.7、创建训练器

trainer = Trainer(args=args,model=model,train_dataset=tokenized_ds,data_collator=DataCollatorForSeq2Seq(tokenizer, padding=True, )
)

1.8、模型训练

trainer.train()

1.9、推理:加载预训练好的模型

from peft import PeftModel
peft_model =  PeftModel.from_pretrained(model=model,model_id='./chat_bot/checkpoint500/')
from transformers import pipelinepipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
ipt = "Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: "
pipe(ipt, max_length=256, do_sample=True, temperature=0.5)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/144191.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【论文阅读】FedABC: Targeting Fair Competition in Personalized Federated Learning

论文链接(AAAI2023) 文章解决的问题主要是NO-IID问题。 文章的方法包括几个关键的技术和策略,具体如下: 二元分类框架: FedABC利用二元分类的训练策略来解决每个类别的个性化问题。这意味着对于每个类别都训练一个独立…

初识 C++ ( 1 )

引言:大家都说c是c的升级语言。我不懂这句话的含义后来看过解释才懂。 一、面向过程语言和面向对象语言 我们都知道C语言是面向过程语言,而C是面向对象语言,说C和C的区别,也就是在比较面向过程和面向对象的区别。 1.面向过程和面向…

JDBC 编程

目录 JDBC 是什么 JDBC 的工作原理 JDBC 的使用 引入驱动 使用 常用接口和类 Connection Statement ResultSet 使用总结 JDBC 是什么 JDBC(Java Database Connectivity):Java数据库连接,是一种用于执行 SQL 语句的Java…

‍♀️焦虑症患者的救赎之路:这5项运动让你重拾宁静与力量!

在这个快节奏、高压力的时代,焦虑症已成为许多人难以言说的秘密。它像一张无形的网,悄悄侵蚀着我们的心灵,让我们在日复一日的焦虑中挣扎。然而,你知道吗?运动,这一简单而强大的自然疗法,正是我…

强化信息安全:密码机密钥管理的策略与实践

强化信息安全:密码机密钥管理的策略与实践 随着信息技术的飞速发展,信息安全已成为企业和社会关注的焦点。密码机作为加密通信和数据保护的关键设备,其密钥管理直接关系到整个信息系统的安全性。本文旨在探讨密码机密钥管理的策略与实践&…

Java 实现桌面烟花秀

前言 今天,我们将展示如何使用 Java Swing 创建一个烟花效果,覆盖整个桌面。我们将重点讲解如何在桌面上展示烟花、如何实现发射和爆炸效果,以及如何将这些效果整合到一个完整的程序中。 效果展示 如上图所示,我们在桌面实现了&…

深入解析ThingsBoard与ThingsKit物联网平台的差异

VS 在物联网(IoT)领域,平台的选择对于企业来说至关重要。本文将深入探讨ThingsBoard社区版与ThingsKit企业版这两个物联网平台的差异,帮助读者更好地理解它们的特色和适用场景。 系统相同点 首先,ThingsBoard社区版和ThingsKit企业版都基于…

Flink1.18.1 Standalone模式集群搭建

Flink1.18.1 Standalone模式集群搭建 Flink1.18.1 Standalone模式集群搭建1. 环境准备1.1 Flink下载地址1.2 集群角色分配 2. Flink 集群安装步骤2.1 下载并解压 Flink2.2 解压安装包2.3 配置环境变量2.4 配置 SSH 免密登录 3. 配置 Flink 集群3.1 修改 flink-conf.yaml 配置文…

Day99 代码随想录打卡|动态规划篇--- 01背包问题

题目(卡玛网T46): 小明是一位科学家,他需要参加一场重要的国际科学大会,以展示自己的最新研究成果。他需要带一些研究材料,但是他的行李箱空间有限。这些研究材料包括实验设备、文献资料和实验样本等等&am…

【linux008】目录操作命令篇 - rmdir 命令

文章目录 1、基本用法2、常见选项3、举例4、注意事项 rmdir 是 Linux 系统中的一个命令,用于删除空目录。它只能删除 空目录,如果目录中存在文件或子目录,则无法删除。 1、基本用法 rmdir [选项] 目录名...2、常见选项 -p, --parents&…

Linux标准IO-系统调用详解

1.1 系统调用 系统调用(system call)其实是 Linux 内核提供给应用层的应用编程接口(API),是 Linux 应用层进入内核的入口。不止 Linux 系统,所有的操作系统都会向应用层提供系统调用,应用程序通…

在 Windows 上恢复已删除的 PDF 文件的最佳方法

如果您不小心删除了 PDF 文件或由于系统突然崩溃而无法再找到它们,本指南介绍了恢复已删除文件的最佳方法。 帖子中列出的方法简单、有效且可行。我们在列出它们之前对其进行了测试。 什么是 PDF,Adobe 将未保存的 PDF 存储在哪里? 自从 Ad…

无损转换:严选4个视频mkv转mp4格式的方法

视频的mkv格式是较为清晰的视频格式,但越清晰的视频格式所占的设备内存也就越大,从而也可能会出现视频传输失败、播放卡顿等的问题。对此,我们可以将视频转换为体积较小的格式来解决上述问题,如mkv转mp4。接下来,小编就…

实战讲稿:Spring Boot整合MyBatis

文章目录 实战讲稿:Spring Boot整合MyBatis课程目标课程内容1. 创建员工映射器接口1.1 创建子包1.2 创建接口 2. 测试员工映射器接口2.1 自动装配员工映射器2.2 测试按标识符查询员工方法2.3 测试查询全部员工方法2.4 测试插入员工方法2.5 测试更新员工方法2.6 测试…

『玉竹』基于Laravel 开发的博客、微博客系统和Android App

基于 Laravel 和 Filament 开发, 使用 Filament 开发管理后台,前端比较简洁。 博客大家都清楚是什么东西,微博客类似于微博之类的吧,有时候想要写的东西可能只有几句话,想要起个标题都不好起。 为了是微博客功能更好用&#xff0c…

通信工程学习:什么是ONT光网络终端

ONT:光网络终端 ONT(Optical Network Terminal,光网络终端)是光纤接入网络(FTTH)中的关键设备,用于将光纤信号转换为电信号或将电信号转换为光信号,以实现用户设备与光纤网络的连接。…

我的AI工具箱Tauri版-VideoIntroductionClipCut视频介绍混剪

本教程基于自研的AI工具箱Tauri版进行VideoIntroductionClipCut视频介绍混剪。 本项目为自研的AI工具箱Tauri版中的视频剪辑模块,专注于自动生成视频介绍片段。该模块名为 VideoIntroductionClipCut,用户可以通过该工具快速进行视频的混剪和介绍内容的生…

Android开发高频面试题之——Android篇

Android开发高频面试题之——Android篇 Android开发高频面试题之——Java基础篇 Android开发高频面试题之——Kotlin基础篇 Android开发高频面试题之——Android基础篇 1. Activity启动模式 standard 标准模式,每次都是新建Activity实例。singleTop 栈顶复用。如果要启动的A…

[数据集][目标检测]红外微小目标无人机直升机飞机飞鸟检测数据集VOC+YOLO格式7559张4类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):7559 标注数量(xml文件个数):7559 标注数量(txt文件个数):7559 标注…

Web转发(forward)与重定向(redirect)

请求转发forward -> xxServlet收到请求 -> 直接转发给yyServlet -> yyServlet返回给客户端 整个过程中,客户端发出一个请求,收到一个响应。 实现: 方式一:利用RequestDispather接口中的forward方法实现请求转发。 RequestDispathe…