【模型学习之路】手写+分析bert

手写+分析bert

目录

前言

架构

embeddings

Bertmodel

预训练任务

MLM

NSP

Bert

后话

netron可视化

code2flow可视化

fine tuning


前言

Attention is all you need!

读本文前,建议至少看懂【模型学习之路】手写+分析Transformer-CSDN博客。

毕竟Bert是transformer的变种之一。

架构

embeddings

Bert可以说就是transformer的Encoder,就像训练卷积网络时可以利用现成的网络然后fine tune就投入使用一样,Bert的动机就是训练一种预训练模型,之后根据不同的场景可以做不同的fine tune。

这里我们还是B代表批次(对于Bert,一个Batch可以输入一到两个句子,输入两个句子时,两个直接拼接就好了),m代表一个batch的单词数,n表示词向量的长度。

Bert的输入是三种输入之和(维度设定我们与本系列上一篇文章保持相同):

token_embeddings  和Transformer完全一样。

segment_embeddings  用来标记句子。第一个句子每个单词标0,第二个句子的每个单词标1。

pos_embeddings  用来标记位置,维度和Transformer中的一样,但是Bert的pos_embeddings是训练出来的(这意味它成为了神经网络里要训练的参数了)。

def get_token_and_segments(tokens_a, tokens_b=None):"""bert的输入之一:token embeddingsbert的输入之二:segment embeddingspos_embeddings在后面的模型里面"""tokens = ['<cls>'] + tokens_a + ['<sep>']segments = [0] * (len(tokens_a) + 2)if tokens_b is not None:tokens += tokens_b + ['<sep>']segments += [1] * (len(tokens_b) + 1)return tokens, segments

Bertmodel

Bert的单个EncoderLayer和Transformer是一样的,我们直接把上一节的代码复制过来就好。

组装好。

class BertModel(nn.Module):def __init__(self, vocab, n, d_ff, h, n_layers,max_len=1000, k=768, v=768):super(BertModel, self).__init__()self.token_embeddings = nn.Embedding(vocab, n)  # [B, m]->[B, m, vocab]->[B, m, n]self.segment_embeddings = nn.Embedding(2, n)  # [B, m]->[B, m, 2]->[B, m, n]self.pos_embeddings = nn.Parameter(torch.randn(1, max_len, n))  # [1, max_len, n]self.layers = nn.ModuleList([EncoderLayer(n, h, k, v, d_ff)for _ in range(n_layers)])def forward(self, tokens, segments, m):  # m是句子长度X = self.token_embeddings(tokens) + \self.segment_embeddings(segments)X += self.pos_embeddings[:, :X.shape[1], :]for layer in self.layers:X, attn = layer(X)return X

简单测试一下。

# 弄一点数据测试一下tokens = torch.randint(0, 100, (2, 10))  # [B, m]segments = torch.randint(0, 2, (2, 10))  # [B, m]m = 10bert = BertModel(100, 768, 3072, 12, 12)out = bert(tokens, segments, m)print(out.shape)  # [2, 10, 768]

预训练任务

Bert在训练时要做两种训练,这里先画个图表示架构,后面给出分析和代码。

MLM

Maked language model,是指在训练的时候随即从输入预料上mask掉一些单词,然后通过的上下文预测该单词,该任务非常像我们在中学时期经常做的完形填空。

在BERT的实验中,15%的WordPiece Token会被随机Mask掉。在训练模型时,一个句子会被多次喂到模型中用于参数学习,但是Google并没有在每次都mask掉这些单词,而是在确定要Mask掉的单词之后,80%的时候会直接替换为[Mask],10%的时候将其替换为其它任意单词,10%的时候会保留原始Token。(这里就不深入了)

