菜叶子芯酸笔记4:大模型训练、分布式训练、显存估算

大模型训练任务主要分为以下三种模型训练过程。

预训练pretrain

监督微调 supervised finetune training

奖励模型 reward model

RLHF

它们之间的顺序联系用RLHF (reinforcement learning with human feedback) 过程来阐释。

首先预训练pretrain得到一个base模型。

到微调阶段,用SFT训练得到一个Actor模型,同时拷贝得到pef_actor模型,为Actor的拷贝,pef_actor不做训练参数优化,用于防止信息遗忘,利用KL散度计算对输出logits做分布近似。

到奖励模型训练阶段,用RM训练方法得到奖励模型,同时拷贝得到critic模型,为RM的拷贝,critic不做训练参数优化,用于提升模型对数据偏好,这里会用到PPO算法。也可以直接用DPO,用偏好数据直接优化。

policy/Actor模型结构。

框图就是RLHF过程。prompt进入tokenize之后。

prompt送入old_policy/ Actor模型,得到response和old_log_prob。这里的输出是迭代反复的过程,因为输出预测的token是连续的,串行的。同时old_values也在更新。至今还是不懂这怎么更新的。

这个博客是这么说的。持有怀疑态度。

prompt 输入到Ref_policy / ref_actor模型里,得到ref_log_probs。【非必需】

prompt + response 输入到critic模型里,得到score。

计算advatanges= reward - old_values。

ref:从0开始实现LLM:7.1、Reward/PPO/DPO/KTO/SimPO详解_dpo 训练、orpo 训练和 simpo 训练-CSDN博客

计算loss。 这里a就是advantages。

更新old_policy模型。

拆解大语言模型RLHF中的PPO - 知乎

RM训练过程:RLHF之前。或者换个说法 old_policy模型的预训练

偏好数据:

在训练时会把conversations+chosen作为chosen_messages,conversations+rejected作为rejected_messages。如果存在system,则将system拼接到conversations前。将两个messages分别转换为token_ids作为输入(此时数据格式为[bs=2n, max_seq_len],n是原本设置的bs)。

而label值是将messages的system+conversations部分的token_id置为-100。

将token_ids输入模型,得到output_values,将其拆成chosen_rewards和rejected_rewards。

对于每个输入对,获取system+prompt长度,截取response+padding部分的reward_value。

数据输入:在训练时会把conversations+chosen作为chosen_messages,conversations+rejected作为rejected_messages。如果存在system,则将system拼接到conversations前。将两个messages分别转换为token_ids作为输入(此时数据格式为[bs=2n, max_seq_len],n是原本设置的bs)。

而label值是将messages的system+conversations部分的token_id置为-100。

将token_ids输入模型,得到output_values,将其拆成chosen_rewards和rejected_rewards。

对于每个输入对,获取system+prompt长度,截取response+padding部分的reward_value。

其中输入最后一个token得到的reward_value结果,作为Reward模型的最终打分。计算损失。

输入输出数据示意图

考虑到每一个chosen_data的token打分都应该比rejected_data的token打分更高。因此在LLM训练阶段,loss使用了全部的response token的loss平均。

Loss

对于chosen和rejected的数据对形式来说,自然是chosen的得分越高越好,rejected的得分越低越好,那么自然是loss=chosen_logits - rejected_logits,此时loss越大越好。

通常来说loss需要在[-1,1]或着[0,1]之间,那么可以变成loss=sigmod(chosen-rejected)。由于在数据非常大或者非常小时,sigmod函数会存在数值溢出问题,且在结果0时容易存在梯度消失问题,可以将sigmod换成logsigmod函数。又考虑到训练时需要loss越低越好,那么可以对结果取负。最终loss公式如下

不同任务的数据格式

pretrain

pretrain 的目的是对模型预训练,使得模型具备基础的知识。这里我们可以把知识理解为记忆面包,统统一股脑喂给模型即可。

1.数据样式

以 wiki_demo.txt 为例,其中每一行信息就是我们上面说到的记忆面包的知识。PT 也是一种自学习的方法,就像之前版本给出的那样,PT 的样本 source 和 target 是一样的。

实现就是设定长度,组队source和target 一样长。

SFT

SFT 有监督微调的数据以 QA 的形式提供,其中 instruction 可以作为 prompt 使用即 Q,如果 instrution 和 input 都有值,则逻辑会把二者拼接在一起作为统一个 Source。output 为 A ,在有监督微调的情况下,前面的 Prompt 会以 mask 的形式进行遮蔽,从而构造 label_ids。

