基于Python的自然语言处理系列(21):BERT模型实现

        在本篇文章中,我们将介绍并实现BERT(Bidirectional Encoder Representations from Transformers)。与传统的Transformer模型相比,BERT的主要区别在于如何处理数据,尤其是引入了掩码语言建模(Masked Language Modeling)和下一句预测(Next Sentence Prediction)。在这篇文章中,我们将从零开始构建BERT,并演示其训练和推理过程。

1. 数据预处理

        我们将使用一个简单的语料库,首先对文本进行预处理,包括将文本转换为小写、去掉标点符号等。

import spacy# 加载简单的文本数据
with open("data/wiki_king.txt", "r") as f:raw_text = f.read()
nlp = spacy.load("en_core_web_sm")
doc = nlp(raw_text)
sentences = list(doc.sents)# 处理文本:转为小写并去除标点符号
text = [x.text.lower() for x in sentences]
text = [re.sub("[.,!?\\-]", '', x) for x in text]
print(text)

        我们可以看到,所有的句子都已经被处理成小写,并且去除了标点符号。

2. 词汇表生成

        在生成词汇表之前,我们首先需要将文本拆分为单词。我们为每个单词生成唯一的ID,并为特殊标记(如[PAD][CLS][SEP][MASK])预留ID。

# 生成词汇表
word_list = list(set(" ".join(text).split()))
word2id = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}  # 特殊标记# 为每个单词生成ID
for i, w in enumerate(word_list):word2id[w] = i + 4  # 特殊标记占用了0-3
id2word = {i: w for w, i in word2id.items()}
vocab_size = len(word2id)# 将文本转换为ID序列
token_list = [[word2id[word] for word in sentence.split()] for sentence in text]
print(token_list)

3. 数据加载器

        BERT模型需要处理两种嵌入:Token嵌入Segment嵌入,并且要对输入句子进行随机掩码处理。我们将实现一个生成批处理数据的函数,该函数包含以下步骤:

  1. Token嵌入:在句子开头添加[CLS]标记,两个句子之间添加[SEP]标记。
  2. Segment嵌入:用0和1来区分两个句子。
  3. 掩码语言建模:随机掩盖15%的单词,其中80%替换为[MASK]标记。
  4. 填充:将所有序列填充到相同长度。
batch_size = 6
max_mask = 5
max_len = 1000def make_batch():batch = []positive = negative = 0while positive != batch_size/2 or negative != batch_size/2:tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences))tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]# 1. Token嵌入input_ids = [word2id['[CLS]']] + tokens_a + [word2id['[SEP]']] + tokens_b + [word2id['[SEP]']]# 2. Segment嵌入segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)# 3. 掩码语言建模n_pred = min(max_mask, max(1, int(len(input_ids) * 0.15)))cand_maked_pos = [i for i, token in enumerate(input_ids) if token not in (word2id['[CLS]'], word2id['[SEP]'])]shuffle(cand_maked_pos)masked_tokens, masked_pos = [], []for pos in cand_maked_pos[:n_pred]:masked_pos.append(pos)masked_tokens.append(input_ids[pos])if random() < 0.1:input_ids[pos] = randint(0, vocab_size - 1)elif random() < 0.9:input_ids[pos] = word2id['[MASK]']# 4. 填充n_pad = max_len - len(input_ids)input_ids.extend([0] * n_pad)segment_ids.extend([0] * n_pad)if max_mask > n_pred:n_pad = max_mask - n_predmasked_tokens.extend([0] * n_pad)masked_pos.extend([0] * n_pad)if tokens_a_index + 1 == tokens_b_index and positive < batch_size / 2:batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True])positive += 1elif tokens_a_index + 1 != tokens_b_index and negative < batch_size / 2:batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False])negative += 1return batchbatch = make_batch()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))
print(input_ids.shape, segment_ids.shape, masked_tokens.shape, masked_pos.shape, isNext.shape)

4. BERT模型实现

4.1 嵌入层

        在BERT模型中,嵌入层负责将输入的Token、位置和句子段嵌入整合在一起。我们使用LayerNorm来标准化输出。