class MLM(nn.Module):def __init__(self, vocab, n, mlm_hid):super(MLM, self).__init__()self.mlp = nn.Sequential(nn.Linear(n, mlm_hid),nn.ReLU(),nn.LayerNorm(mlm_hid),nn.Linear(mlm_hid, vocab))def forward(self, X, P):# X: [B, m, n]# P: [B, p]# 这里P指的是记录了要mask的元素的矩阵,若P(i,j)==k,表示X(i,k)被mask了p = P.shape[1]P = P.reshape(-1)batch_size = X.shape[0]batch_idx = torch.arange(batch_size)batch_idx = torch.repeat_interleave(batch_idx, p)X = X[batch_idx, P].reshape(batch_size, p, -1)  # [B, p, n]out = self.mlp(X)return out

这里的forward的逻辑有点麻烦,要读懂的话可以要手推一下。p是每一个Batch中mask的词的个数。(即在一个Batch中,m个词挑出了p个)。其实意会也行:就是训练时,在X[B, m, n]里每个batch有p个是<mask>,我们把它们挑出来,得到[B, p, n]。

NSP

Next Sentence Prediction的任务是判断句子B是否是句子A的下文。训练数据的生成方式是从平行语料中随机抽取的连续两句话,其中50%保留抽取的两句话,它们符合is_next关系,另外50%的第二句话是随机从预料中提取的,不符合is_next关系。分别记为1 | 0。

这个关系由每个句子的第一个token——<cls>捕捉。

class NSP(nn.Module):def __init__(self, n, nsp_hid):super(NSP, self).__init__()self.mlp = nn.Sequential(nn.Linear(n, nsp_hid),nn.Tanh(),nn.Linear(nsp_hid, 2))def forward(self, X):# X: [B, m, n]X = X[:, 0, :]  # [B, n]out = self.mlp(X)  # [B, 2]return out

Bert

下面拼装Bert。

class Bert(nn.Module):def __init__(self, vocab, n, d_ff, h, n_layers,max_len=1000, k=768, v=768, mlm_feat=768, nsp_feat=768):super(Bert, self).__init__()self.encoder = BertModel(vocab, n, d_ff, h, n_layers, max_len, k, v)self.mlm = MLM(vocab, n, mlm_feat)self.nsp = NSP(n, nsp_feat)def forward(self, tokens, segments, m, P=None):X = self.encoder(tokens, segments, m)mlm_out = self.mlm(X, P) if P is not None else Nonensp_out = self.nsp(X)return X, mlm_out, nsp_out

后话

netron可视化

利用netron可视化。

test_tokens = torch.randint(0, 100, (2, 10))  # [B, m]
test_segments = torch.randint(0, 2, (2, 10))  # [B, m]
test_P = torch.tensor([[1, 2, 4, 6, 8], [1, 3, 4, 5, 6]])
test_m = 10
test_bert = Bert(100, 768, 3072, 12, 12)
test_X, test_mlm_out, test_nsp_out = test_bert(test_tokens, test_segments, test_m, test_P)modelData = "./demo.pth"
torch.onnx.export(test_bert, (test_tokens, test_segments), modelData)
netron.start(modelData)

截取部分看一下。

code2flow可视化

code2flow可以可视化代码函数和类的相互调用关系。

code2flow.code2flow([r'代码路径.py'], '输出路径.svg')

这里生成的png,其实svg清晰得多。

fine tuning

Bert的精髓在于,Bert只是一个编码器(Encoder),经过MLM和NSP两个任务的训练之后,可以自己在它的基础上训练一个Decoder来输出特定的值、得到特定的效果。这也是Bert的神奇和魅力所在!通过两个任务训练出一个编码器,然后可以通过不同的Decoder达到各种效果!

持续探索Bert......

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/462.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

C#实现隐藏和显示任务栏

实现步骤 为了能够控制Windows任务栏&#xff0c;我们需要利用Windows API提供的功能。具体来说&#xff0c;我们会使用到user32.dll中的两个函数&#xff1a;FindWindow和ShowWindow。这两个函数可以帮助我们找到任务栏窗口&#xff0c;并对其执行显示或隐藏的操作 引入命名空…