RM

和前面 sft 的逻辑比较相似,sft 只有 source 和 target,这里 prompt 相当于 source,chosen_ids 相当于 positive_target 的 token ids,rejected_ids 相当于 negative_target 的 token ids,最终全部添加至 model_inputs 即可。原始样本的格式我们虽然未给出,但是大家可以自行构建,只需要在 sft 的基础上增加 bad case 即可,不过这一步需要有正负情感的数据标注。

ref:LLM - 数据处理之 Process Dataset For LLM With PT、SFT、RM_llm sft-CSDN博客

分布式计算

显存需求取决于多个因素,包括但不限于以下几点:

1. 模型参数的数量:这是显存需求的主要部分。

2. 激活(中间层输出):模型在前向传播和反向传播过程中产生的中间结果也需要存储在显存中。

3. 优化器状态:如果使用 Adam 或其他需要额外状态的优化器,这些状态也会占用显存。

4. 批量大小(Batch Size):更大的批量大小会增加显存需求。

5. 序列长度(Sequence Length):更长的输入序列会增加显存需求。

估算显存需求的方法 llamaXB 需要3XGB

1. 模型参数

每个参数通常占用 4 字节(对于 float32 类型)。如果是混合精度训练(如 float16),每个参数占用 2 字节。

* Llama 2-7B:

* 参数数量:7,000,000,000

* float32 占用:7,000,000,000 * 4 字节 = 28 GB

* float16 占用:7,000,000,000 * 2 字节 = 14 GB

* Llama 2-13B:

* 参数数量:13,000,000,000

* float32 占用:13,000,000,000 * 4 字节 = 52 GB

* float16 占用:13,000,000,000 * 2 字节 = 26 GB

* Llama 2-30B:

* 参数数量:30,000,000,000

* float32 占用:30,000,000,000 * 4 字节 = 120 GB

* float16 占用:30,000,000,000 * 2 字节 = 60 GB

2. 激活

激活的显存需求取决于模型的结构和层数。假设每层激活的显存需求为 S,则总的激活显存需求可以近似为 S * 层数 * 批量大小 * 序列长度。

3. 优化器状态

对于 Adam 优化器,每个参数需要额外的两个状态变量(梯度的平方和梯度的一阶矩),因此每个参数需要额外的 8 字节(float32)或 4 字节(float16)。

* Llama 2-7B:

* float32 占用:7,000,000,000 * 8 字节 = 56 GB

* float16 占用:7,000,000,000 * 4 字节 = 28 GB

* Llama 2-13B:

* float32 占用:13,000,000,000 * 8 字节 = 104 GB

* float16 占用:13,000,000,000 * 4 字节 = 52 GB

* Llama 2-30B:

* float32 占用:30,000,000,000 * 8 字节 = 240 GB

* float16 占用:30,000,000,000 * 4 字节 = 120 GB

4. 总显存需求

总显存需求是上述各项的总和。假设批量大小为 1,序列长度为 2048,且不考虑其他额外开销,我们可以得到一个大致的估计。

* Llama 2-7B:

* 参数:14 GB (float16)

* 激活:假设每层激活约 1 MB,则总激活约为 1 MB * 32 层 * 1 * 2048 ≈ 64 MB

* 优化器状态:28 GB (float16)

* 总显存:14 GB + 64 MB + 28 GB ≈ 42 GB

* Llama 2-13B:

* 参数:26 GB (float16)

* 激活:假设每层激活约 1 MB,则总激活约为 1 MB * 32 层 * 1 * 2048 ≈ 64 MB

* 优化器状态:52 GB (float16)

* 总显存:26 GB + 64 MB + 52 GB ≈ 78 GB

* Llama 2-30B:

* 参数:60 GB (float16)

* 激活:假设每层激活约 1 MB,则总激活约为 1 MB * 32 层 * 1 * 2048 ≈ 64 MB

* 优化器状态:120 GB (float16)

* 总显存:60 GB + 64 MB + 120 GB ≈ 180 GB

并行策略

数据并行空间复杂度

ADAM 优化器 + 混合精度训练情况下模型状态的显存占用

综上所述,训练过程中的显存占用可分为两大部分:

