分布式训练:(Pytorch)

分布式训练是将机器学习模型的训练过程分散到多个计算节点或设备上,以提高训练速度和效率,尤其是在处理大规模数据和模型时。分布式训练主要分为数据并行模型并行两种主要策略:

1. 数据并行 (Data Parallelism)

数据并行是最常见的分布式训练方式。在这种方法中,模型副本会被复制到多个计算设备上,每个设备处理不同的批次(batch)数据。

工作流程:
  • 每个设备上都有一个完整的模型副本。
  • 数据集被分割成多个部分(mini-batches),每个设备处理其中一部分。
  • 每个设备独立计算模型的前向传播和反向传播,计算出梯度。
  • 通过某种方式(如梯度聚合),将所有设备的梯度平均化,并更新全局模型参数。
  • 同步方式可分为同步训练和异步训练:
    • 同步训练:所有设备都在同一个时刻更新模型参数。
    • 异步训练:各设备独立更新参数,可能导致一些参数不一致。
# Replicate module to devices in device_ids
replicas = nn.parallel.replicate(module, device_ids)
# Distribute input to devices in device_ids
inputs = nn.parallel.scatter(input, device_ids)
# Apply the models to corresponding inputs
outputs = nn.parallel.parallel_apply(replicas, inputs)
# Gather result from all devices to output_device
result = nn.parallel.gather(outputs, output_device)
优点:
  • 易于实现,特别是在GPU集群或云端平台中。
  • 可以在大规模数据集上显著加快训练过程。
缺点:
  • 通信开销较大,特别是在梯度同步阶段,可能会成为训练速度的瓶颈。
  • 对大模型的扩展性有限,因为每个设备都需要存储完整的模型。

2. 模型并行 (Model Parallelism)

模型并行将一个大型模型拆分到多个设备上,以便更好地利用计算资源,尤其适用于内存消耗较大的模型。

工作流程:
  • 模型被拆分成多个部分,每个设备负责模型的一个子集。
  • 输入数据在各设备间传递,完成前向传播和反向传播。
  • 各设备独立计算梯度并更新自己负责的模型参数。
优点:
  • 适合超大规模模型,尤其是单个设备无法存储整个模型的情况。
  • 内存使用效率较高。
缺点:
  • 由于模型的不同部分在不同设备上进行计算,存在大量的通信开销,尤其是在前向传播和反向传播时需要设备间频繁交互。
  • 难以实现模型的负载均衡,部分设备可能成为性能瓶颈。

常用的分布式训练框架

  • TensorFlow:支持多设备、多机器的分布式训练,通过 tf.distribute.Strategy 轻松实现。
  • PyTorch:通过 torch.distributed 提供原生支持,还支持基于 Horovod 等第三方工具的分布式训练。
  • Horovod:Uber 开源的分布式深度学习库,支持 TensorFlow、Keras、PyTorch 等。

关键挑战

  • 同步和通信开销:在数据并行训练中,梯度的同步可能成为瓶颈。
  • 负载均衡:在模型并行训练中,确保各设备之间的负载均衡非常重要,以避免性能瓶颈。
  • 容错性:分布式训练中节点故障可能导致训练过程中断,需要具备一定的容错机制。

常用的 API 有两个:

  • torch.nn.DataParallel(DP)
  • torch.nn.DistributedDataParallel(DDP)

torch.nn.DataParallel(简称 DP)是 PyTorch 提供的一个简单的并行化工具,主要用于在多个 GPU 上进行数据并行训练。DataParallel 通过将输入数据批次(batch)切分成多个小批次,并将其分发到多个 GPU 上,进行并行处理。它会自动处理梯度的同步和模型参数的更新。

torch.nn.DataParallel 的工作机制

  1. 模型复制DataParallel 会将模型复制到多个 GPU 上,每个 GPU 上有一个模型副本。
  2. 数据分割:输入数据会被划分成多个小批次(mini-batches),并分别分发给各个 GPU。
  3. 并行执行:每个 GPU 独立进行前向传播和反向传播,计算梯度。
  4. 梯度汇总:主设备(默认是 cuda:0)会收集所有 GPU 计算出的梯度,并将它们平均化,更新模型的全局参数。

