LLM大模型训练/推理的显卡内存需求计算

无论你是从头开始训练 LLM、对其进行微调还是部署现有模型,选择合适的 GPU 对成本和效率都至关重要。在这篇博客中,我们将详细介绍使用单个和多个 GPU 以及不同的优化器和批处理大小进行 LLM 训练和推理时 GPU 要求的所有信息。

计算机处理器由多个决定性电路组成,每个电路都可以处于关闭或打开状态。就内存而言,这两种状态由 0 或 1 或位表示。一组八位称为一个字节。1 个字节可以表示零(00000000)和 255(11111111)之间的数字,或 28(等于 256 个不同位置)。通常,在 FP-32(包括符号、指数和尾数)数据类型上训练的神经网络最多占用 4 个字节的内存。

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - AI模型在线查看 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割

模型参数常用的数据类型如下:

  • float(32 位浮点):每个参数 4 个字节
  • half/BF16(16 位浮点):每个参数 2 个字节
  • int8(8 位整数):每个参数 1 个字节
  • int4(4​​ 位整数):每个参数 0.5 个字节

1、什么会消耗 GPU 内存?

在模型训练期间,大部分内存被四个东西消耗

11 模型参数

模型参数是神经网络的可学习组件。它们定义网络的结构和行为,并在训练期间更新以最小化损失函数。通常,我们有权重和偏差参数。

正如我们已经知道的那样,存储一个数字需要 4 个字节。假设我们的模型中有 P 个参数。

  • 参数内存(M)= 参数数量(P)x 精度大小(4 字节)
  • M = Px4
  • 16 位 M = P x 精度大小(2 字节)也类似

我们可以添加一个缩放因子并制定一个标准公式,如下所示:

这里 1.2 表示在 GPU 内存中加载额外内容的 20% 开销,Q 是加载模型应使用的位数。即 16 位、8 位或 4 位。

16 位 Llama 70B 需要 GPU 内存:

这是推理 Llama 70b 模型所需的总体最低 GPU。

1.2 激活

当输入数据通过网络时,激活是每层神经元的中间输出。在前向传递过程中,每层处理输入数据,应用权重、偏差和激活函数(如 ReLU、sigmoid 等)来产生激活。然后,这些激活将作为下一层输入。

需要存储每个层的激活,因为它们在反向传播期间用于计算梯度。

激活内存 = 激活数量 x 批次大小 x 精度大小

注意:“每个参数的激活”取决于模型架构、层数和序列长度。对于大型模型,激活通常需要与参数相当或超过参数的内存。将序列长度加倍也可能使激活内存加倍。

近似值:没有固定的公式来计算激活的 GPU 内存。对于较大的模型,激活所需的内存可能大致与参数的内存相似或略大。

1.3 梯度

梯度是损失函数关于模型参数的偏导数。它们表示应调整每个参数多少以最小化损失函数。

在反向传播期间,损失通过网络向后传播,并计算每个参数(权重和偏差)的梯度。优化器使用这些梯度来更新参数,从而减少整体损失。

存储梯度所需的内存等于参数本身所需的内存。由于每个参数都有相应的梯度,因此它们的内存要求相同。

梯度内存 = 参数内存

1.4 优化器状态

优化器状态是某些优化算法(如 Adam、RMSprop)维护的附加变量,用于提高训练效率。这些状态有助于根据过去的梯度更新模型参数。

不同的优化器维护不同类型的状态。例如:

  • SGD(随机梯度下降):没有附加状态;仅使用梯度来更新参数。
  • Adam:为每个参数维护两个状态:一阶矩(梯度平均值)和二阶矩(梯度平方平均值)。这有助于动态调整每个参数的学习率。对于具有 100 万个参数的模型,Adam 需要为每个参数维护 2 个附加值(一阶矩和二阶矩),从而产生 200 万个附加状态。

优化器状态的内存 = 参数数量 x 精度大小 x 优化器乘数

2、单GPU内存需求计算

我们举个例子

我们想在单个 GPU 上以混合精度(2 字节)训练 100 亿模型。

  • 参数内存=参数数量 x 2 字节 (FP16)
  • 参数内存=10B x 2 字节 = 20 GB
  • 激活内存=每个参数的激活 x 批次大小 x 精度大小

我们可以计算每层激活内存,而不是计算激活的总内存,这是一种高效的方法,需要的内存更少,因为它可以在下一层使用。

  • 每层神经元的近似数量 = sqrt(10B) ≈ 每层 100k 个神经元
  • 一层的激活内存 ≈ 32 x 100k x 2 字节 ≈ 每层 6.4 MB

