一个简单的图像分类项目(九)并行训练的学习:多GPU的DP(DataParallel数据并行)

        将电脑装成Ubuntu、Windows双系统,并在Ubuntu上继续学习。        

        在现代深度学习中,多主机多GPU训练已经变得非常常见,尤其是对于大规模模型和数据集。最简单和早期的并行计算比如NVIDIA的SLI,从NVIDIA 450系列驱动开始,NVIDIA官方停止了对SLI配置的支持,特别是在CUDA计算方面。现代深度学习框架通常可以通过多GPU配置来利用多块显卡,而不需要启用SLI。

       下面学习在pytorch中的并行训练。

一、DataParallel(数据并行)

        DataParallel是一种在深度学习中用于并行处理数据的技术。它可以将一个模型复制到多个设备(如多个 GPU)上,然后将数据分割并分配到这些设备上进行并行计算,以加快模型的训练速度。

优点
        加速训练过程
        在深度学习中,训练大规模的神经网络往往需要处理海量的数据。DataParallel 技术可以将数据划分成多个小批次,同时在多个计算设备(如多个 GPU)上进行处理。例如,一个具有数百万参数的图像分类模型,在处理包含数万张图像的数据集时,如果使用单个 GPU 可能需要花费数天时间来完成一个训练周期。但通过 DataParallel 将数据分配到 4 个 GPU 上并行处理,理论上可以将训练速度提高近 4 倍,大大缩短了训练时间。
        易于实现和使用
        以 PyTorch 为例,使用 DataParallel 相对简单。只需要将模型用torch.nn.DataParallel进行包装,然后像往常一样将数据输入模型进行训练即可。代码修改量较小,不需要对模型架构本身进行复杂的改动。
        硬件资源利用率高
        可以充分利用多个计算设备的计算能力。在具有多个 GPU 的服务器或计算集群中,DataParallel 能够使这些 GPU 同时工作,避免了部分硬件资源闲置的情况。这样可以更有效地利用硬件投资,特别是在处理大规模深度学习任务时,能够最大化地发挥硬件的性能。
缺点
        负载不均衡问题
        当数据划分不均匀或者模型在不同设备上的计算复杂度因数据而异时,可能会出现负载不均衡的情况。例如,在处理文本数据时,如果不同批次的文本长度差异很大,那么在处理长文本批次的设备上可能会花费更多的时间,导致各个设备的计算进度不一致,从而影响整体性能。这种负载不均衡可能会降低并行效率,使得加速比达不到理想的水平。
        通信开销较大
        在多个设备之间进行数据划分和结果合并需要一定的通信开销。设备之间需要频繁地交换数据和梯度信息,这在网络带宽有限或者设备间通信速度较慢的情况下,会成为性能瓶颈。特别是当模型参数非常多或者数据批次较大时,通信开销可能会抵消掉并行计算带来的部分性能提升。
        模型复制导致内存占用增加
        DataParallel 会在每个设备上复制一份模型,这会导致内存占用成倍增加。对于内存资源有限的设备来说,这可能会限制能够处理的模型规模或者数据批次大小。例如,在一些边缘计算设备或者小型 GPU 服务器上,可能无法承受模型的多份复制,从而无法使用 DataParallel 技术。

 项目实践

         DataParallel的实现较为简单,只需要将网络简单定义即可。将本项目中的train.py部分的代码修改为如下:

