基于 Qwen2-1.5B Lora 微调训练医疗问答任务

一、Qwen2 Lora 微调

Qwen是阿里巴巴集团Qwen团队研发的大语言模型和大型多模态模型系列。Qwen2Qwen1.5 的重大升级。无论是语言模型还是多模态模型,均在大规模多语言和多模态数据上进行预训练,并通过高质量数据进行后期微调以贴近人类偏好。Qwen具备自然语言理解、文本生成、视觉理解、音频理解、工具使用、角色扮演、作为AI Agent进行互动等多种能力。

Qwen2有以下特点:

  • 5种模型规模,包括0.5B、1.5B、7B、57B-A14B72B
  • 针对每种尺寸提供基础模型和指令微调模型,并确保指令微调模型按照人类偏好进行校准;
  • 基础模型和指令微调模型的多语言支持;
  • 所有模型均稳定支持32K长度上下文;Qwen2-7B-InstructQwen2-72B-Instruct可支持128K上下文(需额外配置)
  • 支持工具调用、RAG(检索增强文本生成)、角色扮演、AI Agent等;

更多详细的介绍可以参考官方文档:

https://qwen.readthedocs.io/zh-cn/latest/

下面实验所使用的核心依赖版本如下:

torch==1.13.1+cu116
peft==0.12.0
transformers==4.37.0
tensorboard==2.17.1

二、构建 Qwen2-1.5B Lora 模型

LoRA 微调技术的思想很简单,在原始 PLM (Pre-trained Language Model) 增加一个旁路,一般是在 transformer 层,做一个降维再升维的操作,模型的输入输出维度不变,来模拟 intrinsic rank,如下图的 AB。训练时冻结 PLM 的参数,只训练 AB ,,输出时将旁路输出与 PLM 的参数叠加,进而影响原始模型的效果。该方式,可以大大降低训练的参数量,而性能可以优于其它参数高效微调方法,甚至和全参数微调(Fine-Tuning)持平甚至超过。

对于 AB 参数的初始化,A 使用随机高斯分布,B 使用 0 矩阵,这样在最初时可以保证旁路为一个 0 矩阵,最开始时使用原始模型的能力。

在这里插入图片描述

在构建 Qwen2-1.5B Lora 结构模型前,先了解下现在 Qwen2-1.5B 的结构:

这里直接使用 PyTorch 的模型打印方式,主要看模型的组成:

from transformers import AutoModelForCausalLMmodel_path = "model/Qwen2-1.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
print(model)

输出结果:

Qwen2ForCausalLM((model): Qwen2Model((embed_tokens): Embedding(151936, 1536)(layers): ModuleList((0): Qwen2DecoderLayer((self_attn): Qwen2Attention((q_proj): Linear(in_features=1536, out_features=1536, bias=True)(k_proj): Linear(in_features=1536, out_features=256, bias=True)(v_proj): Linear(in_features=1536, out_features=256, bias=True)(o_proj): Linear(in_features=1536, out_features=1536, bias=False)(rotary_emb): Qwen2RotaryEmbedding())(mlp): Qwen2MLP((gate_proj): Linear(in_features=1536, out_features=8960, bias=False)(up_proj): Linear(in_features=1536, out_features=8960, bias=False)(down_proj): Linear(in_features=8960, out_features=1536, bias=False)(act_fn): SiLU())(input_layernorm): Qwen2RMSNorm()(post_attention_layernorm): Qwen2RMSNorm()).. 省略中间结构.(27): Qwen2DecoderLayer((self_attn): Qwen2Attention((q_proj): Linear(in_features=1536, out_features=1536, bias=True)(k_proj): Linear(in_features=1536, out_features=256, bias=True)(v_proj): Linear(in_features=1536, out_features=256, bias=True)(o_proj): Linear(in_features=1536, out_features=1536, bias=False)(rotary_emb): Qwen2RotaryEmbedding())(mlp): Qwen2MLP((gate_proj): Linear(in_features=1536, out_features=8960, bias=False)(up_proj): Linear(in_features=1536, out_features=8960, bias=False)(down_proj): Linear(in_features=8960, out_features=1536, bias=False)(act_fn): SiLU())(input_layernorm): Qwen2RMSNorm()(post_attention_layernorm): Qwen2RMSNorm()))(norm): Qwen2RMSNorm())(lm_head): Linear(in_features=1536, out_features=151936, bias=False)
)