使用 torch.nn.DataParallel

使用 DataParallel 非常简单,通常只需要将模型用 DataParallel 包裹,然后像普通模型一样使用即可。

import torch
import torch.nn as nn
import torch.optim as optim# 定义模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 5)def forward(self, x):return self.fc(x)# 初始化模型和数据
model = SimpleModel()# 将模型并行化
if torch.cuda.device_count() > 1:print("Using", torch.cuda.device_count(), "GPUs")model = nn.DataParallel(model)model = model.cuda()# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 模拟输入数据
inputs = torch.randn(32, 10).cuda()  # 一个 32 样本的 batch,每个样本 10 个特征
targets = torch.randn(32, 5).cuda()  # 对应的目标输出# 前向传播
outputs = model(inputs)# 计算损失
loss = criterion(outputs, targets)# 反向传播
optimizer.zero_grad()
loss.backward()# 更新模型参数
optimizer.step()

DistributedDataParallel (简称 DDP) 是 PyTorch 用于分布式训练的高级并行化工具,它的效率和灵活性比 DataParallel 更高,特别适合在多个 GPU 甚至跨多个节点(机器)上进行分布式训练。与 DataParallel 不同,DDP 在每个设备(GPU)上独立处理模型的前向传播和反向传播,并且避免了主设备的瓶颈问题。

DistributedDataParallel 的工作原理

  1. 模型的分发:与 DataParallel 类似,DDP 会在每个 GPU 上保留一份模型副本。但与 DataParallel 不同的是,DDP 不需要将数据集中在主设备上,而是让每个 GPU 独立完成自己的工作。
  2. 前向和反向传播:每个 GPU 上的模型执行前向传播和反向传播,并计算梯度。
  3. 梯度同步:每个设备上计算的梯度通过 all-reduce 操作在所有设备之间同步,确保所有模型副本的梯度相同。这个过程是并行进行的,不会像 DataParallel 那样集中在主设备上,因此通信效率更高。
  4. 参数更新:每个设备独立地应用梯度更新全局模型参数。

DistributedDataParallel 的优点

  • 高效的通信和同步:梯度的同步是在所有设备之间并行进行的,避免了主设备成为通信瓶颈的问题,因此在多 GPU 或跨节点时表现更加优异。
  • 可扩展性强DDP 支持跨多台机器的训练,适合超大规模模型或需要跨节点的分布式训练。
  • 无锁设计DDP 实现了无锁的梯度同步,不会因锁机制造成性能损失。

DistributedDataParallel 的使用

DataParallel 类似,DDP 也需要对模型进行包装,但它需要更多的设置,特别是在多机环境下,还需要配置通信后端。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP# 初始化分布式环境
def setup(rank, world_size):dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)# 销毁分布式环境
def cleanup():dist.destroy_process_group()# 定义模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 5)def forward(self, x):return self.fc(x)# 初始化模型、优化器和数据
def main(rank, world_size):setup(rank, world_size)model = SimpleModel().cuda(rank)ddp_model = DDP(model, device_ids=[rank])criterion = nn.MSELoss()optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)# 模拟输入数据inputs = torch.randn(32, 10).cuda(rank)targets = torch.randn(32, 5).cuda(rank)# 前向传播outputs = ddp_model(inputs)loss = criterion(outputs, targets)# 反向传播optimizer.zero_grad()loss.backward()# 更新模型参数optimizer.step()cleanup()# 假设有两个GPU,可以这样启动分布式训练
if __name__ == "__main__":world_size = 2  # GPU数torch.multiprocessing.spawn(main, args=(world_size,), nprocs=world_size, join=True)
特性DataParallel (DP) DistributedDataParallel (DDP)
通信模式主设备负责梯度同步所有设备间并行同步梯度
性能通信开销大,主设备瓶颈通信开销小,性能更高
可扩展性适用于单机多 GPU适用于单机或多机多 GPU
使用场景小规模并行大规模或跨节点分布式训练