import timefrom load_imags import train_loader, train_num, test_loader, test_num
from nets import *
from torch.nn.parallel import DataParalleldef main():# 定义网络print('Please choose a network:')print('1. ResNet18')print('2. VGG')# 选择网络while True:net_choose = input('')if net_choose == '1':net = resnet18_model()net = net.to(device)net_name = 'ResNet18'print('You have chosen the ResNet18 network, start training.')breakelif net_choose == '2':net = vgg_model()# net = net.to(device)   # 不使用DataParallelnet = DataParallel(net).to(device)    # 使用DataParallelnet_name = 'VGG_simple'print('You have chosen the VGG network, start training.')breakelse:print('Please input a correct number!')# 定义损失函数和优化器loss_func = nn.CrossEntropyLoss()  # 交叉熵损失函数optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)  # 优化器使用Adamscheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.9)  # 学习率衰减, 每5个epoch,学习率乘以0.9# 训练模型for epoch in range(num_epoches):trained_num = 0  # 记录训练过的图片数量total_correct = 0  # 记录正确数量print('-' * 100)print('Epoch {}/{}'.format(epoch + 1, num_epoches))begin_time = time.time()  # 记录开始时间net.train()  # 训练模式for i, (images, labels) in enumerate(train_loader):images = images.to(device)  # 每batch_size个图像的数据labels = labels.to(device)  # 每batch_size个图像的标签trained_num += images.size(0)  # 记录训练过的图片数量outputs = net(images)  # 前向传播loss = loss_func(outputs, labels)  # 计算损失optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向传播optimizer.step()  # 优化器更新参数_, predicted = torch.max(outputs.data, 1)  # 预测结果correct = predicted.eq(labels).cpu().sum()  # 计算本batch_size的正确数量total_correct += correct  # 记录正确数量# 每5个epoch,学习率衰减scheduler.step()end_time = time.time()  # 记录结束时间print('Each train_epoch take time: {} s'.format(end_time - begin_time))print('This train_epoch accuracy: {:.2f}%'.format(100 * total_correct / train_num))print('-' * 60)tested_num = 0  # 记录测试过的图片数量total_correct = 0  # 记录正确数量begin_time = time.time()  # 记录开始时间net.eval()  # 测试模式for i, (images, labels) in enumerate(test_loader):images = images.to(device)  # 每batch_size个图像的数据labels = labels.to(device)  # 每batch_size个图像的标签tested_num += images.size(0)  # 记录测试过的图片数量outputs = net(images)  # 前向传播loss = loss_func(outputs, labels)  # 计算损失_, predicted = torch.max(outputs.data, 1)  # 预测结果correct = predicted.eq(labels).cpu().sum()  # 计算本batch_size的正确数量total_correct += correct  # 记录正确数量if (i + 1) % 10 == 0:  # 每10个batch_size打印一次print('tested: {}/{}'.format(tested_num, test_num))print('Loss: {:.4f}, Accuracy: {:.2f}%'.format(loss.item(), 100 * correct / images.size(0)))print('tested: {}/{}'.format(tested_num, test_num))print('-' * 30)end_time = time.time()  # 记录结束时间print('Each test_epoch take time: {} s'.format(end_time - begin_time))print('This test_epoch accuracy: {:.2f}%'.format(100 * total_correct / test_num))# 保存模型torch.save(net.state_dict(),os.path.join(model_path,time.strftime("%Y%m%d-%H-%M-", time.localtime()) +net_name + '.pkl'))  # 按结束时间和网络类型保存模型print('Finished Training')if __name__ == '__main__':main()

         只有一个地方修改:net = DataParallel(net).to(device),可以看到简单修改之后就可以实现数据并行。

        下面是数据并行修改前和修改后的运行截图对比

        修改前的GPU占用率:

d0ff8ee047194d84bdef28c63085bfb3.png

2c671b3fa99f47dc8fa5b7298bd57eb8.png

 两个GPU只有一个在工作。

训练用时:

01d1d68d15cf476db2ffeb7ff56e3313.png

 修改后的GPU占用率:

726ab2eb035346bcaf78db8e91e553bb.png

7cfc2a65ed4540b0b866534ce0839b3a.png 两个GPU均参与了训练。

a48fbdcf66f4465d9fb379ff3b3a44fb.png

        但是,训练时长比单显卡的时候变长了,原因是当前的batch_size设定较小,两个GPU之间的通信和等待同步占用时间比较多,GPU占用率很低,大部分时间都处于等待和空闲中。 将batch_size设为当前的4倍:

b95780f067014b27a0f175dc92b1c994.png

训练速度明显提升。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/18762.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

本草智选:中药实验管理的智能推荐

摘要 随着信息技术在管理上越来越深入而广泛的应用,管理信息系统的实施在技术上已逐步成熟。本文介绍了中药实验管理系统的开发全过程。通过分析中药实验管理系统管理的不足,创建了一个计算机管理中药实验管理系统的方案。文章介绍了中药实验管理系统的系…

凸优化理论和多模态基础模型研究

文章目录 摘要Abstract1. 拉格朗日对偶问题1.1 弱对偶问题1.2 强对偶问题(P*D*)1.3 KKT条件 2. 论文阅读3. 总结 摘要 本周从拉格朗日对偶理论出发,系统学习了优化问题中凸函数、强对偶条件以及 KKT 条件的应用,并将其与机器学习…

nginx+vconsole调试网页在vivo浏览器无法显示图片问题

一、问题描述 昨天测试小伙伴提了一个特殊的bug,在安卓vivo手机浏览器上访问网页,网页的图片按钮和录播图一闪而过后便消失不见: 二、问题排查 项目采用Nuxt框架,排查的方向大致如下: 1.其它手机浏览器是否有复现&am…

草本追踪:中药实验管理的数字化转型

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统,它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等,非常…

【Linux】虚拟地址空间,页表,物理内存

目录 进程地址空间,页表,物理内存 什么叫作地址空间? 如何理解地址空间的区域划分? 地址空间结构体 为什么要有地址空间? 页表 cr3寄存器 权限标记位 位置标记位 其他 每个存储单元是一个字节,一…

集群聊天服务器(3)muduo网络库

目录 基于muduo的客户端服务器编程 muduo只能装在linux中,依赖boost库 客户端并不需要高并发 基于muduo的客户端服务器编程 支持epoll线程池,muduo封装了线程池 而且还有完善的日志系统 使用muduo库代码非常固定,基本就只有chatserver的类名…