class Embedding(nn.Module):def __init__(self):super(Embedding, self).__init__()self.tok_embed = nn.Embedding(vocab_size, d_model)self.pos_embed = nn.Embedding(max_len, d_model)self.seg_embed = nn.Embedding(n_segments, d_model)self.norm = nn.LayerNorm(d_model)def forward(self, x, seg):pos = torch.arange(x.size(1), dtype=torch.long).unsqueeze(0).expand_as(x)embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)return self.norm(embedding)

4.2 多头注意力机制

        BERT的多头注意力机制允许模型在不同的注意力头上并行关注不同部分的输入。

class MultiHeadAttention(nn.Module):def __init__(self):super(MultiHeadAttention, self).__init__()self.W_Q = nn.Linear(d_model, d_k * n_heads)self.W_K = nn.Linear(d_model, d_k * n_heads)self.W_V = nn.Linear(d_model, d_v * n_heads)def forward(self, Q, K, V, attn_mask):q_s = self.W_Q(Q).view(Q.size(0), -1, n_heads, d_k).transpose(1, 2)k_s = self.W_K(K).view(K.size(0), -1, n_heads, d_k).transpose(1, 2)v_s = self.W_V(V).view(V.size(0), -1, n_heads, d_v).transpose(1, 2)attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)context = context.transpose(1, 2).contiguous().view(Q.size(0), -1, n_heads * d_v)return nn.LayerNorm(d_model)(nn.Linear(n_heads * d_v, d_model)(context) + Q), attn

4.3 BERT模型

        最后我们定义完整的BERT模型,其中包括嵌入层、编码层以及用于掩码语言模型和下一句预测的分类器。

class BERT(nn.Module):def __init__(self):super(BERT, self).__init__()self.embedding = Embedding()self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])self.fc = nn.Linear(d_model, d_model)self.activ = nn.Tanh()self.classifier = nn.Linear(d_model, 2)embed_weight = self.embedding.tok_embed.weightself.decoder = nn.Linear(embed_weight.size(1), vocab_size, bias=False)self.decoder.weight = embed_weightself.decoder_bias = nn.Parameter(torch.zeros(vocab_size))def forward(self, input_ids, segment_ids, masked_pos):output = self.embedding(input_ids, segment_ids)enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)for layer in self.layers:output, _ = layer(output, enc_self_attn_mask)h_pooled = self.activ(self.fc(output[:, 0]))logits_nsp = self.classifier(h_pooled)masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1))h_masked = torch.gather(output, 1, masked_pos)logits_lm = self.decoder(h_masked) + self.decoder_biasreturn logits_lm, logits_nsp

5. 模型训练

        我们将BERT模型进行训练,使用交叉熵损失函数对掩码语言建模和下一句预测进行优化。

num_epoch = 500
model = BERT()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)for epoch in range(num_epoch):optimizer.zero_grad()logits_lm, logits_nsp = model(input_ids, segment_ids, masked_pos)loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens).mean()loss_nsp = criterion(logits_nsp, isNext)loss = loss_lm + loss_nspif epoch % 100 == 0:print(f'Epoch: {epoch}, Loss: {loss.item():.6f}')loss.backward()optimizer.step()

6. 推理

        最后,我们演示如何使用训练好的BERT模型进行推理。

logits_lm, logits_nsp = model(input_ids, segment_ids, masked_pos)
logits_lm = logits_lm.data.max(2)[1][0].data.numpy()
print('预测的掩码单词:', [id2word[pos] for pos in logits_lm])
logits_nsp = logits_nsp.data.max(1)[1][0].data.numpy()
print('是否为下一句:', '是' if logits_nsp else '否')

结语

        在本篇文章中,我们详细探讨了BERT模型的构建与实现,尤其是在数据处理阶段的独特之处,如掩码语言模型(Masked Language Model, MLM)和下一句预测(Next Sentence Prediction, NSP)的结合。通过逐步解析BERT的编码器结构、注意力机制以及多层自注意力网络,我们展示了如何在实际项目中应用BERT模型进行预训练任务。

        BERT的强大在于其双向编码的特性,使得模型能够充分利用上下文信息,特别适合解决文本分类、句子配对等复杂的自然语言理解任务。在这个实现中,我们使用了一个简单的数据集来演示BERT的核心功能。虽然实际性能受到数据量的限制,但通过扩展到更大的数据集和更复杂的任务,BERT的能力会得到充分发挥。

        在下一篇文章中,我们将探讨Pruning技术,它是一种模型压缩方法,可以通过减少模型中的冗余参数来提高计算效率和内存使用率。这对于部署深度学习模型至关重要,尤其是在资源有限的设备上。敬请期待!

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1552149.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

