BERT训练环节(代码实现)

1.代码实现

#导包
import torch
from torch import nn
import dltools
#加载数据需要用到的声明变量
batch_size, max_len = 1, 64
#获取训练数据迭代器、词汇表
train_iter, vocab = dltools.load_data_wiki(batch_size, max_len)
#其余都是二维数组
#tokens, segments, valid_lens(一维), pred_position, mlm_weights, mlm, nsp(一维)对应每条数据i中包含的数据
for i in train_iter:  #遍历迭代器break   #只遍历一条数据
[tensor([[    3,    25,     0,  4993,     0,    24,     4,    26,    13,     2,158,    20,     5,    73,  1399,     2,     9,   813,     9,   987,45,    26,    52,    46,    53,   158,     2,     5,  3140,  5880,9,   543,     6,  6974,     2,     2,   315,     6,     8,     5,8698,     8, 17229,     9,   308,     2,     4,     1,     1,     1,1,     1,     1,     1,     1,     1,     1,     1,     1,     1,1,     1,     1,     1]]),tensor([[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),tensor([47.]),tensor([[ 9, 15, 26, 32, 34, 35, 45,  0,  0,  0]]),tensor([[1., 1., 1., 1., 1., 1., 1., 0., 0., 0.]]),tensor([[ 484, 1288,   20,    6, 2808,    9,   18,    0,    0,    0]]),tensor([0])]
#创建BERT网络模型
net = dltools.BERTModel(len(vocab), num_hiddens=128, norm_shape=[128], ffn_num_input=128, ffn_num_hiddens=256, num_heads=2, num_layers=2, dropout=0.2, key_size=128, query_size=128, value_size=128, hid_in_features=128, mlm_in_features=128, nsp_in_features=128)
#调用设备上的GPU
devices = dltools.try_all_gpus()
#损失函数对象
loss = nn.CrossEntropyLoss()   #多分类问题,使用交叉熵
#@save    #表示用于指示某些代码应该被保存或导出,以便于管理和重用
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y):#前向传播#获取遮蔽词元的预测结果、下一个句子的预测结果_, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X, valid_lens_x.reshape(-1), pred_positions_X)#计算遮蔽语言模型的损失mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1,1)mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)   #MLM损失函数的归一化版本   #加一个很小的数1e-8,防止分母为0,抵消上一行代码乘以的数值#计算下一个句子预测任务的损失nsp_l = loss(nsp_Y_hat, nsp_y)l = mlm_l + nsp_lreturn mlm_l, nsp_l, l  
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):  #文本词元样本量太多,全跑完花费的时间太多,若num_steps=1在BERT中表示,跑了1个batch_sizenet = nn.DataParallel(net, device_ids=devices).to(devices[0])  #调用设备的GPUtrainer = torch.optim.Adam(net.parameters(), lr=0.01)   #梯度下降的优化算法Adamstep, timer = 0, dltools.Timer()  #设置计时器#调用画图工具animator = dltools.Animator(xlabel='step', ylabel='loss', xlim=[1, num_steps], legend=['mlm', 'nsp'])#遮蔽语言模型损失的和, 下一句预测任务损失的和, 句子对的数量, 计数metric = dltools.Accumulator(4)  #Accumulator类被设计用来收集和累加各种指标(metric)num_steps_reached = False  #设置一个判断标志, 训练步数是否达到预设的步数while step < num_steps and not num_steps_reached:for tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y in train_iter:#将遍历的数据发送到设备上tokens_X = tokens_X.to(devices[0])segments_X = segments_X.to(devices[0])valid_lens_x = valid_lens_x.to(devices[0])pred_positions_X = pred_positions_X.to(devices[0])mlm_weights_X = mlm_weights_X.to(devices[0])mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])#梯度清零trainer.zero_grad()timer.start()  #开始计时mlm_l, nsp_l, l = _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)l.backward()  #反向传播trainer.step()  #梯度更新metric.add(mlm_l, nsp_l, tokens_X.shape[0], l)  #累积的参数指标timer.stop() #计时停止animator.add(step + 1, (metric[0] / metric[3], metric[1] / metric[3]))  #画图的step += 1  #训练完一个batch_size,就+1if step == num_steps:  #若步数与预设的训练步数相等num_steps_reached = True   #判断标志改为Truebreak  #退出while循环print(f'MLM loss {metric[0] / metric[3]:.3f}, 'f'NSP loss {metric[1] / metric[3]:.3f}')print(f'{metric[2]/ timer.sum():.1f} sentence pairs/sec on 'f'{str(devices)}')
train_bert(train_iter, net, loss, len(vocab), devices, 500)

 

