BERT的代码实现

目录

1.BERT的理论

2.代码实现 

 2.1构建输入数据格式

 2.2定义BERT编码器的类

 2.3BERT的两个任务

2.3.1任务一:Masked Language Modeling MLM掩蔽语言模型任务 

2.3.2 任务二:next sentence prediction

3.整合代码 

 4.知识点个人理解


 

1.BERT的理论

BERT全称叫做Bidirectional Encoder Representations from Transformers, 论文地址: [1810.04805] BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (arxiv.org)

BERT是谷歌AI研究院在2018年10月提出的一种预训练模型. BERT本质上就是Transformer模型的encoder部分, 并且对encoder做了一些改进.

  • 官方代码和预训练模型 Github: https://github.com/google-research/bert

下图中编码器部分即BERT的基本结构.

  

2.代码实现 

import torch
from torch import nn
import dltools

 2.1构建输入数据格式

def get_tokens_and_segments(tokens_a, tokens_b=None):#classification 分类#BERT是两句话作为一对句子一同传入的,也可以单独传一句话,若序列长度长,可以补padding#假设先传一句话tokens_atokens = ['<cls>'] + tokens_a + ['<sep>']  #tokens_embedding层的处理segments = [0] * (len(tokens_a) + 2)  #判断词元属于哪一句话,加标记,0属于第一句话if tokens_b is not None:tokens += tokens_b + ['sep']segments += [1] * (len(tokens_b) + 1)return tokens, segments#测试上面的函数
get_tokens_and_segments([1, 2, 3], [4, 5, 6])

(['<cls>', 1, 2, 3, '<sep>', 4, 5, 6, 'sep'], [0, 0, 0, 0, 0, 1, 1, 1, 1])

 2.2定义BERT编码器的类

class BERTEncoder(nn.Module):#由于前馈网络的ffn_num_outputs = num_hiddens,没有初始化传入#__init__()里面的参数,是创建类的时候传入的参数def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout,max_len=1000, key_size=768, query_size=768, value_size=768, **kwargs):super().__init__(**kwargs)#token_embeddings层self.token_emdedding = nn.Embedding(vocab_size, num_hiddens)#segment_embedding层  (传入两个句子,所以第0维为2)self.segment_embedding = nn.Embedding(2, num_hiddens)#pos_embedding层  :位置嵌入层是可以学习的, 用nn.Parameter()定义可学习的参数self.pos_embedding = nn.Parameter(torch.randn(1, max_len, num_hiddens))#设置Encoder_block的数量self.blks = nn.Sequential()  #为使用的Encoder_block依次编号for i in range(num_layers):  #有几层网络循环几层self.blks.add_module(f'{i}', dltools.EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout))#__init__()里面的参数,是创建类的时候传入的参数#foward里面的参数是创建完类对象之后,调用类方法时传入的参数def forward(self, tokens, segments, valid_lens):#X = token_embedding + segment_embedding + pos_embedding#传入的token_embedding,segment_embedding两者的shape相同,可以直接相加X = self.token_emdedding(tokens) + self.segment_embedding(segments)#pos_embedding与前两层的数据shape不相同,不能直接相加#切片让self.pos_embedding的第1维度的数据切片到token_embedding,segment_embedding相加之后的数X = X + self.pos_embedding.data[:, :X.shape[1], :]for blk in self.blks:X = blk(X, valid_lens)return X  
#测试上面代码#创建BERTEncoder类对象
vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4
norm_shape, ffn_num_input, num_layers, dropout = [768], 768, 2, 0.2
encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)tokens = torch.randint(0, vocab_size, (2, 8)) #生成随机正整数
segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 0, 1, 1, 1, 1]])
#调用类方法
encoded_X = encoder(tokens, segments, None)encoded_X.shape
torch.Size([2, 8, 768])

