昇思25天学习打卡营第11天|MindSpore 助力下的 GPT2:数据集加载处理及模型全攻略

目录

环境配置

数据集下载和获取

数据集拆分

处理数据集

模型构建

​​​​​​​模型训练

​​​​​​​模型推理


环境配置


        “%%capture captured_output”这一行指令通常旨在捕获后续整个代码块所产生的输出结果。首先,将已预装的 mindspore 库予以卸载。随后,借助指定的国内镜像源(如 https://pypi.mirrors.ustc.edu.cn/simple )来安装特定版本(即 2.2.14 版)的 mindspore 库。接着,通过另一个国内镜像源(如 https://pypi.tuna.tsinghua.edu.cn/simple )完成指定版本(0.15.0 版)的 tokenizers 库的安装。最后,对 mindnlp 库进行安装操作。

        代码如下:

%%capture captured_output  
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号  
!pip uninstall mindspore -y  
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14  
!pip install tokenizers==0.15.0 -i https://pypi.tuna.tsinghua.edu.cn/simple  
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`  
!pip install mindnlp  

数据集下载和获取


        对一个数据集进行下载操作,而后将其加载构建为 TextFileDataset 类型的对象,最后获取该数据集的规模大小信息。

        代码如下:

from mindnlp.utils import http_get  
# download dataset  
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'  
path = http_get(url, './')  
from mindspore.dataset import TextFileDataset  
# load dataset  
dataset = TextFileDataset(str(path), shuffle=False)  
dataset.get_dataset_size()  

        分析:首先,从 mindnlp.utils 模块导入了 http_get 函数。接着,定义了一个字符串,此字符串乃是数据集的下载链接 URL 。随后,运用 http_get 函数从指定的该 URL 下载数据集,并将其保存至当前目录(即'./'),返回的路径被存储在 path 变量里。之后,从 mindspore.dataset 模块引入 TextFileDataset 类。再接着,利用下载完成的数据集的路径创建了一个 TextFileDataset 对象,并将其命名为 dataset ,同时设置不打乱数据的顺序(shuffle=False)。最终,调用 get_dataset_size 方法以获取数据集的大小。

        运行结果:

数据集拆分


        将名为 dataset 的数据集按照比例 0.9 和 0.1 拆分为训练数据集 train_dataset 和测试数据集 test_dataset ,并且拆分过程不进行随机操作(randomize=False)。

        代码如下:

# split into training and testing dataset  
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)  

​​​​​​​处理数据集


        第一步:构建了一个用于对数据集予以预处理的函数 process_dataset ,同时借助 BertTokenizer 开展中文文本的处理工作。

        代码如下:

import json  
import numpy as np  
# preprocess dataset  
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):  def read_map(text):  data = json.loads(text.tobytes())  return np.array(data['article']), np.array(data['summarization'])  def merge_and_pad(article, summary):  # tokenization  # pad to max_seq_length, only truncate the article  tokenized = tokenizer(text=article, text_pair=summary,  padding='max_length', truncation='only_first', max_length=max_seq_len)  return tokenized['input_ids'], tokenized['input_ids']      dataset = dataset.map(read_map, 'text', ['article', 'summary'])  # change column names to input_ids and labels for the following training  dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])  dataset = dataset.batch(batch_size)  if shuffle:  dataset = dataset.shuffle(batch_size)  return dataset  
from mindnlp.transformers import BertTokenizer  
# We use BertTokenizer for tokenizing chinese context.  
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')  
len(tokenizer)

        分析:首先,在 process_dataset 函数中:

        定义了一个内部函数read_map,用于将输入的文本转换为numpy 数组形式的 article 和 summarization 。

        定义了 merge_and_pad 函数,对 article 和 summary 进行分词处理,并进行填充以达到最大序列长度。

        对输入的数据集进行一系列的操作,包括读取映射、合并和填充、分批次处理以及可选的随机打乱。

        接着,通过 BertTokenizer.from_pretrained('bert-base-chinese') 加载了一个预训练的用于中文的 BertTokenizer 。

        最后,len(tokenizer) 尝试获取 tokenizer 对象的长度,但对于 BertTokenizer 来说,len 操作的含义通常不太明确,可能不会得到有意义的结果,或者可能会引发错误,具体取决于 BertTokenizer 类的实现。

        运行结果:

        第二步:对训练数据集 train_dataset 进行处理,并创建一个迭代器。

        代码如下:

train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)  
next(train_dataset.create_tuple_iterator()) 

        分析:process_dataset 此函数应当是针对数据集展开某种预处理操作的,比如运用给定的 tokenizer 进行分词之类的,同时设定了批处理规模为 4 。而后获取由 create_tuple_iterator 方法所创建的迭代器的下一个元素。

        运行结果:

[Tensor(shape=[4, 1024], dtype=Int64, value=  [[ 101, 1724, 3862 ...    0,    0,    0],  [ 101,  704, 3173 ...    0,    0,    0],  [ 101, 1079, 2159 ... 1745, 8021,  102],  [ 101, 1355, 2357 ...    0,    0,    0]]),  Tensor(shape=[4, 1024], dtype=Int64, value=  [[ 101, 1724, 3862 ...    0,    0,    0],  [ 101,  704, 3173 ...    0,    0,    0],  [ 101, 1079, 2159 ... 1745, 8021,  102],  [ 101, 1355, 2357 ...    0,    0,    0]])]  

​​​​​​​模型构建


        第一步:自定义模型的计算逻辑,用于计算特定的损失值。

        代码如下:

from mindspore import ops  
from mindnlp.transformers import GPT2LMHeadModel  
class GPT2ForSummarization(GPT2LMHeadModel):  def construct(  self,  input_ids = None,  attention_mask = None,  labels = None,  ):  outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)  shift_logits = outputs.logits[..., :-1, :]  shift_labels = labels[..., 1:]  # Flatten the tokens  loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)  return loss  

        分析:定义了一个名为 GPT2ForSummarization 的类,它继承自 GPT2LMHeadModel 。

        在 construct 方法中,首先调用父类的 construct 方法获取输出。然后,对输出的 logits 进行处理,得到 shift_logits(去除了最后一个位置的 logits),对 labels 也进行相应处理得到 shift_labels(去除了第一个位置的 labels)。

        接着,使用 mindspore 中的 ops.cross_entropy 函数计算交叉熵损失。将 shift_logits 和 shift_labels 展平后作为参数传入,同时指定了 ignore_index 为 tokenizer.pad_token_id,这通常表示在计算损失时忽略填充的标记。

        第二步:实现一种先上升后下降的学习率调整策略,前期通过热身逐渐上升学习率,后期随着训练步数的增加逐渐降低学习率。

        代码如下:

from mindspore import ops  
from mindspore.nn.learning_rate_schedule import LearningRateSchedule  
class LinearWithWarmUp(LearningRateSchedule):  """ Warmup-decay learning rate. """  def __init__(self, learning_rate, num_warmup_steps, num_training_steps):  super().__init__()  self.learning_rate = learning_rate  self.num_warmup_steps = num_warmup_steps  self.num_training_steps = num_training_steps  def construct(self, global_step):  if global_step < self.num_warmup_steps:  return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate  return ops.maximum(  0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))  ) * self.learning_rate  

        分析:定义了一个名为 LinearWithWarmUp 的类,它继承自

        mindspore.nn.learning_rate_schedule.LearningRateSchedule 。

        __init__ 方法用于初始化类的属性,包括学习率 learning_rate 、热身步数 num_warmup_steps 和总训练步数 num_training_steps 。

        construct 方法根据传入的当前全局步数 global_step 计算学习率。

        如果 global_step 小于热身步数 num_warmup_steps ,则学习率的计算方式为 global_step 除以最大为 1 的 num_warmup_steps ,再乘以学习率 learning_rate ,实现热身阶段学习率的逐渐上升。

        如果 global_step 大于等于热身步数,学习率的计算方式为 (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps)) 乘以学习率 learning_rate ,并且使用 ops.maximum 函数确保学习率不为负,实现训练后期学习率的逐渐下降。

​​​​​​​模型训练


        第一步:为模型的训练进行准备工作,包括配置模型、设置学习率调度器和优化器,并获取模型的参数数量信息。

        代码如下:

num_epochs = 1  
warmup_steps = 2000  
learning_rate = 1.5e-4  
num_training_steps = num_epochs * train_dataset.get_dataset_size()  
from mindspore import nn  
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel  
config = GPT2Config(vocab_size=len(tokenizer))  
model = GPT2ForSummarization(config)  
lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)  
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)  
# 记录模型参数数量  
print('number of model parameters: {}'.format(model.num_parameters())) 

        分析:定义了一些训练相关的参数,如训练轮数 num_epochs 为 1,热身步数 warmup_steps 为 2000,学习率 learning_rate 为 1.5e-4 。然后根据训练轮数和训练数据集的大小计算出总的训练步数 num_training_steps 。

        配置了 GPT2 模型的参数 config ,其中指定了词汇表大小。

        创建了 GPT2ForSummarization 模型 model 。

        创建了一个名为 lr_scheduler 的学习率调度器 LinearWithWarmUp ,使用之前定义的学习率、热身步数和总训练步数进行初始化。

        使用 nn.AdamWeightDecay 优化器,并将模型的可训练参数和学习率调度器传递给它进行优化。

        最后打印出模型的参数数量。

        运行结果:

        number of model parameters: 102068736

        第二步:设定检查点的保存路径、名称、保存频次以及最大保存数量。将模型、训练数据集、训练轮数、优化器和回调函数传入。开启混合精度的训练模式。启动训练进程并指定目标列。

        代码如下:

from mindnlp._legacy.engine import Trainer  
from mindnlp._legacy.engine.callbacks import CheckpointCallback  
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization', epochs=1, keep_checkpoint_max=2)  
trainer = Trainer(network=model, train_dataset=train_dataset,  epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)  
trainer.set_amp(level='O1')  # 开启混合精度  
trainer.run(tgt_columns="labels")  

        分析:首先,从 mindnlp._legacy.engine 模块导入 Trainer 类和 CheckpointCallback 回调函数。

        然后,创建了一个 CheckpointCallback 实例 ckpoint_cb ,设置了保存检查点的路径为 'checkpoint' ,检查点的名称为 'gpt2_summarization' ,每 1 个 epoch 保存一次检查点,最多保存 2 个检查点。

        接下来,创建了 Trainer 实例 trainer ,将模型 model 、训练数据集 train_dataset 、训练轮数设置为 1 、优化器 optimizer 以及之前创建的回调函数 ckpoint_cb 传递给它。

        之后,通过 trainer.set_amp(level='O1') 开启了混合精度训练模式。

        最后,使用 trainer.run(tgt_columns="labels") 来启动训练过程,并指定目标列是 "labels" ,即训练过程中关注的目标列是 "labels" 。

​​​​​​​模型推理


        第一步:处理测试数据集。

        代码如下:

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):  def read_map(text):  data = json.loads(text.tobytes())  return np.array(data['article']), np.array(data['summarization'])  def pad(article):  tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)  return tokenized['input_ids']  dataset = dataset.map(read_map, 'text', ['article', 'summary'])  dataset = dataset.map(pad, 'article', ['input_ids'])     dataset = dataset.batch(batch_size)  return dataset  
test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)  
print(next(test_dataset.create_tuple_iterator(output_numpy=True)))  

        分析:定义了一个名为 process_test_dataset 的函数,用于处理测试数据集。

        函数内部:read_map 函数用于将输入的文本数据解析为文章和摘要的数组。pad 函数使用 tokenizer 对文章进行处理,并截断使其长度不超过指定的最大序列长度减去最大摘要长度,然后返回处理后的输入 ID 序列。

        在函数主体中,首先使用 map 方法应用 read_map 函数将数据解析为文章和摘要,然后应用 pad 函数处理文章,再将数据集按指定的批大小进行分批。

        在主程序中,调用 process_test_dataset 函数处理测试数据集 test_dataset ,并设置批大小为 1 。最后,使用 next 函数获取处理后的数据集的下一个元素,并打印出来。

        第二步:加载预训练模型并设置模型为评估模式,遍历测试数据集的迭代器。

        代码如下:

model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)  
model.set_train(False)  
model.config.eos_token_id = model.config.sep_token_id  
i = 0  
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():  output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)  output_text = tokenizer.decode(output_ids[0].tolist())  print(output_text)  i += 1  if i == 1:  break  

        分析:首先,从指定的检查点文件('./checkpoint/gpt2_summarization_epoch_0.ckpt')加载预训练的 GPT2LMHeadModel 模型,并使用给定的配置 config 。然后设置模型为评估模式(set_train(False)),并将模型配置中的结束标记 ID 设置为分隔标记 ID 。

        接下来,通过遍历测试数据集的迭代器,对于每个输入 ID 和原始摘要对,使用模型进行生成。生成时设置最大新生成的标记数为 50,束搜索的束数量为 5,不重复的 n 元语法大小为 2。然后对生成的输出 ID 进行解码得到输出文本,并打印输出。

        最后,设置一个计数器 i ,当 i 达到 1 时停止循环。

        运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1473766.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

2024年AI技术深入研究

2024年AI技术持续快速发展,应用领域广泛,产业发展迅速,市场趋势积极,学术研究深入。 AI技术进展大模型发展 2024年,智谱AI正在研发对标OpenAI Sora的高质量文生视频模型,预计最快年内发布。智谱AI的进展显示了国内AI大模型领域的快速发展,以及与国际领先技术的竞争态势…

Ubuntu20.04突然没网的一种解决办法

本来要学一下点云地图处理&#xff0c;用octomap库&#xff0c;但是提示少了octomap-server库&#xff0c;然后通过下面命令安装的时候&#xff1a; sudo apt install ros-noetic-octomap-server 提示&#xff1a;错误:7 https://mirrors.ustc.edu.cn/ubuntu focal-security …

jni原理和实现

一、jni原理 主要就是通过数据类型签名和反射来实现java与c/c方法进行交互的 数据类型签名对应表 javac/cbooleanZbyteBcharCshortSintIlongLfloatFdoubleDvoidVobjectL开头&#xff0c;然后以/分割包的完整类型&#xff0c;后面再加; 比如String的签名就是Ljava/long/Strin…

【C++】模板进阶--保姆级解析(什么是非类型模板参数?什么是模板的特化?模板的特化如何应用?)

目录 一、前言 二、什么是C模板&#xff1f; &#x1f4a6;泛型编程的思想 &#x1f4a6;C模板的分类 三、非类型模板参数 ⚡问题引入⚡ ⚡非类型模板参数的使用⚡ &#x1f525;非类型模板参数的定义 &#x1f525;非类型模板参数的两种类型 &#x1f52…

机器学习之保存与加载

前言 模型的数据需要存储和加载&#xff0c;这节介绍存储和加载的方式方法。 存和加载模型权重 保存模型使用save_checkpoint接口&#xff0c;传入网络和指定的保存路径&#xff0c;要加载模型权重&#xff0c;需要先创建相同模型的实例&#xff0c;然后使用load_checkpoint…

大厂面试官赞不绝口的后端技术亮点【后端项目亮点合集(2)】

本文将持续更新~~ hello hello~ &#xff0c;这里是绝命Coding——老白~&#x1f496;&#x1f496; &#xff0c;欢迎大家点赞&#x1f973;&#x1f973;关注&#x1f4a5;&#x1f4a5;收藏&#x1f339;&#x1f339;&#x1f339; &#x1f4a5;个人主页&#xff1a;绝命C…

Redis 中 Set 和 Zset 类型

目录 1.Set类型 1.1 Set集合 1.2 普通命令 1.3 集合操作 1.4 内部编码 1.5 使用场景 2.Zset类型 2.1 Zset有序集合 2.2 普通命令 2.3 集合间操作 2.4 内部编码 2.5 使用场景 1.Set类型 1.1 Set集合 集合类型也是保存多个字符串类型的元素&#xff0c;但是和列表类型不同的是&…

QT的编译过程(底层逻辑)

qmake -project 用于从源代码生成项目文件&#xff0c;qmake 用于从项目文件生成 Makefile&#xff0c;而 make 用于根据 Makefile 构建项目。 详细解释&#xff1a; qmake -project 这个命令用于从源代码目录生成一个初始的 Qt 项目文件&#xff08;.pro 文件&#xff09;。它…

2.1 tmux和vim

文章目录 前言概述tmuxvim总结 前言 开始学习的时间是 2024.7.6 ,13&#xff1a;47 概述 最好多使用&#xff0c;练成条件反射式的 直接使用终端的工具&#xff0c;可以连接到服务器&#xff0c;不需要使用本地的软件 tmux 这个主要有两个功能&#xff0c;第一个功能是分…

SpringBoot项目练习

文章目录 SpringBootVue后台管理系统所需软件下载、安装、版本查询Vue搭建一个简单的Vue项目 Spring项目1项目架构 SpringBootVue后台管理系统 学习视频&#xff1a; https://www.bilibili.com/video/BV1U44y1W77D/?spm_id_from333.337.search-card.all.click&vd_sourcec…

2024年最新运维面试题(附答案)

作者简介&#xff1a;一名云计算网络运维人员、每天分享网络与运维的技术与干货。 公众号&#xff1a;网络豆云计算学堂 座右铭&#xff1a;低头赶路&#xff0c;敬事如仪 个人主页&#xff1a; 网络豆的主页​​​​​ 一&#xff0e;选择题 1.HTTP协议默认使用哪个端口…

【刷题汇总--大数加法、 链表相加(二)、大数乘法】

C日常刷题积累 今日刷题汇总 - day0061、大数加法1.1、题目1.2、思路1.3、程序实现 2、 链表相加(二)2.1、题目2.2、思路2.3、程序实现 3、大数乘法3.1、题目3.2、思路3.3、程序实现 4、题目链接 今日刷题汇总 - day006 1、大数加法 1.1、题目 1.2、思路 读完题,明白大数相加…

最新版情侣飞行棋dofm,已解锁高阶私密模式,单身狗务必绕道!(附深夜学习资源)

今天阿星要跟大家聊一款让阿星这个大老爷们儿面红耳赤的神奇游戏——情侣飞行棋。它的神奇之处就在于专为情侣设计&#xff0c;能让情侣之间感情迅速升温&#xff0c;但单身狗们请自觉绕道&#xff0c;不然后果自负哦&#xff01; 打开游戏&#xff0c;界面清新&#xff0c;操…

平价猫粮新选择!福派斯鲜肉猫粮,让猫咪享受美味大餐!

福派斯鲜肉猫粮&#xff0c;作为一款备受铲屎官们青睐的猫粮品牌&#xff0c;凭借其卓越的品质和高性价比&#xff0c;为众多猫主带来了健康与美味的双重享受。接下来&#xff0c;我们将从多个维度对这款猫粮进行解析&#xff0c;让各位铲屎官更加全面地了解它的魅力所在。 1️…

11.x86游戏实战-汇编指令add sub inc dec

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 本次游戏没法给 内容参考于&#xff1a;微尘网络安全 上一个内容&#xff1a;10.x86游戏实战-汇编指令lea 首先双击下图红框位置 然后在下图红框位置输入0 然…

电商视角如何理解动态IP与静态IP

在电子商务的蓬勃发展中&#xff0c;网络基础设施的稳定性和安全性是至关重要的。其中&#xff0c;IP地址作为网络设备间通信的基础&#xff0c;扮演着举足轻重的角色。从电商的视角出发&#xff0c;我们可以将动态IP和静态IP比作电商平台上不同类型的店铺安排&#xff0c;以此…

记录一次MySQL恢复

一、前言 此文章由一次数据库被黑客删除而引发 由于对于Linux操作、docker使用、MySQL原理这些都相对不是很熟悉&#xff0c;所以记录下来避免以后在工作中遇到类似的问题而惊慌失措。 1.MySQL环境现状 docker管理的&#xff0c;8.0.26版本 启动语句: docker run -d -p 33…

智慧矿山建设规划方案(121页Word)

智慧矿山建设项目方案摘要 一、项目背景及现状分析 项目背景 随着信息技术的迅猛发展&#xff0c;智慧化、数字化已成为矿山行业转型升级的必然趋势。智慧矿山建设项目旨在通过集成先进的信息技术手段&#xff0c;实现对矿山生产、管理、安全等全过程的智能化监控与管理&…

【ARMv8/v9 GIC 系列 1.5 -- Enabling the distribution of interrupts】

请阅读【ARM GICv3/v4 实战学习 】 文章目录 Enabling the distribution of interruptsGIC Distributor 中断组分发控制CPU Interface 中断组分发控制Physical LPIs 的启用Summary Enabling the distribution of interrupts 在ARM GICv3和GICv4体系结构中&#xff0c;中断分发…

如何搭建Ubuntu环境安装禅道

一、禅道安装部署的环境要求 禅道安装部署环境推荐使用 Linux Apache PHP7.0以上版本 MySQL5.5以上版本/MariaDB的组合。Nginx其次&#xff0c;不推荐IIS PHP组合。禅道需要使用PHP的这些扩展&#xff1a;pdo、pdo_mysql、json、filte、openssl、mbstring、zlib、curl、gd、…