【Python刷题】最少拐弯路线问题

题目描述 洛谷P1649 一、题目理解 首先,我们来看一下这道题目的要求。题目给定了一个 NN(1≤N≤100) 的方格,方格中的每个格子有不同的状态,用 . 表示可以行走的格子,x 表示不能行走的格子,…

在windows系统里面部署 Redis

在windows中下载安装REdis 1.下载mis 地址添加链接描述 然后直接下载安装然后点击你的库 2.然后选择好之后选择好路径就行了。 然后我们点击这个cli.exe文件然后双击打开输入 在命令框里输入: 如果显示的和图片显示的一样,则证明你已经在本地部署好了…

NTP博客

使用nmtui命令修改IP: 注意: 修改之后,要激活: nmcli connection up ens160 1、软件安装 #设置当前时区 [rootlocalhost ~]# timedatectl set-timezone Asia/Shanghai 1.1.配置国内阿里yum源 [rootredhat ~]# cd /etc/yum.r…

《Large-scale Multi-modal Pre-trained Models: A Comprehensive Survey》中文校对版

文章汉化系列目录 文章目录 文章汉化系列目录摘要引言2 背景2.1 传统深度学习2.2 自然语言处理中的预训练2.3 计算机视觉中的预训练2.4 音频与语音中的预训练 3 多模态预训练3.1 任务定义与关键挑战3.2 MM-PTM的优势3.3 预训练数据3.4 预训练目标3.5 预训练网络架构3.5.1 自注意…

从源码角度分析JDK动态代理

文章目录 前言一、JDK动态代理二、动态代理的生成三、invoke的运行时调用总结 前言 本篇从源码的角度,对JDK动态代理的实现,工作原理做简要分析。 一、JDK动态代理 JDK动态代理是运行时动态代理的一种实现,相比较于CGLIB ,目标对象…

操作系统——计算机系统概述——1.5操作系统引导(开机过程)

操作系统引导: A.CPU从一个特定主存地址开始,取指令,执行ROM中的引导程序(先进行硬件自检,再开机) B.将磁盘的第一块——主引导记录读入内存,执行磁盘引导程序,扫描分区表 C.从活动分…

推荐一本python学习书:《编程不难》

推荐理由 全面:把零基础Python编程、可视化、数学、数据、机器学习,融合在一起,循循渐进。 开源:PDF、Python代码、Jupyter文档,在github直接免费下! 真实:提供大量真实场景下的数据&#xff…

数据结构与算法分析模拟试题及答案5

模拟试题(五) 一、单项选择题(每小题 2 分,共20分) (1)队列的特点是(   )。 A)先进后出 B)先进先出 C)任意位置进出 D&#xff0…

集群聊天服务器(9)一对一聊天功能

目录 一对一聊天离线消息服务器异常处理 一对一聊天 先新添一个消息码 在业务层增加该业务 没有绑定事件处理器的话消息会派发不出去 聊天其实是服务器做一个中转 现在同时登录两个账号 收到了聊天信息 再回复一下 离线消息 声明中提供接口和方法 张三对离线的李…

jedis基础入门

jedis采用key&#xff0c;value的形式保存数据&#xff0c;使用nosql sql和nosql的区别 一&#xff1a;入门案例 导入依赖 <dependencies><dependency><groupId>redis.clients</groupId><artifactId>jedis</artifactId><version>…

QWen2.5学习

配置环境 pip install transformers 记得更新一下&#xff1a;typing_extensions pip install --upgrade typing_extensions 安装modelscope modelscope/modelscope: ModelScope: bring the notion of Model-as-a-Service to life. 下载这个仓库的代码上传到服务器解压 推…

足球青训俱乐部管理后台系统(程序+数据库+报告)

基于SpringBoot的足球青训俱乐部管理后台系统&#xff0c;系统包含两种角色&#xff1a;管理员、用户,系统分为前台和后台两大模块 编程语言&#xff1a;Java 数据库&#xff1a;MySQL 项目管理工具&#xff1a;Maven 前端技术&#xff1a;Vue 后端技术&#xff1a;SpringBoot…

MoneyPrinterTurbo - AI自动生成高清短视频

MoneyPrinterTurbo是一款基于AI大模型的开源软件&#xff0c;旨在通过一键操作帮助用户自动生成高清短视频。只需提供一个视频 主题或 **关键词** &#xff0c;就可以全自动生成视频文案、视频素材、视频字幕、视频背景音乐&#xff0c;然后合成一个高清的短视频。 ​ ​ 主要…

Cross-Inlining Binary Function Similarity Detection

注&#xff1a;在阅读该论文时顺便参考了作者团队的分享视频&#xff1a;【ICSE 2024论文预讲会-第二期-下午-哔哩哔哩】 https://b23.tv/XUVAPy3 在这个视频的末尾最后一个 一.introducion 计算下面两个函数的相似度&#xff1a; 查询函数&#xff1a;脆弱函数&#xff0c;重…