#  nn.Sequential()是PyTorch中的一个类,它允许用户将多个计算层按照顺序组合成一个模型。在深度学习中,模型可以是由各种不同类型的层组成的,例如卷积层、池化层、全连接层等。nn.Sequential()方法可以将这些层组合在一起,形成一个整体模型。 

 2.3BERT的两个任务

2.3.1任务一:Masked Language Modeling MLM掩蔽语言模型任务 

class MaskLM(nn.Module):def __init__(self, vocab_size, num_inputs=768, **kwargs):super().__init__(**kwargs)self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),  #全连接层nn.ReLU(),  nn.LayerNorm(num_hiddens), nn.Linear(num_hiddens, vocab_size))  #输出层#X表示随机(15%概率)将一些词元换成mask#pred_positions表示已经处理好的80%概率将选中的词换成mask>, 10%概率换成随机词元,10%概率保持原有词元#pred_position是二维数据def forward(self, X, pred_positions):  num_pred_positions = pred_positions.shape[1]  #索引出80%、10%、10%三个概率选出的需要转换的词位置数量pred_positions = pred_positions.reshape(-1)  #变成一维数据batch_size = X.shape[0]  #获取批次batch_idx = torch.arange(0, batch_size) #获取批次的编号#将批次编号与元素数量对应起来#例如:batch_size = [0, 1]   -->   [0, 0, 0, 1, 1, 1]batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)  #将batch_idx中每个元素重复num_pred_positions次#把要预测位置的数据取出来masked_X = X[batch_idx, pred_positions]masked_X = masked_X.reshape(batch_size, num_pred_positions, -1)  #还原维度mlm_Y_hat = self.mlp(masked_X)return mlm_Y_hat
#测试代码mlm = MaskLM(vocab_size, num_hiddens)
mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])
mlm_Y_hat = mlm(encoded_X, mlm_positions)mlm_Y_hat.shape    #2:2个批次,   3:三个需要转换词元的位置     10000:计算的概率数量(在最后会用softmax函数计算分类结果),vocab_size有10000个,
torch.Size([2, 3, 10000])
mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])  #假设真实值
loss = nn.CrossEntropyLoss(reduction='none')
mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1))  # mlm_Y_hat的shape=(6, 10000)     mlm_Y的shape=(6)
mlm_l.shape

torch.Size([6])

2.3.2 任务二:next sentence prediction

class NextSentencePred(nn.Module):def __init__(self, num_inputs, **kwargs):super().__init__(**kwargs)self.output = nn.Linear(num_inputs, 2)  #预测输入的句子是否为下一个句子,预测目标值为“是/否”二分类问题def forward(self, X):#X的形状(batch_size, num_hiddens)return self.output(X)
#测试代码encoded_X = torch.flatten(encoded_X, start_dim=1)  #将数据展平,相当于reshape
nsp = NextSentencePred(encoded_X.shape[-1])
nsp_Y_hat = nsp(encoded_X)nsp_Y_hat.shape
torch.Size([2, 2])
#计算损失
nsp_y = torch.tensor([0, 1])   #假设真实值
nsp_1 = loss(nsp_Y_hat, nsp_y)
nsp_1.shape

torch.Size([2])

3.整合代码 

class BERTModel(nn.Module):def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout,max_len=1000, key_size=768, query_size=768, value_size=768,hid_in_features=768, mlm_in_features=768, nsp_in_features=768, **kwargs):super().__init__(**kwargs)#初始化编码器对象self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout,max_len=max_len, key_size=key_size, query_size=query_size, value_size=value_size)#掩蔽语言模型任务self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)#中间隐藏层的线性转换+激活函数self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens), nn.Tanh())#预测出下一句self.nsp = NextSentencePred(nsp_in_features)def forward(self, tokens, seqments, valid_lens=None, pred_position=None):encoded_X = self.encoder(tokens, seqments, valid_lens)if pred_position is not None:mlm_Y_hat = self.mlm(encoded_X, pred_position)else:pred_position = None#0表示<cls>标记的索引nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))return encoded_X, mlm_Y_hat, nsp_Y_hat

 4.知识点个人理解

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1542604.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Linux 静态库与动态库的制作与使用