def get_bert_encoding(net, tokens_a, tokens_b=None):tokens, segments = dltools.get_tokens_and_segments(tokens_a, tokens_b)token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0)  #unsqueeze(0)增加一个维度segments = torch.tensor(segments, device=devices[0]).unsqueeze(0)  valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)endoced_X, _, _ = net(token_ids, segments, valid_len)return endoced_X
tokens_a = ['a', 'crane', 'is', 'flying']
encoded_text = get_bert_encoding(net, tokens_a)
# 词元:'<cls>','a','crane','is','flying','<sep>'
encoded_text_cls = encoded_text[:, 0, :]
encoded_text_crane = encoded_text[:, 2, :]
encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]
(torch.Size([1, 6, 128]),torch.Size([1, 128]),tensor([-0.5872, -0.0510, -0.7376], device='cuda:0', grad_fn=<SliceBackward0>))
encoded_text_crane

 

tensor([[-5.8725e-01, -5.0994e-02, -7.3764e-01, -4.3832e-02,  9.2467e-02,1.2745e+00,  2.7062e-01,  6.0271e-01, -5.5055e-02,  7.5122e-02,4.4872e-01,  7.5821e-01, -6.1558e-02, -1.2549e+00,  2.4479e-01,1.3132e+00, -1.0382e+00, -4.7851e-03, -6.3590e-01, -1.3180e+00,5.2245e-02,  5.0982e-01,  7.4168e-02, -2.2352e+00,  7.4425e-02,5.0371e-01,  7.2120e-02, -4.6384e-01, -1.6588e+00,  6.3987e-01,-6.4567e-01,  1.7187e+00, -6.9696e-01,  5.6788e-01,  3.2628e-01,-1.0486e+00, -7.2610e-01,  5.7909e-02, -1.6380e-01, -1.2834e+00,1.6431e+00, -1.5972e+00, -4.5678e-03,  8.8022e-02,  5.5931e-02,-7.2332e-02, -4.9313e-01, -4.2971e+00,  6.9757e-01,  7.0690e-02,-1.8613e+00,  2.0366e-01,  8.9868e-01, -3.4565e-01,  9.6776e-02,1.3699e-02,  7.1410e-01,  5.4820e-01,  9.7358e-01, -8.1038e-01,2.6216e-01, -5.7850e-01, -1.1969e-01, -2.5277e-01, -2.0046e-01,-1.6718e-01,  5.5540e-01, -1.8172e-01, -2.5639e-02, -6.0961e-01,-1.1521e-03, -9.2973e-02,  9.5226e-01, -2.4453e-01,  9.7340e-01,-1.7908e+00, -2.9840e-02,  2.3087e+00,  2.4889e-01, -7.2734e-01,2.1827e+00, -1.1172e+00, -7.0915e-02,  2.5138e+00, -1.0356e+00,-3.7332e-02, -5.6668e-01,  5.2251e-01, -5.0058e-01,  1.7354e+00,4.0760e-01, -1.2982e-01, -7.0230e-01,  3.1563e+00,  1.8754e-01,2.0220e-01,  1.4500e-01,  2.3296e+00,  4.5522e-02,  1.1762e-01,1.0662e+00, -4.0858e+00,  1.6024e-01,  1.7885e+00, -2.7034e-01,-1.6869e-01, -8.7018e-02, -4.2451e-01,  1.1446e-01, -1.5761e+00,7.6947e-02,  2.4336e+00,  4.5346e-02, -6.5078e-02,  1.4203e+00,3.7165e-01, -7.9571e-01, -1.3515e+00,  4.1511e-02,  1.3561e-01,-3.3006e+00,  1.4821e-01,  1.3024e-01,  1.9966e-01, -8.5910e-01,1.4505e+00,  7.6774e-02,  9.3771e-01]], device='cuda:0',grad_fn=<SliceBackward0>)
tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']
encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)
# 词元:'<cls>','a','crane','driver','came','<sep>','he','just', 'left','<sep>'
encoded_pair_cls = encoded_pair[:, 0, :]
encoded_pair_crane = encoded_pair[:, 2, :]
encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]

 