2. 并行数据加载

在深度学习任务中,数据加载通常是训练过程中的一个瓶颈,特别是当数据量很大时。使用多个进程来并行加载数据,并将数据从可分页内存(虚拟内存)转移到固定内存(GPU 内存)可以显著提高训练效率。

工作流程

  1. 数据加载

    • 使用多个进程并行从磁盘读取数据。每个进程负责加载不同的数据批次,减少了磁盘 I/O 操作的等待时间。
  2. 生产者-消费者模式

    • 数据加载进程(生产者)将读取的数据批次放入队列中,而主线程(消费者)从队列中取出数据批次进行训练。这样可以在数据加载和模型训练过程中实现并行化,减少数据加载对训练速度的影响。
  3. 固定内存的使用

    • 将数据从主机的可分页内存转移到固定内存。数据被加载到固定内存中后,转移到 GPU 的速度会更快,因为固定内存中的数据可以快速传输。

参数解释

  1. num_workers

    • 这个参数指定了数据加载的进程数量。将 num_workers 设置为大于 0 的值可以让 DataLoader 使用多个子进程来并行加载数据。
    • 例如,num_workers=4 表示使用 4 个进程来加载数据。这可以显著提高数据加载速度,因为多个进程可以同时从磁盘读取不同的数据批次。
  2. pin_memory

    • 这个参数用于将数据从主机内存(CPU 内存)固定到页面锁定内存(pinned memory)。固定内存可以让数据传输到 GPU 更加高效。
    • pin_memory=True 时,DataLoader 会将数据从可分页的内存(虚拟内存)传输到固定内存中,这样在将数据转移到 GPU 时,数据传输速度会更快,因为固定内存可以避免页面交换的开销。

总结

  • 数据加载:使用多个进程来并行加载和预处理数据,通过流水线处理减少数据加载的延迟。
  • 数据传输:利用 CUDA 流优化从固定内存到 GPU 的数据传输。
  • 数据并行性:使用数据并行和 NCCL 等通信库实现高效的梯度同步和模型参数更新,优化训练过程。

这种方法结合了数据加载、数据传输和数据并行处理的优化,能够显著提升深度学习模型的训练效率和速度。

import torch
from torch.utils.data import DataLoader, Dataset
import numpy as npclass CustomDataset(Dataset):def __init__(self, size):self.data = np.random.rand(size, 3, 224, 224).astype(np.float32)self.labels = np.random.randint(0, 2, size).astype(np.int64)def __len__(self):return len(self.data)def __getitem__(self, idx):return torch.tensor(self.data[idx]), torch.tensor(self.labels[idx])dataset = CustomDataset(size=10000)
dataloader = DataLoader(dataset,batch_size=64,shuffle=True,num_workers=4,      # 使用 4 个子进程加载数据pin_memory=True     # 将数据转移到固定内存
)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")for inputs, labels in dataloader:inputs, labels = inputs.to(device), labels.to(device)# 模型训练代码# ...

 参考文章:

Pytorch 分布式训练(DP/DDP)_pytorch分布式训练-CSDN博客icon-default.png?t=O83Ahttps://blog.csdn.net/ytusdc/article/details/122091284?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522CC589E02-BBE1-4F15-BDC0-CA76EBF6C160%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=CC589E02-BBE1-4F15-BDC0-CA76EBF6C160&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~top_positive~default-1-122091284-null-null.142^v100^control&utm_term=%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83&spm=1018.2226.3001.4187

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1536681.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

分布式中间件-分布式代理框架Codis和Twemproxy