在Linux中&#xff0c;库library是一组函数和资源的集合&#xff0c;他们可以被不同的程序共享和使用&#xff0c;库的主要目的是代码重用&#xff0c;减少内存占用&#xff0c;并简化程序的维护。 Linux操作系统支持的函数库分为&#xff1a;静态库和动态库。 静态库&#xf…

【线程池】Tomcat线程池

版本&#xff1a;tomcat-embed-core-10.1.8.jar 前言 最近面试被问到 Tomcat 线程池&#xff0c;因为之前只看过 JDK 线程池&#xff0c;没啥头绪。在微服务横行的今天&#xff0c;确实还是有必要研究研究 Tomcat 的线程池 Tomcat 线程池和 JDK 线程池最大的不同就是它先把最…

二分+优先队列例题总结(icpc vp+牛客小白月赛)

题目 思路分析 要求输出最小的非负整数k&#xff0c;同时我们还要判断是否存在x让整个序列满足上述条件。 当k等于某个值时&#xff0c;我们可以得到x的一个取值区间&#xff0c;若所有元素得到的x的区间都有交集(重合)的话,那么说明存在x满足条件。因为b[i]的取值为1e9&…

Maven-一、分模块开发

Maven进阶 文章目录 Maven进阶前言创建新模块向新模块装入内容使用新模块把模块部署到本地仓库补充总结 前言 分模块开发可以把一个完整项目中的不同功能分为不同模块管理&#xff0c;然后模块间可以相互调用&#xff0c;该篇以一个SSM项目为目标展示如何使用maven分模块管理。…

没错,我给androidx修了一个bug!

不容易啊&#xff0c;必须先截图留恋&#x1f601; 这个bug是发生在xml中给AppcompatTextView设置textFontWeight&#xff0c;但是却无法生效。修复bug的代码也很简单&#xff0c;总共就几行代码&#xff0c;但是在找引起这个bug的原因和后面给androidx提pr却花了很久。 //App…

云手机的海外原生IP有什么用?

在全球数字化进程不断加快的背景下&#xff0c;企业对网络的依赖程度日益加深。云手机作为一项创新的工具&#xff0c;正逐步成为企业优化网络结构和全球业务拓展的必备。尤其是云手机所具备的海外原生IP功能&#xff0c;为企业进入国际市场提供了独特的竞争优势。 什么是海外原…

DNF Decouple and Feedback Network for Seeing in the Dark

DNF: Decouple and Feedback Network for Seeing in the Dark 在深度学习领域&#xff0c;尤其是在低光照图像增强的应用中&#xff0c;RAW数据的独特属性展现出了巨大的潜力。然而&#xff0c;现有架构在单阶段和多阶段方法中都存在性能瓶颈。单阶段方法由于域歧义&#xff0c…

如何使用 3 种简单的方法将手写内容转换为文本

手写比文本更具艺术性&#xff0c;这就是许多人追求手写字体的原因。有时&#xff0c;我们必须将手写内容转换为文本&#xff0c;以便于存储和阅读。本文将指导您如何轻松转换它。 此外&#xff0c;通常以扫描的手写内容编辑文本很困难&#xff0c;但使用奇客免费OCR&#xff…

视觉距离与轴距离的转换方法

1.找一个明显的参照物&#xff0c;用上方固定的相机拍一下。保存好图片 2.轴用定长距离如1mm移动一下。 3.再用上相机再取一张图。 4.最后用halcon 将两图叠加 显示 效果如下 从图上可以明显的看出有两个图&#xff0c;红色标识的地方。 这时可以用halcon的工具画一个长方形…

Cesium 绘制可编辑点

Cesium Point点 实现可编辑的pointEntity 实体 文章目录 Cesium Point点前言一、使用步骤二、使用方法二、具体实现1. 开始绘制2.绘制事件监听三、 完整代码前言 支持 鼠标按下 拖动修改点,释放修改完成。 一、使用步骤 1、点击 按钮 开始 绘制,单击地图 绘制完成 2、编辑…