(torch.Size([1, 10, 128]),torch.Size([1, 128]),tensor([-0.4637, -0.0569, -0.6119], device='cuda:0', grad_fn=<SliceBackward0>))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/150300.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

geodatatool(地图资源下载工具)3.8更新

geodatatool&#xff08;地图资源下载工具&#xff09;3.8&#xff08;新&#xff09;修复更新&#xff0c;修复更新包括&#xff1a; 1.选中下载数据时显示选中个数。 其它一些BUG修复。 如您有其它问题及需求&#xff0c;也可以联系我们&#xff0c;我们将持续维护更新该工…

刷题日记_DAY1

前言 这里记录每日随机刷的错题 两个数组的交集&#xff08;模拟&#xff09; 题目描述 题目解析 题目要求返回指定的两个字符串之间的距离&#xff0c;容易想到的一种解法就是暴力遍历&#xff0c;来个双循环&#xff0c;但时间复杂度就为N^2&#xff0c;不符合题意 for(…

【最新华为OD机试E卷-支持在线评测】绘图机器(100分)多语言题解-(Python/C/JavaScript/Java/Cpp)

🍭 大家好这里是春秋招笔试突围 ,一枚热爱算法的程序员 💻 ACM金牌🏅️团队 | 大厂实习经历 | 多年算法竞赛经历 ✨ 本系列打算持续跟新华为OD-E/D卷的多语言AC题解 🧩 大部分包含 Python / C / Javascript / Java / Cpp 多语言代码 👏 感谢大家的订阅➕ 和 喜欢�…

ESXI主机加入VCENTER现有集群提示出现常规性错误

背景&#xff1a;由于忘记了这台主机的root密码&#xff0c;所以在迁移完虚拟机后给这台主机重新安装了操作系统&#xff0c;装完操作系统加集群提示如下报错&#xff1a; 查阅了一些资料后发现主机的CPU是一样的&#xff0c;不需要开EVC&#xff1b; 也有一些说需要改这个配置…

【专题】2024年中国白酒行业数字化转型研究报告合集PDF分享(附原数据表)

原文链接&#xff1a;https://tecdat.cn/?p37755 消费人群趋于年轻化&#xff0c;消费需求迈向健康化&#xff0c;消费场景与渠道走向多元化&#xff0c;这些因素共同驱动企业凭借数据能力来适应市场的变化。从消费市场来看&#xff0c;消费群体、需求、场景及渠道皆展现出与…

图文组合商标部分驳回后优化后初审通过!

这几天以前有个企业的商标初审下来了&#xff0c;以前是加了图形个别部分没有通过初审&#xff0c;后面是把图形去掉重新用文字申请下来初审。 图形与文字同时申请&#xff0c;会分别审查有一个元素过不了&#xff0c;整体就会过不了&#xff0c;所以平常就会建议分开申请注册商…

【Linux实践】实验三:LINUX系统的文件操作命令

【Linux实践】实验三&#xff1a;LINUX系统的文件操作命令 实验目的实验内容实验步骤及结果1. 切换和查看目录2. 显示目录下的文件3. 创建和删除目录① mkdir② rm③ rmdir 4. 输出和重定向① 输出② 重定向 > 和 >> 5. 查看文件内容① cat② head 6. 权限7. 复制8. 排…

【微服务即时通讯系统】——etcd一致性键值存储系统,etcd的介绍,etcd的安装,etcd使用和功能测试

文章目录 etcd1. etcd的介绍1.1 etcd的概念 2. etcd的安装2.1 安装etcd2.2 安装etcd客户端C/C开发库 3. etcd使用3.1 etcd接口介绍 4. etcd使用测试4.1 原生接口使用测试4.2 封装etcd使用测试 etcd 1. etcd的介绍 1.1 etcd的概念 Etcd 是一个基于GO实现的 分布式、高可用、一致…

UE学习篇ContentExample解读------Blueprint_Communication-上