文章目录 Codis框架架构图 Twemproxy框架Codis和Twemproxy对比设计目标功能特性使用场景结论 Codis框架 Codis是一个开源的分布式内存键值存储系统,它基于Redis并且提供了一个分布式的解决方案来扩展单一Redis实例的能力。Codis项目由豌豆荚团队开发,并…

订单防重复提交:token 发放以及校验

订单防重复提交:token 发放以及校验 1. 基于Token校验避免订单重复提交 1. 基于Token校验避免订单重复提交 在很多秒杀场景中,用户为了能下单成功,会频繁的点击下单按钮,这时候如果没有做好控制的话,就可能会给一个用…

春秋云境之CVE-2022-30887

一.靶场环境 1.下载靶场环境 根据题目提示,此靶场存在文件上传漏洞。 2.启动靶场环境 我们可以看到是一个登录页面,我们尝试进行登录 二.登录页面 1.尝试进行登录 我们发现用户名必须是邮箱,那么弱口令肯定不行,我们可以看到…

Qt集成Direct2D绘制,实现离屏渲染

没搜到关于Qt中使用Direct2D的方式&#xff0c;想了个办法&#xff0c;在此做个记录。 需要引入这两个库&#xff1a; 代码&#xff1a; #pragma once #include <QWidget> #include <QImage> #include <QPainter> #include <QMouseEvent>#include &q…

【23-24年】年度总结与迎新引荐

文章目录 相关连接前言1 忙碌的备研与本科毕设2 暑期阿里之旅3 团队荣誉与迎新引荐4 项目合作意向 相关连接 个人博客&#xff1a;issey的博客 - 愿无岁月可回首 前言 自从2023年4月更新了两篇关于NLP的文章后&#xff0c;我便消失了一年半的时间。如今&#xff0c;随着学业…

SpringBoot 图书管理系统

文章目录 一、删除图书二、批量删除三、强制登录3.1 不使用拦截器3.2 使用拦截器 四、更新图书 一、删除图书 并不使用delete语句&#xff1a; 原因&#xff1a;企业开发中&#xff0c;因为数据就意味着金钱&#xff0c;所以我们不会使用delete去删除&#xff08;delete删除是…

基于SpringBoot的人事管理系统【附源码】

基于SpringBoot的人事管理系统&#xff08;源码L文说明文档&#xff09; 目录 4 系统设计 4.1 系统概述 4.2系统功能结构设计 4.3数据库设计 4.3.1数据库E-R图设计 4.3.2 数据库表结构设计 5 系统实现 5.1管理员功能介绍 5.1.1管理员登…

2分钟解决联想电脑wifi功能消失 网络适配器错误代码56

分钟解决联想电脑wifi功能消失 网络适配器错误代码56 现象 原因 电脑装了虚拟机&#xff0c;导致网络适配器冲突。我的电脑是装了vm虚拟机&#xff0c;上次更新系统后wifi图标就消失了。 解决方案 1、先卸载虚拟机 2、键盘按winr&#xff0c;弹出运行窗口&#xff0c;输入“…

LLVM PASS-PWN-前置

文章目录 参考环境搭建基础知识![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/dced705dcbb045ceb8df2237c9b0fd71.png)LLVM IR实例1. **.ll 格式&#xff08;人类可读的文本格式&#xff09;**2. **.bc 格式&#xff08;二进制格式&#xff09;**3. **内存表示** …

『功能项目』伤害数字UI显示【53】

我们打开上一篇52眩晕图标显示的项目&#xff0c; 本章要做的事情是在Boss受到伤害时显示伤害数字 首先打开Boss01预制体空间在Canvas下创建一个Text文本 设置Text文本 重命名为DamageUI 设置为隐藏 编写脚本&#xff1a;PlayerCtrl.cs 运行项目 本章做了怪物受伤血量的显示UI…

C语言 ——— 写一个宏,将一个整数的二进制位的奇数位和偶数位交换

