政安晨:【Keras机器学习示例演绎】(四十二)—— 使用 KerasNLP 和 tf.distribute 进行数据并行训练

目录

简介

导入

基本批量大小和学习率

计算按比例分配的批量大小和学习率


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:使用 KerasNLP 和 tf.distribute 进行数据并行训练。

简介


分布式训练是一种在多台设备或机器上同时训练深度学习模型的技术。它有助于缩短训练时间,并允许使用更多数据训练更大的模型。KerasNLP 是一个为自然语言处理任务(包括分布式训练)提供工具和实用程序的库。

在本文中,我们将使用 KerasNLP 在 wikitext-2 数据集(维基百科文章的 200 万字数据集)上训练基于 BERT 的屏蔽语言模型 (MLM)。MLM 任务包括预测句子中的屏蔽词,这有助于模型学习单词的上下文表征。

本指南侧重于数据并行性,尤其是同步数据并行性,即每个加速器(GPU 或 TPU)都拥有一个完整的模型副本,并查看不同批次的部分输入数据。部分梯度在每个设备上计算、汇总,并用于计算全局梯度更新。

具体来说,本文将教您如何在以下两种设置中使用 tf.distribute API 在多个 GPU 上训练 Keras 模型,只需对代码做最小的改动:

—— 在一台机器上安装多个 GPU(通常为 2 至 8 个)(单主机、多设备训练)。这是研究人员和小规模行业工作流程最常见的设置。
—— 在由多台机器组成的集群上,每台机器安装一个或多个 GPU(多设备分布式训练)。这是大规模行业工作流程的良好设置,例如在 20-100 个 GPU 上对十亿字数据集进行高分辨率文本摘要模型训练。

!pip install -q --upgrade keras-nlp
!pip install -q --upgrade keras  # Upgrade to Keras 3.

导入

import osos.environ["KERAS_BACKEND"] = "tensorflow"import tensorflow as tf
import keras
import keras_nlp

在开始任何训练之前,让我们配置一下我们的单 GPU,使其显示为两个逻辑设备。

在使用两个或更多物理 GPU 进行训练时,这完全没有必要。这只是在默认 colab GPU 运行时(只有一个 GPU 可用)上显示真实分布式训练的一个技巧。