文章目录 总览描述批次阅览1.1 Basic communication with a target blueprint1.2 Basic communication via actor casting1.3 Blueprint communication via actor casting to child Blueprint1.4 Communicating with all actors of a specific class 概念总结致谢&#xff1a; …

vite分目录打包以及去掉默认的.gz 文件

1.vite打包情况介绍&#xff1a; 1.1vite在不进行任何配置的情况下&#xff0c;会将除开public的所有引用到资源打包编译添加哈希值至assets文件夹中&#xff08;非引用文件以及行内样式图片未被打包编译资源会被treeSharp直接忽略不打包&#xff09;&#xff0c;     1.2w…

SpringBoot框架之KOB项目 - 配置Mysql与注册登录模块(中)

修改Spring Security 登录验证模式 传统的验证登录模式 公开页面&#xff1a;输入url就可以直接访问授权页面&#xff1a;登录之后才可以访问 Jwt验证模式 容易实现跨域不需要在服务器端存储 对比于传统模式将所有的sessionId换成jwt token access token refresh token 过…

neo4j小白入门

1.建立几个学校的节点 1.1创建一个节点的Cypher命令 create (Variable:Lable {Key1:Value,Key2,Value2}) return Variable 1.2创建一个学校的节点 create (n:School{name:清华大学,code: 10003,establishmentDate:date ("1911-04-29")})return n 1.3一次创建几个…

在Markdown中实现内部查询

markdown实现内部查询 在想要跳转到的位置添加 <a idxxx></a> 标签&#xff0c;如下图&#xff1a; 然后按如下格式添加目录 [跳转文字](#id)&#xff1a; 如上操作即可实现markdown内部查询。 具体实现效果如下&#xff1a;

通过service访问Pod

假设Pod中的容器可能因为各种原因发生故障而死掉&#xff0c;Deployment等controller会通过动态创建和销毁Pod来保证应用整体的健壮性&#xff0c;换句话说&#xff0c;Pod是脆弱的&#xff0c;但应用是健壮的 每个Pod都有自己的Ip&#xff0c;当controller用新的Pod替代发生故…

seL4 Mapping(三)

官网链接: Mapping Mapping 这节课程主要是介绍seL4的虚存管理。 虚存 Virtual memory 除了用于操作硬件分页结构的内核原语之外&#xff0c;seL4不提供虚拟内存管理。用户必须为创建中间级分页结构&#xff0c;映射页面以及取消映射页面提供服务。 用户可以随意的定义他们…

6种常见位运算符+异或运算符的使用(加密、解密)

一、位运算符 位运算符进行的是整数与整数之间的运算 1、右移运算符&#xff1a;>> &#xff08;1&#xff09;相当于对整数除以2 &#xff08;2&#xff09;举例&#xff1a; int num 2; System.out.println(num >> 1); 2、左移运算符&#xff1a;<< …

opencv-python学习笔记10-图像形态学处理

目录 一、基本概念&#xff1a; &#xff08;1&#xff09;结构元素&#xff08;Structuring Element&#xff09;&#xff1a; &#xff08;2&#xff09;膨胀&#xff08;Dilation&#xff09;&#xff1a; &#xff08;3&#xff09;腐蚀&#xff08;Erosion&#xff0…

巧用解压软件:高效处理云盘文件

百度网盘支持多种文件格式&#xff0c;包括文本文件格式如.txt、.doc、.docx 等&#xff1b;图片文件格式如.jpg、.png 等&#xff1b;音频文件格式如.mp3、.wav 等&#xff1b;视频文件格式如.avi、.mp4 等&#xff1b;压缩文件格式如.zip、.rar、.7z 等&#xff1b;可执行文件…

进度条QProgressBar

进度条控价&#xff0c;用来只是任务的完成情况 值 包括当前值、最大值、最小值 // 获取和设置当前值 int value() const; void setValue(int);// 获取和设置最大值 int maximum() const; void setMaximum(int);// 获取和设置最小值 int minimum() const; void setMinimum(i…

http增删改查四种请求方式操纵数据库

注意&#xff1a;在manage.py项目入口文件中的路由配置里&#xff0c;返回响应的 return语句后面的代码不会执行&#xff0c;所以路由配置中每个模块代码要想都执行&#xff0c;不能出现return 激活虚拟环境&#xff1a;venv(我的虚拟环境名称&#xff09;\Scripts\activate …