对于大型模型中的层(假设有数百层),激活内存最多可达数十 GB。

因此,正如我们之前讨论过的,对于 32 的批次大小,大约需要 20-40 GB 的内存。现在,如果我们将批次大小加倍,这个范围可以加倍。

  • 梯度内存 = 参数内存
  • 梯度内存 = 20 GB
  • 优化器状态内存 = 参数数量 x 4 字节 (FP32) x 2 (用于 Adam)
  • 优化器状态内存 = 10B x 4 字节 x 2 = 80 GB

总内存估计:

  • 参数内存:20 GB
  • 激活内存:≈20–40 GB(取决于批次大小)
  • 梯度内存:20 GB
  • 优化器状态内存:80 GB
  • 总内存 = 20 + 20 到 40 + 20 + 80 = 140 到 160 GB

3、多个 GPU 的内存计算

要计算在 n 个 GPU 上训练时每个 GPU 的内存需求,我们需要考虑如何使用数据并行和模型并行等并行技术在 GPU 上分配内存。

关键假设:

  • 模型并行:模型的参数在 GPU 之间分配,因此每个 GPU 仅存储总模型参数的一小部分。梯度和优化器状态也同样被划分。
  • 数据并行:每个 GPU 都会获得整个模型参数的副本,但数据批次会在 GPU 之间分配。激活是针对每个 GPU 的小批次单独计算的。

如果我们使用模型并行性,那么所有模型参数、梯度和优化器统计数据都是分布式的。

但是,每个 GPU 仍然需要存储其批次部分的激活。激活的内存不会随着 GPU 数量的增加而减少,因为每个 GPU 都独立处理自己的数据。

因此,对于所有 GPU 来说,激活所需的内存仍然相同

因此,在 n 个 GPU 上以混合精度(2 字节)训练 100 亿模型所需的总内存为:

如果我们想使用 2 个 GPU 训练 LLM,我们需要大约 8o 到 100 GB 的内存。


原文链接:LLM显卡内存需求计算 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1543004.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

C/C++逆向:switch语句逆向分析

在逆向分析中,switch语句会被编译器转化为不同的底层实现方式,这取决于编译器优化和具体的场景。常见的实现方式包括以下几种: ①顺序判断(if-else链): 编译器将switch语句转化为一系列的if-else语句。这…

管道物体计数系统源码分享

管道物体计数检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Vis…

信创背景下中职计算机组装与维护课程教学解决方案

在当前的国际形势下,确保信息化系统的安全性和可靠性显得尤为重要。为了提高信息技术的安全性和可靠性,国家鼓励并支持使用国产的信息技术、工具和资源来替代现有的技术体系。这一过程被称为“安全可信的创新替代”,它已经成为国家安全战略的…

VMware ESXi 8.0U3b macOS Unlocker OEM BIOS 2.7 标准版和厂商定制版

VMware ESXi 8.0U3b macOS Unlocker & OEM BIOS 2.7 标准版和厂商定制版 ESXi 8.0U3 标准版,Dell (戴尔)、HPE (慧与)、Lenovo (联想)、Inspur (浪潮)、Cisco (思科)、Hitachi (日立)、Fujitsu (富士通)、NEC (日电) 定制版、Huawei (华为) OEM 定制版 请访问…

OpenResty安装及使用

🍓 简介:java系列技术分享(👉持续更新中…🔥) 🍓 初衷:一起学习、一起进步、坚持不懈 🍓 如果文章内容有误与您的想法不一致,欢迎大家在评论区指正🙏 🍓 希望这篇文章对你有所帮助,欢…

构建高可用和高防御力的云服务架构第四部分:REDIS(4/5)

本文的目的是深入探讨Redis在构建高可用和高防御力云服务架构中的应用。我们将分析Redis的工作原理、核心特性以及如何通过Redis优化云服务架构的性能和安全性。此外,我们还将提供实际案例和最佳实践,帮助读者更好地理解和应用Redis,以构建更…

中小企业体系技术抽象沉淀-异地灾备篇

IT团队内部使用工具 系列文章:https://blog.csdn.net/caicongyang/article/details/136857045 DDL DML管控 https://github.com/hhyo/Archery/ flyway 文档编写 wiki 技术对外输出文档推荐gitbook 同城双活数据同步方案 总览: vivo 系列文章&#x…

普通程序员如何快速入门AIGC