!nvidia-smi --query-gpu=memory.total --format=csv,noheader
physical_devices = tf.config.list_physical_devices("GPU")
tf.config.set_logical_device_configuration(physical_devices[0],[tf.config.LogicalDeviceConfiguration(memory_limit=15360 // 2),tf.config.LogicalDeviceConfiguration(memory_limit=15360 // 2),],
)logical_devices = tf.config.list_logical_devices("GPU")
logical_devicesEPOCHS = 3
24576 MiB

要使用 Keras 模型进行单主机、多设备同步训练,您需要使用 tf.distribute.MirroredStrategy API。下面是其工作原理:

—— 实例化 MirroredStrategy,可选择配置要使用的特定设备(默认情况下,该策略将使用所有可用的 GPU)。
—— 使用该策略对象打开一个作用域,并在该作用域中创建所需的包含变量的所有 Keras 对象。通常情况下,这意味着在分发作用域内创建和编译模型。
—— 像往常一样通过 fit() 训练模型。

strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
Number of devices: 2

基本批量大小和学习率

base_batch_size = 32
base_learning_rate = 1e-4

计算按比例分配的批量大小和学习率

scaled_batch_size = base_batch_size * strategy.num_replicas_in_sync
scaled_learning_rate = base_learning_rate * strategy.num_replicas_in_sync

现在,我们需要下载并预处理 wikitext-2 数据集。该数据集将用于预训练 BERT 模型。我们将过滤掉短行,以确保数据有足够的语境用于训练。

keras.utils.get_file(origin="https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip",extract=True,
)
wiki_dir = os.path.expanduser("~/.keras/datasets/wikitext-2/")# Load wikitext-103 and filter out short lines.
wiki_train_ds = (tf.data.TextLineDataset(wiki_dir + "wiki.train.tokens",).filter(lambda x: tf.strings.length(x) > 100).shuffle(buffer_size=500).batch(scaled_batch_size).cache().prefetch(tf.data.AUTOTUNE)
)
wiki_val_ds = (tf.data.TextLineDataset(wiki_dir + "wiki.valid.tokens").filter(lambda x: tf.strings.length(x) > 100).shuffle(buffer_size=500).batch(scaled_batch_size).cache().prefetch(tf.data.AUTOTUNE)
)
wiki_test_ds = (tf.data.TextLineDataset(wiki_dir + "wiki.test.tokens").filter(lambda x: tf.strings.length(x) > 100).shuffle(buffer_size=500).batch(scaled_batch_size).cache().prefetch(tf.data.AUTOTUNE)
)

在上述代码中,我们下载并提取了 wikitext-2 数据集。然后,我们定义了三个数据集:wiki_train_ds、wiki_val_ds 和 wiki_test_ds。我们对这些数据集进行了过滤,以去除短行,并对其进行批处理,以提高训练效率。

在 NLP 训练/调整中,使用衰减学习率是一种常见的做法。在这里,我们将使用多项式衰减时间表(PolynomialDecay schedule)。

total_training_steps = sum(1 for _ in wiki_train_ds.as_numpy_iterator()) * EPOCHS
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(initial_learning_rate=scaled_learning_rate,decay_steps=total_training_steps,end_learning_rate=0.0,
)class PrintLR(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs=None):print(f"\nLearning rate for epoch {epoch + 1} is {model_dist.optimizer.learning_rate.numpy()}")

我们还要回调 TensorBoard,这样就能在本教程后半部分训练模型时可视化不同的指标。我们将所有回调放在一起,如下所示:

callbacks = [tf.keras.callbacks.TensorBoard(log_dir="./logs"),PrintLR(),
]print(tf.config.list_physical_devices("GPU"))
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

准备好数据集后,我们现在要在 strategy.scope() 中初始化并编译模型和优化器:

with strategy.scope():# Everything that creates variables should be under the strategy scope.# In general this is only model construction & `compile()`.model_dist = keras_nlp.models.BertMaskedLM.from_preset("bert_tiny_en_uncased")# This line just sets pooled_dense layer as non-trainiable, we do this to avoid# warnings of this layer being unusedmodel_dist.get_layer("bert_backbone").get_layer("pooled_dense").trainable = Falsemodel_dist.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),optimizer=tf.keras.optimizers.AdamW(learning_rate=scaled_learning_rate),weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],jit_compile=False,)model_dist.fit(wiki_train_ds, validation_data=wiki_val_ds, epochs=EPOCHS, callbacks=callbacks)
Epoch 1/3
Learning rate for epoch 1 is 0.00019999999494757503239/239 ━━━━━━━━━━━━━━━━━━━━ 43s 136ms/step - loss: 3.7009 - sparse_categorical_accuracy: 0.1499 - val_loss: 1.1509 - val_sparse_categorical_accuracy: 0.3485
Epoch 2/3239/239 ━━━━━━━━━━━━━━━━━━━━ 0s 122ms/step - loss: 2.6094 - sparse_categorical_accuracy: 0.5284
Learning rate for epoch 2 is 0.00019999999494757503239/239 ━━━━━━━━━━━━━━━━━━━━ 32s 133ms/step - loss: 2.6038 - sparse_categorical_accuracy: 0.5274 - val_loss: 0.9812 - val_sparse_categorical_accuracy: 0.4006
Epoch 3/3239/239 ━━━━━━━━━━━━━━━━━━━━ 0s 123ms/step - loss: 2.3564 - sparse_categorical_accuracy: 0.6053
Learning rate for epoch 3 is 0.00019999999494757503239/239 ━━━━━━━━━━━━━━━━━━━━ 32s 134ms/step - loss: 2.3514 - sparse_categorical_accuracy: 0.6040 - val_loss: 0.9213 - val_sparse_categorical_accuracy: 0.4230

根据范围拟合模型后,我们对其进行正常评估!

model_dist.evaluate(wiki_test_ds)
 29/29 ━━━━━━━━━━━━━━━━━━━━ 3s 60ms/step - loss: 1.9197 - sparse_categorical_accuracy: 0.8527[0.9470901489257812, 0.4373602867126465]

对于跨多台计算机的分布式训练(而不是只利用单台计算机上的多个设备进行训练),您可以使用两种分布式策略:MultiWorkerMirroredStrategy 和 ParameterServerStrategy:

—— tf.distribution.MultiWorkerMirroredStrategy(多工作站策略)实现了一种 CPU/GPU 多工作站同步解决方案,可与 Keras 风格的模型构建和训练循环配合使用,并使用跨副本的梯度同步还原。
—— tf.distribution.experimental.ParameterServerStrategy(参数服务器策略)实现了一种异步 CPU/GPU 多工作站解决方案,其中参数存储在参数服务器上,工作站异步更新梯度到参数服务器。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1419992.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

使用html和css实现个人简历表单的制作

根据下列要求,做出下图所示的个人简历(表单) 表单要求 Ⅰ、表格整体的边框为1像素,单元格间距为0,表格中前六列列宽均为100像素,第七列 为200像素,表格整体在页面上居中显示; Ⅱ、前…

【Unity Shader入门精要 第6章】基础光照(一)

1. 什么是光照模型 光照原理 在真实世界中,我们能够看到物体,是由于眼睛接收到了来自观察目标的光。这里面包括两种情况:一部分是观察目标本身发出的光(自发光)直接进入我们的眼睛,另一部分是其他物体&am…

基于Django实现的校园疫情监控平台

基于Django实现的校园疫情监控平台 开发语言:Python 数据库:MySQL所用到的知识:Django框架工具:pycharm、Navicat、Maven 系统功能实现 登录注册功能 用户在没有登录自己的用户名之前只能浏览本网站的首页,想要使用其他功能都会…

组织机构树形列表实现

源码地址:https://www.lanzouw.com/itjDc1ydraof 本来上传了源码,但是发现只能VIP才能下载,所以重新上传到蓝奏云上了,链接如下: 先看下效果图: 可以自己写HTML来自定义每一项的内容显示,包括…

哈希算法在区块链中的应用

哈希算法是区块链技术的核心组件之一,它确保了区块链数据的不可篡改性和安全性。在本文中,我们将探讨哈希算法的基本原理,以及它在区块链中的具体应用。 哈希算法的基本原理 哈希算法是一种数学函数,它接收输入(或“消…

excel转pdf的java实现

一、实现原理 采用java调用vbs脚本调用office应用把excel转成pdf。 支持文件格式:xlsx,xls,csv 二、前期准备 1、安装office软件 2、准备vbs脚本文件,放到C:\excel2pdf_script\目录下。(本文只用2个文件) 三、VBS转换脚本 1…

有边数限制的最短路

文章目录 题目 有边数限制的最短路算法分析1、问题:为什么Dijkstra不能使用在含负权的图中?dijkstra详细步骤2、什么是bellman - ford算法?3、bellman - ford算法的具体步骤4、在下面代码中,是否能到达n号点的判断中需要进行if(di…

vue2 八大组件通信,父子通信,跨层级通信,事件总线,vuex等

文章目录 什么是组件通信?父子通信流程propsProps 定义Props 作用特点数组写法对象写法(props校验)简写只验证数据类型:完整写法,完整的验证: props父向子传值用props父传子在子组件中修改props $emit子向父…

vue3点击添加小狗图片,vue3拆分脚本

我悄悄蒙上你的眼睛 模板和样式 <template><div class"XueXi_Hooks"><img v-for"(dog, index) in dog1List" :src"dog" :key"index" /><button click"addDog1">点我添加狗1</button><hr …

圆柱齿轮的旋向如何判断?

上期出了个题&#xff0c;给了两个内齿轮&#xff0c;请大家来判断他们的旋向&#xff0c;看到了有不少小伙伴评论给出了自己的答案&#xff0c;正确和错误差不多各半吧&#xff0c;错的占比要大一些。这期咱们就好好聊一聊这个问题。 外齿轮的旋向大家貌似判断都没什么问题&a…

Hive行列转换应用与实现

Hive行列转换应用与实现 1.多行转多列 问题引入 解决方法 2.多行转单列 问题引入 解决方法 3.多列转多行 问题引入 解决方法 4.单列转多行

信息系统项目管理师0102:可行性研究的内容(7项目立项管理—7.2项目可行性研究—7.2.1可行性研究的内容)

点击查看专栏目录 文章目录 7.2项目可行性研究7.2.1可行性研究的内容1.技术可行性分析2.经济可行性分析3.社会效益可行性分析4.运行环境可行性分析5.其他方面的可行性分析记忆要点总结7.2项目可行性研究 可行性研究是在项目建议书被批准后,从技术、经济、社会和人员等方面的条…

200+套AxureBi可视化大数据大屏看板原型设计方案

产品名称&#xff1a;200套AxureBi可视化大屏看板原型设计方案 模板数量&#xff1a;200套平均单价0.46元&#xff08;持续增加中~平均每2周一更&#xff09; 软件版本: Axure 8,Axure 9,Axure 10&#xff08;兼容&#xff09; 作品类型: BI数据大屏可视化Axure原型 文件类型: …

多线程-线程安全

目录 线程安全问题 加锁(synchronized) synchronized 使用方法 synchronized的其他使用方法 synchronized 重要特性(可重入的) 死锁的问题 对 2> 提出问题 对 3> 提出问题 解决死锁 对 2> 进行解答 对4> 进行解答 volatile 关键字 wait 和 notify (重要…

SpringBoot中使用MongoDB

目录 搭建实体类 基本的增删改查操作 分页查询 使用MongoTemplate实现复杂的功能 引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-mongodb</artifactId> </dependency> 在ap…

CentOS 8.5 安装配置 Tinyproxy 轻量代理服务器 Windows10 系统设置http代理 详细教程

1 下载 下载地址 2 上传服务器并解压 tar zxvf tinyproxy-1.11.2.tar.gz 3 安装配置 #安装依赖软件 yum install automake cd tinyproxy-1.11.2/ #生成configure ./autogen.sh # ./configure --prefix/usr/local/tinyproxy make make install 4 配置环境 vim /etc/prof…

【教程】最新MySQL8.3.0社区版安装指南(超详细)

写在前面&#xff1a; 如果文章对你有帮助&#xff0c;记得点赞关注加收藏一波&#xff0c;利于以后需要的时候复习&#xff0c;多谢支持&#xff01; 文章目录 一、下载安装包二、解压安装包三、设置配置文件四、配置系统环境五、初始化操作 此次安装的版本为MySQL社区版&…

【JVM】Class文件的格式

目录 概述 Class文件的格式 概述 Class文件是JVM的输入&#xff0c;Java虚拟机规范中定义了Class文件的结构。Class文件是JVM实现平台无关、技术无关的基础。 1:Class文件是一组以8字节为单位的字节流&#xff0c;各个数据项目按顺序紧凑排列 2:对于占用空间大于8字节的数据…

实验室信息管理系统主要解决哪些问题,能帮实验室从哪些方面提升效率?

实验室信息管理系统&#xff08;LIMS&#xff09;是一种全面精益化管理工具&#xff0c;它对实验室的人、机、料、法、环进行精确管理&#xff0c;使监测业务高效、准确、方便&#xff0c;确保实验室的运行效率和数据安全性得到极大的提升。通过LIMS&#xff0c;实验室能够实现…

Android Studio连接MySQL8.0

【序言】 移动平台这个课程要做一个app的课设&#xff0c;我打算后期增加功能改成毕设&#xff0c;就想要使用MySQL来作为数据库&#xff0c;相对于SQLlite来说&#xff0c;我更熟悉MySQL一点。 【遇到的问题】 一直无法连接上数据库&#xff0c;开始的时候查了很多资料&#…