模型状态:记模型本身参数量为 Φ ,在 Adam + 混合精度训练的情况下,模型状态包括 fp16 的模型参数 2 Φ 和参数梯度 2 Φ 和 fp32 的模型参数备份 4 Φ ,momentum 4 Φ 和 variance 4 Φ ,即总共 2 Φ + 2 Φ + 4 Φ + 4 Φ + 4 Φ = 16Φ 。(注意 fp16 占两个字节,fp32 占四个字节)

剩余状态:即训练中的激活值、临时缓冲区和显存碎片等。

以 GPT-2 为例,GPT-2 模型含有 1.5B 个参数,如果用 fp16 格式,模型本身只占 3GB 显存,但是实际训练过程中的模型状态需要耗费 24GB【1.5*16】!可以看到。模型状态是成倍于模型本身的大小,是显存消耗的大头。并且,对于剩余状态中的激活值等,已经有 activation checkpointing 等以时间换空间的优化方式,可以有效减小这部分显存消耗。因此,优化模型状态的显存占用是重点。

ZeRO 由 ZeRO-DP 和 ZeRO-R 组成,分别是对模型状态和剩余状态的显存优化。

ref : 分布式训练数据并行极致优化:ZeRO_zero分布式训练-CSDN博客

#TODO:

#smooth quant

#deepspeed megatron

#mixtral MOE

#vllm tensorrt-llm

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/16202.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Python爬虫----python爬虫基础

一、python爬虫基础-爬虫简介 1、现实生活中实际爬虫有哪些? 2、什么是网络爬虫? 3、什么是通用爬虫和聚焦爬虫? 4、为什么要用python写爬虫程序 5、环境和工具 二、python爬虫基础-http协议和chrome抓包工具 1、什么是http和https协议…

什么是低温温度传感器

低温学是物理学的一个分支,处理极低温度的产生和影响。已经基于各种与温度相关的特性开发了低温温度传感器。常见的市售传感器包括电阻器,电容器,热电偶和诸如二极管或晶体管的半导体结器件。 主要标准级传感器对热和机械冲击非常敏感&#…

【SpringBoot】23 文件预览(kkFileView)

Gitee仓库 https://gitee.com/Lin_DH/system 介绍 文件预览功能是指在不打开或编辑文件的情况下,通过某种方式查看文件的内容、格式或者部分内容的功能。该功能通常用于文件管理系统、办公工具、在线教育平台、企业协作平台、电子邮件客户端等领域,能…

PC提取微信语音

首先,多选需要转存的语音信息——点击下方正方体图标收藏——打开收藏界面,找到语音文件打开——点击界面上放3个小点,选择转存为笔记。 然后,打开电脑端微信,点击左侧收藏图标,找到保存的语音文件打开&am…

STM32 ADC --- 单通道采样

STM32 ADC — 单通道采样 文章目录 STM32 ADC --- 单通道采样cubeMX配置代码修改:应用 使用cubeMX生成HAL工程 需求:有多个通道需要进行ADC采样,实现每次采样只采样一个通道,且可以随时采样不同通道的功能。 cubeMX配置 这里我们…

力扣 LeetCode 150. 逆波兰表达式求值(Day5:栈与队列)