文章目录 第1阶段:基础知识打牢 (1-2周)第2阶段:深度学习理论与实践 (2-4周)第3阶段:AIGC 生成技术入门 (3-5周)第4阶段:进阶学习和项目实战 (5-8周)第5阶段:保持学习和更新 (持续进行) 要快速入门 AIGC(AI…

SPI驱动学习六(SPI_Master驱动程序)

目录 前言一、SPI_Master驱动程序框架1. SPI传输概述1.1 数据组织方式1.2 SPI控制器数据结构 2. SPI传输函数的两种方法2.1 老方法2.2 新方法 二、如何编写SPI_Master驱动程序1. 编写设备树2. 编写驱动程序 三、SPI_Master驱动程序简单示例demo1. 使用老方法编写的SPI Master驱…

Webrtc开发实战系列 - win10+vs2022下编译最新webrtc代码

1. 准备起步 操作系统:windows 10 安装 vs2019/vs2022 安装 win10 sdk 19041 一定勾选 Debugging Tools for Windows 科学上网准备代理工具 磁盘剩余空间至少 30G 推荐用一台干净的机器或者虚拟机来编译WebRTC,安装过python的会出现一些非常棘手…

昂首资本:欧美货币对的交易智慧

在外汇市场的海洋中,昂首资本的投资者们深知,把握欧美货币对的交易时段是获取收益的关键。欧美货币对,即欧元对美元,因其在欧洲和美国市场的活跃交易时段而备受瞩目。这两个时段不仅交易量巨大,而且价格波动剧烈&#…

【隐私计算篇】利用多方安全计算MPC实现VGG16人脸识别隐私推理

1. 背景介绍 本文主要介绍一种利用多方安全计算MPC技术,实现VGG16的人脸识别模型,侧重于模型推理阶段,目前已经公开专利,因此以下内容的分享都是基于公开材料。该分享涉及到最小化多方安全计算(MPC)以及明密文混合计算的思想&…

JAVA开源项目 甘肃非物质文化网站 计算机毕业设计

本文项目编号 T 043 ,文末自助获取源码 \color{red}{T043,文末自助获取源码} T043,文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 查…

python画图|把X轴标签移动到图像顶端

在前述学习过程中,我们一直使用的是默认的轴坐标,X轴往往置于图像的下端。 有时候,也会有将X轴标签放置在图形顶端的需求,今天就一起学习一下。 【1】官网教程 首先打开官网,可以通过下述链接一步直达: …

软考高级:系统安全 -区块链特点:去中心化、开放性、自治性、安全性、匿名性

讲解 生活化例子 想象一下,你和朋友们玩一个共享账本的游戏。每个人都可以在账本上记账,没人可以单独改动账本,大家都可以随时查看账本内容,也不用再信任某个单独的人来管理账本。这就类似于区块链的工作原理。 概念讲解 去中…

基于c++实现的简易shell

代码逻辑 核心思想 解析命令行,拆解命令及其选项创建子进程,在子进程中执行命令如果是前台执行命令,则父进程就阻塞等待子进程中命令执行结束后回收子进程的资源如果是后台执行命令,则父进程不进行阻塞等待,可继续向下…

【机器学习】---神经架构搜索(NAS)

这里写目录标题 引言1. 什么是神经架构搜索(NAS)1.1 为什么需要NAS? 2. NAS的三大组件2.1 搜索空间搜索空间设计的考虑因素: 2.2 搜索策略2.3 性能估计 3. NAS的主要方法3.1 基于强化学习的NAS3.2 基于进化算法的NAS3.3 基于梯度的…

【数据结构】图的遍历

快乐的流畅:个人主页 个人专栏:《C游记》《进击的C》《Linux迷航》 远方有一堆篝火,在为久候之人燃烧! 文章目录 引言一、深度优先遍历1.1 定义1.2 实现 二、广度优先遍历2.1 定义2.2 实现 三、DFS与BFS的对比 引言 前置知识&…

linux用户管理运行级别找回root密码

目录 1.用户的添加 1.1用户添加的基本指令 1.2不指定家目录的名称 1.3指定家目录的名称 2.密码的修改 3.删除目录 3.1删除的两个情况 3.2删除的流程 4.查询用户的信息 5.用户的切换 6.用户组 6.1用户组的概念 6.2创建用户到指定的组 6.3修改用户到其他的组 6.4用…

SpringCloud Alibaba之Sentinel实现熔断与限流

(学习笔记) QPS(Query Per Second):即每秒查询率,是对⼀个特定的查询服务器在规定时间内所处理流量多少的衡量标准。QPS req/sec 请求数/秒,即每秒的响应请求数,也即是最⼤吞吐能⼒…