数据中心交换机与普通交换机之间的区别到底在哪里?

号主&#xff1a;老杨丨11年资深网络工程师&#xff0c;更多网工提升干货&#xff0c;请关注公众号&#xff1a;网络工程师俱乐部 上午好&#xff0c;我的网工朋友。 数据中心交换被设计用来满足数据中心特有的高性能、高可靠性和可扩展性需求。 与此同时&#xff0c;普通交换机…

全面提升MySQL性能:从硬件到配置再到代码的最佳实践

MySQL 是全球最流行的开源关系型数据库管理系统之一&#xff0c;广泛应用于各种规模的应用程序中。随着应用规模的增长&#xff0c;数据库的性能优化成为提升系统整体性能的关键因素。本文将从多个角度探讨如何对MySQL进行性能优化&#xff0c;帮助开发者和DBA解决实际问题&…

免费 Oracle 各版本 离线帮助使用和介绍

文章目录 Oracle 各版本 离线帮助使用和介绍概要在线帮助下载离线文档包&#xff1a;解压离线文档&#xff1a;访问离线文档&#xff1a;导航使用&#xff1a;目录介绍Install and Upgrade&#xff08;安装和升级&#xff09;&#xff1a;Administration&#xff08;管理&#…

Android 13.0 系统wifi列表显示已连接但无法访问网络问题解决

1.前言 在13.0的系统rom产品定制化开发中,在wifi模块也很重要,但是在某些情况下对于一些wifi连接成功后,确显示已连接成功,但是无法访问互联网 的情况,所以实际上这时可以正常上网的,就是显示的不正常,所以就需要分析连接流程然后解决问题 如图所示: 2.系统wifi列表显示…

linux文件编程_进程

1. 进程相关概念 面试中关于进程&#xff0c;应该会问的的几个问题&#xff1a; 1.1. 什么是程序&#xff0c;什么是进程&#xff0c;有什么区别&#xff1f; 程序是静态的概念&#xff0c;比如&#xff1a; 磁盘中生成的a.out文件&#xff0c;就叫做&#xff1a;程序进程是…

【Python报错已解决】 Encountered error while trying to install package.> lxml

&#x1f3ac; 鸽芷咕&#xff1a;个人主页 &#x1f525; 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想&#xff0c;就是为了理想的生活! 专栏介绍 在软件开发和日常使用中&#xff0c;BUG是不可避免的。本专栏致力于为广大开发者和技术爱好者提供一个关于BUG解决的经…

掌握 JVM 垃圾收集线程:简化 VM 选项

垃圾收集阶段对于任何 Java 应用程序都至关重要。主要目标是保持高吞吐量和低延迟之间的平衡。通过配置垃圾收集器&#xff0c;我们可以提高性能&#xff0c;或者至少推动应用程序朝着特定的方向发展。 垃圾收集周期越短越好。因此&#xff0c;分配给垃圾收集器的资源越多&…

RS485串口通信:【图文详讲】

RS485&#xff0c;RS的意义为Recommended Standard的缩写&#xff0c;也就是推荐标准&#xff0c;是一种常用的半双工-异步-串行通信总线。半双工的意思就是两者通信时&#xff0c;同一时刻&#xff0c;只能由其中一方发送&#xff0c;另一方只能接收&#xff0c;不可以同时收发…

Java 每日一刊(第18期):集合

文章目录 前言1. Java 集合框架概述1.1 Java 集合框架的定义和意义1.2 Java 集合框架的历史演进1.3 集合框架的基本组成部分1.4 Java 集合的优势1.5 Java 集合与数组的区别与关系 2. Java 集合框架的核心接口2.1 Collection 接口2.2 List 接口2.3 Set 接口2.4 Queue 接口2.5 Ma…

共享单车轨迹数据分析:以厦门市共享单车数据为例(九)

副标题&#xff1a;基于站点800m范围内评价指标探究——以吕厝站为例 上篇文章我们以厦门市为例&#xff0c;来通过POI和优劣解距离法&#xff08;TOPSIS&#xff09;来研究厦门岛内以800m作为辐射范围的地铁站哪些地铁站发展的最好&#xff0c;根据综合得分指数可以知道&…