解题思路: 逆波兰表达式就是从二叉树的后序遍历得来的(左右根),因此计算机直接按顺序取出表达式中元素进行运算即可,无需考虑括号的运算顺序,加快运算速度 对于(12)x(3…

交通路口智能监测平台实现

🏡作者主页:点击! 🤖编程探索专栏:点击! ⏰️创作时间:2024年11月15日8点12分 神秘男子影, 秘而不宣藏。 泣意深不见, 男子自持重, 子夜独自沉。 论文链接 点击开启你的论文编程之旅h…

Redis 持久化机制 RDB 和 AOF 区别

Redis 是一个开源的内存数据结构存储系统,广泛应用于缓存、会话存储、实时分析等场景。虽然 Redis 本质上是内存数据库,但它支持持久化机制,将数据保存在磁盘中以防止数据丢失。在 Redis 中,主要有两种持久化机制:RDB(…

uniapp动态获取练习题的内容选项和最终选择的结果

里面的练习题题目和选项都是动态获取的&#xff0c;提交的时候结果是多个单选题最终选择的值&#xff0c;重点是给单选组标签上加上change事件&#xff0c;多选通用&#xff0c;change事件内加一个回调&#xff0c;代码示例如下&#xff1a; <template> <view class&…

联想 ThinkPad的高级键盘功能

前言&#xff1a; 用好键盘是程序员最需要花时间了解的。 联想ThinkPAD的高级键盘功能和windows的键盘功能是不一样的。学习一下&#xff0c;给自己的工作&#xff0c;编程带来很大的的提高。花时间是有意义的。 调出设置&#xff1a; 1 先是键盘管理&#xff1a; 这里&#…

红黑树

目录 红黑树 红黑树的概念 红黑树的性质 红黑树节点的定义 插入的代码实现 情况一 情况二 uncle不存在 uncle存在且为黑单旋 情况三 uncle存在且为黑的双旋情况 情况二和情况三的总代码 以上是父亲在爷爷左边的情况,右边的情况也类似 左旋代码 右旋代码 红黑树…

MySQL进阶-索引的组合索引

练习题目 题目链接难度SQL进阶-索引的组合索引★★★☆☆ SQL思路 SQL进阶-索引的组合索引 初始化数据 drop table if exists user_profile; CREATE TABLE user_profile ( id int NOT NULL, device_id int NOT NULL, gender varchar(14) NOT NULL, age int , university va…

适用比亚迪汽车生产线的RFID高频读写器

随着人工智能和物联网技术的发展&#xff0c;汽车产线正朝着高度自动化和智能化的方向发展&#xff0c;许多汽车制造商选择将RFID技术应用在其生产线上&#xff0c;以提高生产效率、降低劳动强度。例如比亚迪等汽车生产线上已经广泛应用RFID技术。 健永科技利用自身的研发能力…

用Python实现中国象棋(详细教程 | 附代码)

创建一个完整的中国象棋游戏是一个复杂的项目&#xff0c;涉及到游戏规则、用户界面、AI算法等多个方面。在这里&#xff0c;我将提供一个更完整的Python代码示例&#xff0c;包括基本的棋盘、棋子移动规则和简单的用户交互。但请注意&#xff0c;这仍然是一个简化的版本&#…

力扣-Mysql-3308- 寻找表现最佳的司机(中等)

一、题目来源 3308. 寻找表现最佳的司机 - 力扣&#xff08;LeetCode&#xff09; 二、数据表结构 表&#xff1a;Drivers ----------------------- | Column Name | Type | ----------------------- | driver_id | int | | name | varchar | | age …

LeetCode 209.长度最小的子数组

209.长度最小的子数组 思路&#x1f9d0;&#xff1a; 该题可以用滑动窗口进行解答&#xff0c;滑动窗口的意思是&#xff0c;我们判断一段区间的情况&#xff0c;再根据不同情况进行区间的更新。 这里要求满足总和大于等于target的子数组&#xff0c;那么我们可以用两个指针当…

国网山东电力生产检修建设基地绿色低碳智慧用能项目获创新创意劳动竞赛一等奖

原标题&#xff1a;深化开展“供电能效服务”&#xff0c;全力推动全社会能效提升&#xff0c;国网山东电力生产检修建设基地绿色低碳智慧用能项目获得全省智慧综合能源服务项目创新创意劳动竞赛一等奖 11月14日,由山东省发展和改革委员会、山东省总工会、山东省能源局主办,山…

AIHub: 模型和数据集的私有云存储库

AIStor 的最新功能之一是广受欢迎的开源项目 Hugging Face 的私有云版本。这篇文章详细介绍了 AIStor 的 AIHub 如何有效地创建一个完全由企业控制的 API 兼容的私有云版本的 Hugging Face。在我们开始之前&#xff0c;介绍 Hugging Face 是有意义的。Hugging Face 是面向 AI 工…

【SAP FICO】财务三大报表_2-进阶(现金流量表-数据表结构、取数逻辑)

系列文章目录 文章目录 系列文章目录前言一、现金流量表二、现金流量表的数据表结构1、核心数据表2、内部数据结构 三、现金流量表的取数逻辑1、获取用户输入2、获取数据3、处理数据 总结 前言 承接上篇财务三大报表_2-进阶&#xff08;利润表-数据表结构、取数逻辑&#xff0…

【人工智能】深入解析!三种实现ChatGPT打字机效果的最佳方案

在当今AI快速发展的时代&#xff0c;ChatGPT 凭借其强大的自然语言处理能力&#xff0c;已经成为众多开发者和企业的首选工具。然而&#xff0c;如何在前端页面中实现类似于ChatGPT的打字机效果&#xff0c;以提升用户交互体验&#xff0c;成为了一个广受关注的话题。今天&…