Excel菜单选项无法点击?两种原因及解决方法全解析

在使用Excel处理数据时&#xff0c;有时会遇到菜单选项无法点击的情况。这种问题会影响到正常的操作和编辑。出现这种情况的原因可能有多种&#xff0c;本文将介绍两种常见的原因&#xff0c;并提供相应的解决方法&#xff0c;帮助小伙伴们快速恢复菜单选项的正常使用。 原因一…

SpringBoot节奏:Web音乐网站构建手册

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常…

波尼音乐 2.3.0-b1 | 开源免费的音乐播放器,附两个公共接口

波尼音乐最初作为一个毕设项目&#xff0c;凭借其实现了本地与网络音乐播放的能力而受到许多用户的喜爱。随着百度在线音乐API的关闭&#xff0c;波尼音乐逐渐失去在线音乐播放功能。在开源社区的支持下&#xff0c;开发者发现新的网易云音乐API&#xff0c;重启项目并进行全面…

ComfyUI - ComfyUI 工作流中集成 SAM2 + GroundingDINO 处理图像与视频 教程

欢迎关注我的CSDN&#xff1a;https://spike.blog.csdn.net/ 本文地址&#xff1a;https://spike.blog.csdn.net/article/details/143359538 免责声明&#xff1a;本文来源于个人知识与公开资料&#xff0c;仅用于学术交流&#xff0c;欢迎讨论&#xff0c;不支持转载。 SAM2 与…

C++原创游戏宝强越狱第二季即将回归

抱歉&#xff0c;3个月以来我不是没时间更&#xff0c;而是懒得更。。。 这不宝强越狱第一季完结了么&#xff0c;所以我决定两个月内弄完宝强越狱第二季&#xff0c;第一个版本将在2025年1月1日发布。不过&#xff0c;我还做了个预告片BQYY预告片-CSDN直播&#xff08;33s的垃…

mysql查表相关练习

作业要求&#xff1a; 单表练习&#xff1a; 1 . 查询出部门编号为 D2019060011 的所有员工 2 . 所有财务总监的姓名、编号和部门编号。 3 . 找出奖金高于工资的员工。 4 . 找出奖金高于工资 40% 的员工。 5 找出部门编号为 D2019090011 中所有财务总监&#xff0c;和…

【笔试题】迈入offer的新大门

1. 笔试题1 1.1 题目链接&#xff1a;[NOIP2010]数字统计_牛客题霸_牛客网 1.2 题目描述 补充&#xff1a; 1.3 解法 1.3.1 算法思路 定义变量&#xff0c;L,R,count用于记数。 对规定符合区域范围内的数据进行遍历&#xff0c;对每个数据的每一位进行判断是否为2&#xf…

Gitee push 文件

1、背景 想将自己的plecs仿真放到git中管理&#xff0c;以防丢失&#xff0c;以防乱改之后丢失之前版本仿真。此操作说明默认用户已下载git。 2、操作步骤 2.1 开启Git Bash 在文件夹中右键&#xff0c;开启Git Bash。 2.2 克隆文件 在Git Bash中打git clone git地址&#…

【AIGC】2024-arXiv-Lumiere:视频生成的时空扩散模型

2024-arXiv-Lumiere: A Space-Time Diffusion Model for Video Generation Lumiere&#xff1a;视频生成的时空扩散模型摘要1. 引言2. 相关工作3. Lumiere3.1 时空 U-Net (STUnet)3.2 空间超分辨率的多重扩散 4. 应用4.1 风格化生成4.2 条件生成 5. 评估和比较5.1 定性评估5.2 …

MySQL高可用MHA

目录 一、MHA概述 1.MHA是什么 2.MHA的组成 3.MHA特点 4.MHA工作原理 二、MySQL部署MHA 1.配置主从复制 2.配置MHA高可用 2.1所有服务器安装MHA依赖环境 2.2所有服务器上安装node组件 2.3在MHA manager节点上安装manager组件 2.4在所有服务器上配置无密码认证 …