误差评估,均方误差、均方根误差、标准差、方差

均方根误差 RMSE/RMS 定义 RMSE是观察值与真实值偏差的平方&#xff0c;对于一组观测值 y i y_i yi​ 和对应的真值 t i t_i ti​ R M S E 1 n ∑ i 1 n ( y i − t i ) &#xff0c;其中n是观测次数 RMSE\sqrt{\frac1n \sum_{i1}^n (y_i-t_i)} \text{&#xff0c;其中n是…

2.个人电脑部署MySQL,傻瓜式教程带你拥有个人金融数据库!

2.个人电脑部署MySQL&#xff0c;傻瓜式教程带你拥有个人金融数据库&#xff01; ‍ 前边我们提到&#xff0c;比较适合做量化投研的数据库是MySQL&#xff0c;开源免费。所以今天我就写一篇教程来教大家如何在自己的环境中部署MySQL。 在不同的设备或系统中安装MySQL的步骤…

局部凸空间及其在算子空间中的应用之四——归纳极限空间2

局部凸空间及其在算子空间中的应用之四——归纳极限空间2 前言一、归纳极限拓扑中极限的含义总结 数学的真理是绝对的&#xff0c;它超越了时间和空间。——约翰冯诺伊曼 前言 在上一篇文章中&#xff0c;我们讨论了归纳极限拓扑的概念和与连续线性算子有关的一个重要结论。认…

为什么编程很难?

之前有一个很紧急的项目&#xff0c;项目中有一个bug始终没有被解决&#xff0c;托了十几天之后&#xff0c;就让我过去协助解决这个bug。这个项目是使用C语言生成硬件code&#xff0c;是更底层的verilog&#xff0c;也叫做HLS开发。 项目中的这段代码并不复杂&#xff0c;代码…

postman控制变量和常用方法

1、添加环境&#xff1a; 2、环境添加变量&#xff1a; 3、配置不同的环境&#xff1a;local、dev、sit、uat、pro 4、 接口调用 5、清除cookie方法&#xff1a; 6、下载文件方法&#xff1a;

calibre-web报错:File type isn‘t allowed to be uploaded to this server

calibre-web报错&#xff1a;File type isnt allowed to be uploaded to this server 最新版的calibre-web在Upload时候会报错&#xff1a; File type isnt allowed to be uploaded to this server 解决方案&#xff1a; Admin - Basic Configuration - Security Settings 把…

2024PDF内容修改秘籍:工具推荐与技巧分享

现在我们使用PDF文档的频率越来越高了&#xff0c;很多时候收到的表格之类的资料也都是PDF格式的&#xff0c;如果进行转换之后编辑再转换为PDF格式还是有点麻烦的&#xff0c;那么pdf怎么编辑修改内容呢&#xff1f;这篇文章我将介绍几款可以直接编辑PDF文件的工具来提高我们的…

鸿蒙next 带你玩转鸿蒙拍照和相册获取图片

前言导读 各位网友和同学&#xff0c;相信大家在开发app的过程中都有遇到上传图片到服务器的需求&#xff0c;我们一般是有两种方式&#xff0c;拍照获取照片或者调用相册获取照片&#xff0c;今天我们就分享一个小案例讲一下这两种情况的实现。废话不多说我们正式开始 效果图…

安全带检测系统源码分享

安全带检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Visio…

Linux,uboot,kernel启动流程,S5PV210芯片的启动流程,DRAM控制器初始化流程

一、S5PV210芯片的DRAM控制器介绍、初始化DDR的流程分析 1、DRAM的地址空间 1)从地址映射图可以知道&#xff0c;S5PV210有两个DRAM端口。 DRAM0的内存地址范围&#xff1a;0x20000000&#xff5e;0x3FFFFFFF&#xff08;512MB&#xff09;&#xff1b;DRAM1:的内存地址范围…