从上面的结构可以看出 Qwen2-1.5B 的结构其实并不复杂,由 27DecoderLayer 构成,每个 Decoder 主要的核心是 self_attentionmlp,因此可以尝试在 q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj 层添加 Lora 结构,下面使用 PEFT 库实现,这里 r 使用 8lora_alpha 使用 32

from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskTypemodel_path = "model/Qwen2-1.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM,target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],inference_mode=False,r=8,lora_alpha=32,lora_dropout=0.1
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
print(model)

输出结果:

trainable params: 9,232,384 || all params: 1,552,946,688 || trainable%: 0.5945
PeftModelForCausalLM((base_model): LoraModel((model): Qwen2ForCausalLM((model): Qwen2Model((embed_tokens): Embedding(151936, 1536)(layers): ModuleList((0): Qwen2DecoderLayer((self_attn): Qwen2Attention((q_proj): lora.Linear((base_layer): Linear(in_features=1536, out_features=1536, bias=True)(lora_dropout): ModuleDict((default): Dropout(p=0.1, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=1536, out_features=8, bias=False))(lora_B): ModuleDict((default): Linear(in_features=8, out_features=1536, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(k_proj): lora.Linear((base_layer): Linear(in_features=1536, out_features=256, bias=True)(lora_dropout): ModuleDict((default): Dropout(p=0.1, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=1536, out_features=8, bias=False))(lora_B): ModuleDict((default): Linear(in_features=8, out_features=256, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(v_proj): lora.Linear((base_layer): Linear(in_features=1536, out_features=256, bias=True)(lora_dropout): ModuleDict((default): Dropout(p=0.1, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=1536, out_features=8, bias=False))(lora_B): ModuleDict((default): Linear(in_features=8, out_features=256, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(o_proj): lora.Linear((base_layer): Linear(in_features=1536, out_features=1536, bias=False)(lora_dropout): ModuleDict((default): Dropout(p=0.1, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=1536, out_features=8, bias=False))(lora_B): ModuleDict((default): Linear(in_features=8, out_features=1536, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(rotary_emb): Qwen2RotaryEmbedding())(mlp): Qwen2MLP((gate_proj): lora.Linear((base_layer): Linear(in_features=1536, out_features=8960, bias=False)(lora_dropout): ModuleDict((default): Dropout(p=0.1, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=1536, out_features=8, bias=False))(lora_B): ModuleDict((default): Linear(in_features=8, out_features=8960, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(up_proj): lora.Linear((base_layer): Linear(in_features=1536, out_features=8960, bias=False)(lora_dropout): ModuleDict((default): Dropout(p=0.1, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=1536, out_features=8, bias=False))(lora_B): ModuleDict((default): Linear(in_features=8, out_features=8960, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(down_proj): lora.Linear((base_layer): Linear(in_features=8960, out_features=1536, bias=False)(lora_dropout): ModuleDict((default): Dropout(p=0.1, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=8960, out_features=8, bias=False))(lora_B): ModuleDict((default): Linear(in_features=8, out_features=1536, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(act_fn): SiLU())(input_layernorm): Qwen2RMSNorm()(post_attention_layernorm): Qwen2RMSNorm()).. 省略中间结构.  (27): Qwen2DecoderLayer((self_attn): Qwen2Attention((q_proj): lora.Linear((base_layer): Linear(in_features=1536, out_features=1536, bias=True)(lora_dropout): ModuleDict((default): Dropout(p=0.1, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=1536, out_features=8, bias=False))(lora_B): ModuleDict((default): Linear(in_features=8, out_features=1536, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(k_proj): lora.Linear((base_layer): Linear(in_features=1536, out_features=256, bias=True)(lora_dropout): ModuleDict((default): Dropout(p=0.1, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=1536, out_features=8, bias=False))(lora_B): ModuleDict((default): Linear(in_features=8, out_features=256, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(v_proj): lora.Linear((base_layer): Linear(in_features=1536, out_features=256, bias=True)(lora_dropout): ModuleDict((default): Dropout(p=0.1, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=1536, out_features=8, bias=False))(lora_B): ModuleDict((default): Linear(in_features=8, out_features=256, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(o_proj): lora.Linear((base_layer): Linear(in_features=1536, out_features=1536, bias=False)(lora_dropout): ModuleDict((default): Dropout(p=0.1, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=1536, out_features=8, bias=False))(lora_B): ModuleDict((default): Linear(in_features=8, out_features=1536, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(rotary_emb): Qwen2RotaryEmbedding())(mlp): Qwen2MLP((gate_proj): lora.Linear((base_layer): Linear(in_features=1536, out_features=8960, bias=False)(lora_dropout): ModuleDict((default): Dropout(p=0.1, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=1536, out_features=8, bias=False))(lora_B): ModuleDict((default): Linear(in_features=8, out_features=8960, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(up_proj): lora.Linear((base_layer): Linear(in_features=1536, out_features=8960, bias=False)(lora_dropout): ModuleDict((default): Dropout(p=0.1, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=1536, out_features=8, bias=False))(lora_B): ModuleDict((default): Linear(in_features=8, out_features=8960, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(down_proj): lora.Linear((base_layer): Linear(in_features=8960, out_features=1536, bias=False)(lora_dropout): ModuleDict((default): Dropout(p=0.1, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=8960, out_features=8, bias=False))(lora_B): ModuleDict((default): Linear(in_features=8, out_features=1536, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(act_fn): SiLU())(input_layernorm): Qwen2RMSNorm()(post_attention_layernorm): Qwen2RMSNorm()))(norm): Qwen2RMSNorm())(lm_head): Linear(in_features=1536, out_features=151936, bias=False)))
)

从结果可以看出,Lora 之后在每一层都增加了一个 lora_Alora_B 结构来实现降维升维的作用。

三、准备训练数据集

数据集采用 GitHub 上的 Chinese-medical-dialogue-data 中文医疗对话数据集。

GitHub 地址如下:

https://github.com/Toyhom/Chinese-medical-dialogue-data

数据分了 6 个科目类型:

在这里插入图片描述

数据格式如下所示:

在这里插入图片描述

其中 ask 为病症的问题描述,answer 为病症的回答。

该数据集在本专栏的前面文章中,已经被使用在 ChatGLM2ChatYuan-large 模型上做过微调实验,感兴趣的小伙伴可以参考一下:

ChatGLM2-6B Lora 微调训练医疗问答任务

基于第二代 ChatGLM2-6B P-Tuning v2 微调训练医疗问答任务

ChatYuan-large-v2 微调训练 医疗问答 任务

由于整体数据比较多,这里为了演示效果,选取 内科、肿瘤科、儿科、外科 四个科目的数据进行实验,并且每个科目取前 10000 条数据进行训练、2000 条数据进行验证。

首先将数据集转为 json 格式方便后续读取:

import json
import pandas as pddata_path = ["./data/Chinese-medical-dialogue-data-master/Data_数据/IM_内科/内科5000-33000.csv","./data/Chinese-medical-dialogue-data-master/Data_数据/Oncology_肿瘤科/肿瘤科5-10000.csv","./data/Chinese-medical-dialogue-data-master/Data_数据/Pediatric_儿科/儿科5-14000.csv","./data/Chinese-medical-dialogue-data-master/Data_数据/Surgical_外科/外科5-14000.csv",
]train_json_path = "./data/train.json"
val_json_path = "./data/val.json"
# 每个数据取 10000 条作为训练
train_size = 10000
# 每个数据取 2000 条作为验证
val_size = 2000def main():train_f = open(train_json_path, "a", encoding='utf-8')val_f = open(val_json_path, "a", encoding='utf-8')for path in data_path:data = pd.read_csv(path, encoding='ANSI')train_count = 0val_count = 0for index, row in data.iterrows():question = row["ask"]answer = row["answer"]line = {"question": question,"answer": answer}line = json.dumps(line, ensure_ascii=False)if train_count < train_size:train_f.write(line + "\n")train_count = train_count + 1elif val_count < val_size:val_f.write(line + "\n")val_count = val_count + 1else:breakprint("数据处理完毕!")train_f.close()val_f.close()if __name__ == '__main__':main()

处理之后可以看到两个生成的文件:

在这里插入图片描述

四、微调训练

解析数据,构建 Dataset 数据集

qa_dataset.py:

# -*- coding: utf-8 -*-
from torch.utils.data import Dataset
import torch
import json
import numpy as npclass QADataset(Dataset):def __init__(self, data_path, tokenizer, max_source_length, max_target_length) -> None:super().__init__()self.tokenizer = tokenizerself.max_source_length = max_source_lengthself.max_target_length = max_target_lengthself.max_seq_length = self.max_source_length + self.max_target_lengthself.data = []if data_path:with open(data_path, "r", encoding='utf-8') as f:for line in f:if not line or line == "":continuejson_line = json.loads(line)question = json_line["question"]answer = json_line["answer"]self.data.append({"question": question,"answer": answer})print("data load , size:", len(self.data))def preprocess(self, question, answer):messages = [{"role": "system", "content": "你是一个医疗方面的专家,可以根据患者的问题进行解答。"},{"role": "user", "content": question}]prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)instruction = self.tokenizer(prompt, add_special_tokens=False, max_length=self.max_source_length)response = self.tokenizer(answer, add_special_tokens=False, max_length=self.max_target_length)input_ids = instruction["input_ids"] + response["input_ids"] + [self.tokenizer.pad_token_id]attention_mask = (instruction["attention_mask"] + response["attention_mask"] + [1])labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [self.tokenizer.pad_token_id]if len(input_ids) > self.max_seq_length:input_ids = input_ids[:self.max_seq_length]attention_mask = attention_mask[:self.max_seq_length]labels = labels[:self.max_seq_length]return input_ids, attention_mask, labelsdef __getitem__(self, index):item_data = self.data[index]input_ids, attention_mask, labels = self.preprocess(**item_data)return {"input_ids": torch.LongTensor(np.array(input_ids)),"attention_mask": torch.LongTensor(np.array(attention_mask)),"labels": torch.LongTensor(np.array(labels))}def __len__(self):return len(self.data)

训练:

# -*- coding: utf-8 -*-
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType
import pandas as pd
from qa_dataset import QADataset
from tqdm import tqdm
import os, time, sysdef train_model(model, train_loader, val_loader, optimizer, gradient_accumulation_steps,device, num_epochs, model_output_dir, writer):batch_step = 0for epoch in range(num_epochs):time1 = time.time()model.train()for index, data in enumerate(tqdm(train_loader, file=sys.stdout, desc="Train Epoch: " + str(epoch))):input_ids = data['input_ids'].to(device, dtype=torch.long)attention_mask = data['attention_mask'].to(device, dtype=torch.long)labels = data['labels'].to(device, dtype=torch.long)# 前向传播outputs = model(input_ids=input_ids,attention_mask=attention_mask,labels=labels,)loss = outputs.loss# 反向传播,计算当前梯度loss.backward()# 梯度累积步数if (index % gradient_accumulation_steps == 0 and index != 0) or index == len(train_loader) - 1:# 更新网络参数optimizer.step()# 清空过往梯度optimizer.zero_grad()writer.add_scalar('Loss/train', loss, batch_step)batch_step += 1# 100轮打印一次 lossif index % 100 == 0 or index == len(train_loader) - 1:time2 = time.time()tqdm.write(f"{index}, epoch: {epoch} -loss: {str(loss)} ; each step's time spent: {(str(float(time2 - time1) / float(index + 0.0001)))}")# 验证model.eval()val_loss = validate_model(model, val_loader, device)writer.add_scalar('Loss/val', val_loss, epoch)print(f"val loss: {val_loss} , epoch: {epoch}")print("Save Model To ", model_output_dir)model.save_pretrained(model_output_dir)def validate_model(model, device, val_loader):running_loss = 0.0with torch.no_grad():for _, data in enumerate(tqdm(val_loader, file=sys.stdout, desc="Validation Data")):input_ids = data['input_ids'].to(device, dtype=torch.long)attention_mask = data['attention_mask'].to(device, dtype=torch.long)labels = data['labels'].to(device, dtype=torch.long)outputs = model(input_ids=input_ids,attention_mask=attention_mask,labels=labels,)loss = outputs.lossrunning_loss += loss.item()return running_loss / len(val_loader)def main():# 基础模型位置model_name = "model/Qwen2-1.5B-Instruct"# 训练集train_json_path = "./data/train.json"# 验证集val_json_path = "./data/val.json"max_source_length = 128max_target_length = 256epochs = 10batch_size = 1lr = 1e-4gradient_accumulation_steps = 16lora_rank = 8lora_alpha = 32model_output_dir = "output"logs_dir = "logs"# 设备device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 加载分词器和模型tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)# setup peftpeft_config = LoraConfig(task_type=TaskType.CAUSAL_LM,target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],inference_mode=False,r=lora_rank,lora_alpha=lora_alpha,lora_dropout=0.1)model = get_peft_model(model, peft_config)model.is_parallelizable = Truemodel.model_parallel = Truemodel.print_trainable_parameters()print("Start Load Train Data...")train_params = {"batch_size": batch_size,"shuffle": True,"num_workers": 0,}training_set = QADataset(train_json_path, tokenizer, max_source_length, max_target_length)training_loader = DataLoader(training_set, **train_params)print("Start Load Validation Data...")val_params = {"batch_size": batch_size,"shuffle": False,"num_workers": 0,}val_set = QADataset(val_json_path, tokenizer, max_source_length, max_target_length)val_loader = DataLoader(val_set, **val_params)# 日志记录writer = SummaryWriter(logs_dir)# 优化器optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)model = model.to(device)# 开始训练print("Start Training...")train_model(model=model,train_loader=training_loader,val_loader=val_loader,optimizer=optimizer,gradient_accumulation_steps=gradient_accumulation_steps,device=device,num_epochs=epochs,model_output_dir=model_output_dir,writer=writer)writer.close()if __name__ == '__main__':main()

训练过程:

在这里插入图片描述

训练结束后,可以在 output 中看到 lora 模型:

在这里插入图片描述

五、模型测试

# -*- coding: utf-8 -*-
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torchmodel_path = "model/Qwen2-1.5B-Instruct"
lora_dir = "output"device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = PeftModel.from_pretrained(model, lora_dir)
model.to(device)prompt = """
5月至今上腹靠右隐痛,右背隐痛带酸,便秘,喜睡,时有腹痛,头痛,腰酸症状?
"""
messages = [{"role": "system", "content": "你是一个医疗方面的专家,可以根据患者的问题进行解答。"},{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True
)
print(text)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=258)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)

在这里插入图片描述

模型回答:根据你的叙述,胃炎胆汁反流性胃炎的可能性大,建议口服奥美拉唑,吗丁啉救治,清淡易消化饮食,忌辛辣打击食物,留意歇息,不要加班除了正规救治胃痛外,患者还需要有看重护理方面,比如恰当饮食,始终保持心情愉快。与此同时患者还要留意决定一家专业医院诊病,这样才能获得良好的治疗效果。

六、模型合并

上面测试还是分开加载的基础模型和lora模型,可以将两个合并为一个,方便后续部署:

# -*- coding: utf-8 -*-
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import PeftModeldevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_path = "model/Qwen2-1.5B-Instruct"
lora_dir = "output"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
model = PeftModel.from_pretrained(model, lora_dir).to(device)
print(model)
# 合并model, 同时保存 token
model = model.merge_and_unload()
model.save_pretrained("lora_output")
tokenizer.save_pretrained("lora_output")

合并后的结构:

在这里插入图片描述

后面就不需要再通过 PeftModel 直接加载模型既可使用:

# -*- coding: utf-8 -*-
from transformers import AutoModelForCausalLM, AutoTokenizer
import torchmodel_path = "lora_output"device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model.to(device)prompt = """
5月至今上腹靠右隐痛,右背隐痛带酸,便秘,喜睡,时有腹痛,头痛,腰酸症状?
"""
messages = [{"role": "system", "content": "你是一个医疗方面的专家,可以根据患者的问题进行解答。"},{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True
)
print(text)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=258)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/146588.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

DELPHI编译软件时带上当前IDE的版本号

如果通过 CompilerVersion 得到的也只是编译器的版本号。 比如&#xff1a;delphi XE12 是 36 &#xff0c;也仅此而己。 我想得到的是IDE的版本号&#xff0c;比如当前最新版本的DELPHI是&#xff1a;Embarcadero RAD Studio 12 Version 29.0.53571.9782 我想得到 29.0.53…

【JAVA开源】基于Vue和SpringBoot的网上超市系统

本文项目编号 T 037 &#xff0c;文末自助获取源码 \color{red}{T037&#xff0c;文末自助获取源码} T037&#xff0c;文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 查…

16【Protues51单片机仿真】智能洗衣机倒计时系统

目录 一、主要功能 二、硬件资源 三、程序编程 四、实现现象 一、主要功能 用直流电机转动模拟洗衣机。要求 有弱洗、普通洗、强洗三种模式&#xff0c;可通过按键选择。可以设置洗衣时长&#xff0c;通关按键选择15、30、45、60、90分钟。时间到蜂鸣器报警提示。LCD 显示…

Webui 显卡有显存,会报错:CUDA out of memory

Webui 显卡明明有显存&#xff0c;会报错&#xff1a;CUDA out of memory 网上找了很多资料&#xff0c;都没有能解决这个问题 &#xff0c;后来发现和电脑虚拟内存设置有关&#xff0c;这里记录一下具体的解决方法&#xff1a; 什么是 CUDA Out of Memory 错误&#xff1f; …

SAP B1 Web Client MS Teams App集成连载三

过程/Procedure&#xff1a; 1.在应用商店中&#xff0c;点击启动 SAP Business One 应用。应用详细信息页面显示如下。 In the Apps store, click SAP Business One app to launch it. The app details page is displayed as below 2.在左上角&#xff0c;有一个包含两个选项的…

淘宝扭蛋机小程序,扭蛋机文化下的新体验

在数字化时代中&#xff0c;扭蛋机逐渐从传统的线下机器转移到了线上互联网中&#xff0c;市场得到了创新发展。扭蛋机小程序具有便捷、多样化、个性化的特点&#xff0c;迎合了当下消费者的线上消费习惯&#xff0c;又能够让扭蛋机玩家体验到新鲜有趣的扭蛋。 扭蛋机是一种热…

光伏板缺陷红外检测数据集

光伏板缺陷红外检测数据集 包含以下4个数据文件&#xff1a; /train&#xff1a;训练集 /valid&#xff1a;验证集 /test&#xff1a;测试集 README.txt&#xff1a;数据说明 【数据说明】检测目标以Pascal VOC格式进行标注&#xff0c;对每个图像进行以下预处理&#xff0c;统…

PCIE集成验证(五)MSI/MSI-X中断

PCI 总线最早采用的中断机制是 INTx&#xff0c;这是基于边带信号的。后续的 PCI/PCI-X版本&#xff0c;为了消除边带信号&#xff0c;降低系统的硬件设计复杂度&#xff0c;逐渐采用了 MSI(Message Signaled Interrupt)/MSI-X&#xff08;消息信号中断&#xff09;的中断机制。…

救生圈检测系统源码分享

救生圈检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Visio…

使用Renesas R7FA8D1BH (Cortex®-M85)和微信小程序App数据传输

目录 概述 1 系统架构 1.1 系统结构 1.2 系统硬件框架结构 1.3 蓝牙模块介绍 2 微信小程序实现 2.1 UI介绍 2.2 代码实现 3 上位机功能实现 3.1 通信协议 3.2 系统测试 4 下位机功能实现 4.1 功能介绍 4.2 代码实现 4.3 源代码文件 5 测试 5.1 编译和下载代码…

微服务基础设施选型

微服务基础设施架构 微服务基础设施架构全貌 微服务 vs SOA (Round 2) 微服务数量越多越复杂 微服务 vs SOA (Round 3) 微服务把服务的粒度变小&#xff0c;进行了标准化拆分。同时也将ESB拆分为了微服务。 微服务基础设施优先级 这里面体现了基础设施的优先级&#xff0c;如…

人工智能之就业方向(The Employment Direction of Artificial Intelligence)

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:Linux运维老纪的首页…

银河麒麟桌面操作系统V10(SP1)离线升级SSH(OpenSSH)服务

目录 前言 准备工作 准备与目标服务器相同版本的操作系统 准备编译依赖包 下载OpenSSL源码包 下载OpenSSH源码包 升级OpenSSH服务 查看当前版本信息 安装编译依赖包 安装OpenSSL 安装OpenSSH 前言 OpenSSH是一个广泛使用的开源SSH(安全壳)协议的实现,它提供了安…

手机自动化测试环境之夜神模拟器inspector部署验证

1、自动化测试环境部署_总览图检查表流程图 Python需要安装Appium-Python-Clicent去定位元素&#xff1b;Appium是一个中间的服务器&#xff0c;它需要依赖node.js&#xff0c;python的脚本通过appium和手机进行交互&#xff1b;手机app的环境都是java环境&#xff0c;先安装jd…

PMBOK® 第六版 排列活动顺序

目录 读后感—PMBOK第六版 目录 职场中有句玩笑话&#xff1a;“工作是永远做不完的&#xff0c;任何时候都不可能做完。”这里所吐槽的要点就在于工作任务繁多以及工作缺乏秩序。工作确实是做不完的&#xff0c;倘若工作都能完成&#xff0c;那也就不需要工作了。 工作中令人…

【服务器第二期】mobaxterm软件下载及连接

【服务器第二期】mobaxterm软件下载及连接 前言什么是SSH什么是FTP/SFTP mobaxterm软件介绍mobaxterm软件下载SSH登录使用方法1-新建ssh连接方法2-打开已有的ssh连接方法3-通过ssh命令建立连接 SFTP数据传输方法1-建立ssh连接后直接拖拽方法2-建立sftp连接再拖拽方法3-直接使用…

SURILL MILL搭配cnc机器的打样(3维导入 使用)

导入STP文件&#xff0c;然后 选择 &#xff0c;点击 曲面里的 曲面 炸开 (和曲线分开 ) 到处曲面 的面与 面的先分开了 看 实际情况 &#xff0c;接下来 也可以 曲线炸开 来 分解 组合 然后 &#xff0c;此时选择面还是没有生产成线 点击文件 那一行的曲面 绘制 ,借助曲面…

华为云centos7.9按装ambari 2.7.5 hostname 踩坑记录

华为云centos7.9按装ambari 2.7.5踩坑记录 前言升华总结 前言 一般都是废话&#xff0c;本人专业写bug业余运维。起初找了三台不废弃的台式机&#xff0c;开始重装centos系统&#xff0c;开始了HDP3.1.5Ambari2.7.5安装。 推荐一波好文&#xff0c;一路长绿。跑了一段时间没啥…

3DMAX乐高积木插件LegoBlocks使用方法

3DMAX乐高积木插件LegoBlocks&#xff0c;用户可以通过控件调整和自定义每个乐高积木的外观和大小。 【适用版本】 3dMax2009或更高版本&#xff08;不仅限于此范围&#xff09; 【安装方法】 3DMAX乐高积木插件无需安装&#xff0c;使用时直接拖动插件脚本文件到3dMax视口中…

适用于 Windows 的 7 大数据恢复工具,可靠的数据恢复工具可有效地恢复丢失的文件

数据丢失可能是一种令人沮丧的经历&#xff0c;无论是由于意外删除、磁盘格式化还是系统崩溃。幸运的是&#xff0c;Windows 用户可以使用几种可靠的数据恢复工具来有效地恢复丢失的文件。以下是前七名数据恢复工具的综述&#xff0c;包括奇客数据恢复产品&#xff1a; 适用于 …