聚类算法综述

摘要 聚类算法旨在根据数据中的固有模式和相似性将数据组织成组或簇。它们在当今生活中扮演着重要角色&#xff0c;例如在市场营销和电子商务、医疗保健、数据组织和分析以及社交媒体中。现有众多聚类算法&#xff0c;并且不断有新的算法被引入。每个算法都有其自身的优点和缺…

【网络监控加速设备】国产化一站式高性能数据处理平台(海光CPU+复旦微FPGA)

随着网络流量的飞速增长&#xff0c;数据的监控与管理需求日益加剧。针对这一痛点&#xff0c;一款集协议检测、数据监测、报文转发和结果展示于一体的网络监控加速设备&#xff0c;设备百分之百国产化也体现了完全自主可控。设备不仅具备丰富的网络监控功能&#xff0c;还支持…

确保企业架构与业务的一致性与合规性:数字化转型中的关键要素与战略实施

在现代企业的数字化转型过程中&#xff0c;确保企业架构&#xff08;Enterprise Architecture, EA&#xff09;与企业业务的紧密一致性与合规性至关重要。无论是在战略层面还是运营层面&#xff0c;EA都为企业的未来发展提供了清晰的蓝图&#xff0c;确保企业在应对复杂的业务环…

指数分布的原理和应用

本文介绍指数分布&#xff0c;及其推导原理。 Ref: 指数分布 开始之前&#xff0c;先看个概率密度函数的小问题&#xff1a; 问题描述&#xff1a;你于上午10点到达车站&#xff0c;车在10点到10:30 之间到达的时刻 X 的概率密度函数如图&#xff1a; 则使用分段积分&#xff0…

HTML 基础标签——链接标签 <a> 和 <iframe>

文章目录 1. `<a>` 标签属性详细说明示例2. `<iframe>` 标签属性详细说明示例注意事项总结链接标签在HTML中是实现网页导航的重要工具,允许用户从一个页面跳转到另一个页面或嵌入外部内容。主要的链接标签包括 <a> 标签和<iframe> 标签。本文将深入探…

Windows部署rabbitmq

本次安装环境&#xff1a; 系统&#xff1a;Windows 11 软件建议版本&#xff1a; erlang OPT 26.0.2rabbitmq 3.12.4 一、下载 1.1 下载erlang 官网下载地址&#xff1a; 1.2 下载rabbitmq 官网下载地址&#xff1a; 建议使用解压版&#xff0c;安装版可能会在安装软件…

自适应对话式团队构建,提升语言模型代理的复杂任务解决能力

人工智能咨询培训老师叶梓 转载标明出处 如何有效利用多个大模型&#xff08;LLM&#xff09;代理解决复杂任务一直是一个研究热点。由美国南加州大学、宾夕法尼亚州立大学、华盛顿大学、早稻田大学和谷歌DeepMind的研究人员联合提出了一种新的解决方案——自适应团队构建&…

“游戏人”也能拿诺贝尔奖!

“游戏人”也能拿诺贝尔奖&#xff01; 点击蓝链领取游戏开发教程 2024年度的诺贝尔奖&#xff0c;堪称AI深刻影响传统科学领域的一次生动展现。在这一年里&#xff0c;备受瞩目的诺贝尔物理学奖与诺贝尔化学奖两大核心奖项&#xff0c;均颁给了在AI领域做出杰出贡献的研究者…

获取环境变量 getenv小心有坑!

一、背景 在工作中&#xff0c;所做的项目需要涉及两个不同语言( P/Invoke)的信息传递。最后选定了一种环境变量的传递方式&#xff0c;但这也遇到了getenv带来的大坑&#xff01; 1.1 问题现象 我们在C#的exe主流程中通过DllImport&#xff0c;对环境变量进行了设置。随后我…