【Linux】【操作】Linux操作集锦系列之七——Linux环境下如何查看CPU使用情况(利用率等)

&#x1f41a;作者简介&#xff1a;花神庙码农&#xff08;专注于Linux、WLAN、TCP/IP、Python等技术方向&#xff09;&#x1f433;博客主页&#xff1a;花神庙码农 &#xff0c;地址&#xff1a;https://blog.csdn.net/qxhgd&#x1f310;系列专栏&#xff1a;Linux技术&…

AutoGen实现多代理-Planning_and_Stock_Report_Generation(六)

1. 案例背景 本节内容是构建Agent组&#xff0c;通过广播模式&#xff0c;实现管理者对agent工作流的管理。本实验链接&#xff1a;传送门 2. 代码实践 2.1 环境设置 llm_config{"model": "gpt-4-turbo"}# 工作任务描述 task "Write a blogpost a…

Cyberduck网络鸭-访问远程文件客户端新选择

Cyberduck 是一款适用于 macOS 和 Windows 的自由文件传输客户端。适用于 Linux、macOS 和 Windows 的命令行界面 (CLI)。核心库用于Mountain Duck。 官网&#xff1a;https://cyberduck.io/download/ 开源地址&#xff1a; https://cyberduck.io/download/ 支持协议很多&…

国庆同欢,祖国昌盛!肌肉纤维启发,水凝胶如何重构聚合物

在这个国庆佳节&#xff0c;我们共同感受祖国的繁荣昌盛&#xff0c;同时也迎来了知识的探索之旅。今天来了解聚合物架构的重构的研究——《Hydrogel‐Reactive‐Microenvironment Powering Reconfiguration of Polymer Architectures》发表于《Advanced Science》。材料科学不…

消费电子制造企业如何使用SAP系统提升运营效率与竞争力

在当今这个日新月异的消费电子市场中&#xff0c;企业面临着快速变化的需求、激烈的竞争以及不断攀升的成本压力。为了在这场竞赛中脱颖而出&#xff0c;消费电子制造企业纷纷寻求数字化转型的突破点&#xff0c;其中&#xff0c;SAP系统作为业界领先的企业资源规划(ERP)解决方…

怀孕之天赋共享:其实人身体没变,完全是天赋共享

关于怀孕天赋共享&#xff0c;有人说&#xff0c;是不是怀孕导致身体变化&#xff1f; 并没有。下面这个就是案例。你总不能说&#xff0c;小孩生下来身体立即改变吧&#xff1f;

World of Warcraft [CLASSIC] Engineering 421-440

工程学421-440 World of Warcraft [CLASSIC] Engineering 335-420_魔兽世界宗师级工程学需要多少点-CSDN博客 【萨隆邪铁锭】421-425 学习新技能&#xff0c;其他都不划算&#xff0c;只能做太阳瞄准镜 【太阳瞄准镜】426、427、428、429 【随身邮箱】430 这个基本要做的&am…

基于SSM的农产品仓库管理系统【附源码】

基于SSM的农产品仓库管理系统&#xff08;源码L文说明文档&#xff09; 目录 4 系统设计 4.1 系统概要设计 4.2 系统功能结构设计 4.3 数据库设计 4.3.1 数据库E-R图设计 4.3.2 数据库表结构设计 5 系统实现 5.1 管理员功能介绍 5.1.1 用户管…

ios内购支付-支付宝APP支付提现

文章目录 前言一、IOS内购支付&#xff08;ios订单生成自己写逻辑即可&#xff09;1.支付回调票据校验controller1.支付回调票据校验server 二、安卓APP支付宝支付1.生成订单返回支付宝字符串&#xff08;用于app拉起支付宝&#xff0c;这里用的是证书模式&#xff09;2.生成订…

Java 死锁及避免讲解和案例示范

在大型分布式系统中&#xff0c;死锁是一种常见但难以排查的并发问题。特别是在 Java 领域&#xff0c;死锁问题可能导致系统崩溃或卡顿。本文将以电商交易系统为例&#xff0c;详细讲解如何识别和避免 Java 程序中的死锁问题&#xff0c;确保系统高效运行。 1. 什么是死锁&am…