目录 题目要求 代码实现 题目要求 写一个宏&#xff0c;将一个整数的二进制位的奇数位和偶数位交换 举例说明&#xff1a; 输入&#xff1a;10 10 的二进制为 1010 &#xff0c;奇数位和偶数位交换后得 0101 &#xff0c;也就是 5 输出&#xff1a;5 代码实现 代码演示&…

RK3568驱动指南|第十六篇 SPI-第190章 配置模式下寄存器的配置

瑞芯微RK3568芯片是一款定位中高端的通用型SOC&#xff0c;采用22nm制程工艺&#xff0c;搭载一颗四核Cortex-A55处理器和Mali G52 2EE 图形处理器。RK3568 支持4K 解码和 1080P 编码&#xff0c;支持SATA/PCIE/USB3.0 外围接口。RK3568内置独立NPU&#xff0c;可用于轻量级人工…

电流镜与恒流源

在两个晶体管完全对称的情况下&#xff0c;电源通过R1给两个晶体管提供相同的偏置电流&#xff0c; 这样他们流过集电极和发射极的电流就相同。 电流镜原视频链接&#xff1a; 【电流镜电路】https://www.bilibili.com/video/BV1b5411k7rh?vd_source3cc3c07b09206097d0d8b0ae…

Linux基础3-基础工具3(make,makefile,gdb详解)

上篇文章&#xff1a;Linux基础3-基础工具2&#xff08;vim详解&#xff0c;gcc详解&#xff09;-CSDN博客 本章重点&#xff1a; 1.自动化构建工具make,makefile 2.linux调试工具gdb 目录 一. 自动化构建工具make,makefile 1.1 make使用 1.2 使用make注意点 a. make和文件时…

Python数据分析案例60——扩展变量后的神经网络风速预测(tsfresh)

案例背景 时间序列的预测一直是经久不衰的实际应用和学术研究的对象&#xff0c;但是绝大多数的时间序列可能就没有太多的其他的变量&#xff0c;例如一个股票的股价&#xff0c;还有一个企业的用电量&#xff0c;人的血糖浓度等等&#xff0c;空气的质量&#xff0c;温度这些…

揭秘LLM计算数字的障碍的底层原理

LLM的 Tokenizer与数字切分 大语言模型在处理语言时&#xff0c;通常依赖Tokenization技术来将文本切分为可操作的单元。早期版本的Tokenizer对数字处理不够精确&#xff0c;常常将多个连续数字合并为一个Token。比如“13579”可能被切分为“13”、“57”和“9”。在这种情况…

【Linux修行路】网络套接字编程——UDP

目录 ⛳️推荐 前言 六、Udp Server 端代码 6.1 socket——创建套接字 6.2 bind——将套接字与一个 IP 和端口号进行绑定 6.3 recvfrom——从服务器的套接字里读取数据 6.4 sendto——向指定套接字中发送数据 6.5 绑定 ip 和端口号时的注意事项 6.5.1 云服务器禁止直接…

AIGC图片相关知识和实战经验(Flux.1,ComfyUI等等)

最近看了网上的一些新闻&#xff0c;flux.1火出圈了&#xff0c;因此自己也尝试跑了一下&#xff0c;作图的质量还是蛮高的&#xff0c;在这里做个知识总结回顾。 flux.1是什么&#xff1f; 根据介绍&#xff0c;flux.1是由stable diffusion 一作&#xff0c;Stability AI的核…

数据结构----栈和队列

&#xff08;一&#xff09;栈 1.栈的概念及结构 栈&#xff1a;一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端 称为栈顶&#xff0c;另一端称为栈底。栈中的数据元素遵守后进先出LIFO&#xff08;Last In First …

【数据结构】十大经典排序算法总结与分析

文章目录 前言1. 十大经典排序算法分类2. 相关概念3. 十大经典算法总结4. 补充内容4.1 比较排序和非比较排序的区别4.2 稳定的算法就真的稳定了吗&#xff1f;4.3 稳定的意义4.4 时间复杂度的补充4.5 空间复杂度补充 结语 前言 排序算法是《数据结构与算法》中最基本的算法之一…