Sample Packing:长序列 LLM 训练的 Attention 问题及优化

一、背景

之前看过部分 Megatron-LM 的源码,也详细分析过对应的 Dataset 和 DataLoader,想当然的认为在 LLM 预训练时会使用 Document Level 的 Mask,也就是常说的 Sample Packing 技术。最近我们在做长序列训练相关工作时发现并非如此,并且出现了一些很奇怪的性能问题,因此重新看了相关工作,并进行了部分实验。

Sample Packing 中有很多可以讨论的技术点,比如 Attention 的实现和优化,Sample 的组合及负载均衡问题(有点类似调度问题)以及不同方案对效果的影响等。我们这里只是先简单介绍一下相关问题和实验,后续会进一步探索更多工作,比如 Document Level 的 Mask 到底对预训练效果影响有多大,对 Attention 进行优化还能带来多少提升,如何设计一个比较好的 Packing 策略等?

二、Dataset + Dataloader

简单说来,预训练通常包含很多不同的数据集,每个数据集又包含许多 Document。为了提升训练效率,在实际训练的时候一个 Sample(Sequence)里面可能会包含多个不同的 Document(Sample Packing)。比如 8K 的预训练 Sequence Length,则一个 Sample 可以包含 8 个 1K 的 Document。

如下图所示,简单展示了 Megatron-LM 中如何 Packing 多个 Document,实际上就是一个多级的索引。需要说明的是,这里其实会引入很多随机读操作,会极大影响读的性能。不过一般 LLM 计算代价都很高,这里也往往不会导致瓶颈。

三、Attention Mask

对于单个 Document 而言,Decoder Only 的 GPT 模型具有 Causal 特性,也就是每个 Token 不能看到之后的 Token,因此在实际训练中需要添加 Attention Mask。如下图所示,这种情况下 Attention Mask 是一个标准的下三角矩阵(Causal Mask),也就是绿色部分为 1,其他部分为 0:

如果一个 Sample 里包含多个样本,则 Attention Mask 矩阵需要变成如下图所示的块对角矩阵形式(Block Diagonal Mask)。比如 Sequence Length 为 16,4 个 Document 的长度分别为 3,4,5,4,则对应 Attention Mask 矩阵如下图所示,对角线上的 4 个矩阵(红框)都是标准的下三角矩阵。按照这种方式可以保证和 4 个 Document 单独作为 Sample 训练是等价的:

四、Reset Attention Mask

4.1 是否需要

那么在实际使用中是否需要严格按照 Block Diagonal Mask 的方式使用呢?答案是否定的,比如 Megatron-LM 可以通过 reset_attention_mask 来控制是使用 Block Diagonal Mask 还是标准的 Causal Mask,默认值为 False。很多模型在预训练时也会采用默认配置,即使用 Causal Mask。

在浪潮的 Yuan-1.0 报告(“源1.0”大模型技术白皮书)中有提到,为了避免不同 Document 之间的相互干扰而将 reset_attention_mask 设置为 True,也就是 Block Diagonal Mask:

在 Meta 的 LLaMA 3.1 技术报告([2407.21783] The Llama 3 Herd of Models)中也提到,在 LLaMA 3.1 模型的预训练中会打开这个配置。不过作者也做了说明,对于 8K Sequence Length 的预训练而言,对模型最终的效果影响不大,对长序列的 Continuous PreTraining 影响比较大:

在 [2402.08268] World Model on Million-Length Video And Language With Blockwise RingAttention 中作者提出了“世界模型”,为了提升超长序列的训练效率,作者采用了 Sample Packing 的策略,并且做了相关消融实验。如下图 Table 10 所示,采用 Naive Packing(不对 Attention Mask 特殊处理)相比使用了 Block Diagonal Mask 的 LWM 的性能会差很多:

PS:当然,目前还没有更多有关预训练中是否 reset_attention_mask 的消融实验,我们后续会进行相关测试。此外,如果采用绝对位置编码,Position-id 也需要相应的调整,在 Megatron-LM 中对应 reset_position_id 选项。

4.2 性能问题

如下图为 Megatron-LM/megatron/core/datasets/gpt_dataset.py 中 reset_attention_mask 的实现方式,首先会将 attention_mask 初始化为标准的 Causal Mask 形式,然后从第二个 Document 开始,将之前的 mask 置为 0:

具体来说如下图所示,初始是一个标准的 Causal Mask 矩阵,然后会将 4x3、5x(3+4) 和 4x(3+4+5) 的区域依次置为 0,之后会变成 Block Diagonal Mask:

实际上我们已经知道这里是标准的 Block Diagonal Mask,可以使用 torch.block_diag() 快速创建。实测当序列比较长时(比如 32K),两种方式速度可能会差几十倍,导致 reset_attention_mask 可能成为训练瓶颈:

除此之外,当序列非常长时,Attention Mask 也会占据很大的存储空间,为了计算效率,往往会使用整型而不是 Bool 类型。假设以 int8 存储,32K 序列长度对应的 Mask 大小为 32K * 32K = 1GB,128K 时更是高达 16GB。为了避免显存浪费,其实不必将其拼成大的 Block Diagonal Mask,而保留几个小的 Causal Mask 即可。

五、Attention 优化

5.1 FlashAttention

当前 LLM 预训练基本都会使用 FlashAttention,其对 Casual Mask 的方式进行了优化,如下图所示,假设 16x16 的 Attention Mask,在计算时按照 4x4 分块,则可以将其分为 3 种情况:

  • 有些块对应的 Mask 都是 0(红框右上部分,比如蓝框),无需再计算。

  • 有些块中部分 Mask 为 0,部分 Mask 为 1(红框),需要相应特殊处理。

  • 有些块对应的 Mask 都是 1(红框左下部分,比如黄框),全部计算即可。

对于上述 Block Diagonal Mask,依然可以使用 Causal Mask 的方式计算,不过会导致大量的无效计算。幸运的是,FlashAttention V2 支持可变序列长度(Varlen)的 Batching Attention 计算,可以避免 Padding 导致的无效计算。因此也就可以借用这种机制来对 Block Diagonal Mask 进行解构,重新分解为多个 Causal Mask 分别计算,可以避免很多无效计算。如下图所示,可以将其看成 4 个独立的 Attention 计算,具体可以参考 FlashAttention Github 上的相关讨论:How to implement example packing with flash_attn v2? · Issue #654 · Dao-AILab/flash-attention · GitHub 和 Will attention_mask be extended to 3D? (concatenate short samples for efficient training) · Issue #432 · Dao-AILab/flash-attention · GitHub。

在 GLM-4([2406.12793] ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools)中也应用了 Sample Packing 方案,并且同样使用了 Block Diagonal Mask 机制来区分不同的 Document。并且作者也是基于 FlashAttention 的 Varlen 功能来实现。

5.2 Pytorch FlashAttention

Pytorch 的 scaled_dot_product_attention 提供了高效的 Attention 实现,也集成了 FlashAttention2 的实现,然而其不支持上述的可变序列长度的功能,导致针对 Block Diagonal Mask 场景时会存在大量的重复计算。

此外,我们在之前的文章中也多次提到,当序列比较短时,Attention 部分计算的占比并不是特别大,因此其中的冗余计算可能对整体训练速度影响不大;但当序列比较长时,Attention 部分计算的占比会越来越大,冗余计算可能会对训练速度有比较大的影响,也就需要对其进行优化。

5.3 FlexAttention

Pytorch 在 2.5.0 版本引入了 FlexAttention(FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention),可以很容易支持各种 Attention Mask 变种,比如标准 Causal Mask、Sliding Window + Causal、Prefix Mask 以及 Document Mask(Block Diagonal Mask)等,相比 FlashAttention 也更加的灵活。

我们基于 FlexAttention 进行了相关测试,以验证使用 Block Diagonal Mask 的性能优势。首先以两个 16K Document 拼接为一个 32K Sample 为例,Attention Mask 大概是如下图所示方式,对应的稀疏度为 74.80%(整个 Mask 中 0 的占比):

如下图所示我们在 H100 GPU 上进行的 Attention 相关性能测试。可以看出, Pytorch 的 Causal + FlashAttention2 方式确实可以达到非常高的 TFLOPS,明显高于 FlexAttention。然而,因为 FlexAttention 中避免了很多无效计算,实际的 Forward 和 Backward 时间反而更短:

当然,也并不意外着 FlexAttention 总是更优的,还和 Sample 中 Document 长度有关。如下图所示为相应测试结果,32K 表示 Sample 中只有一个 Document,2K + 30K 表示 Sample 中有 2 个 Document,一个长度 2K,一个长度 30K。从下图基本上可以得出这样一个结论:当 Sample 中最长的 Document 的长度 <= Sequence Length/2 时,使用 FlexAttention 可能会带来更大的收益:

那么为什么“最长的 Document 的长度 <= Sequence Length/2”时会有收益呢?其实可以简单从稀疏度的角度考虑:假设 a1 + a2 + a3 + … + an = S,并且 0 < a1 <= a2 <= a3 <= … <= an <= S/2,那么可以用数学归纳法得出 (a1)^2 + (a2)^2 + (a3)^2 + … + (an)^2 <= S^2/2。也就是说,最长的 Document 的长度 <= Sequence Length/2 时,稀疏度会 >= 75%(还要考虑 Causal 特性),相应的 FlashAttention 中至少有一半的冗余计算。

因此,我们也需要充分考虑在长文本训练过程中短文本的占比,极端情况下训练数据全部是超长文本,每个 Sample 中都只有一个 Document,Block Diagonal Mask 会退化为 Causal Mask。不过有些时候为了避免模型出现灾难性遗忘,也会混合一些短文本数据,或者高质量的预训练数据,不可避免的会出现冗余计算的问题。

5.3 Sequence Parallel

针对长序列场景通常会采用 RingAttention 和 USP 等,然而不管是 RingAttention 还是其 LoadBalance 版本(如下图 Figure 3 所示)等都没有太多讨论 Sample Packing 的情况。对于 Block Diagonal Mask 场景,其相应的优化,LoadBalance 策略也可能需要对应调整:

在 [2402.08268] World Model on Million-Length Video And Language With Blockwise RingAttention 中作者(也是 RingAttention 的作者)声称针对 Block Diagonal Mask 场景对 RingAttention 进行相关优化,但并没有对比优化前后训练速度的提升。

PS:整体来说,在各种序列并行技术中更好的兼容 Block Diagonal Mask 场景又会有更多的挑战,我们留作后续介绍。

如何学习大模型

现在社会上大模型越来越普及了,已经有很多人都想往这里面扎,但是却找不到适合的方法去学习。

作为一名资深码农,初入大模型时也吃了很多亏,踩了无数坑。现在我想把我的经验和知识分享给你们,帮助你们学习AI大模型,能够解决你们学习中的困难。

我已将重要的AI大模型资料包括市面上AI大模型各大白皮书、AGI大模型系统学习路线、AI大模型视频教程、实战学习,等录播视频免费分享出来,需要的小伙伴可以扫取。

一、AGI大模型系统学习路线

很多人学习大模型的时候没有方向,东学一点西学一点,像只无头苍蝇乱撞,我下面分享的这个学习路线希望能够帮助到你们学习AI大模型。

在这里插入图片描述

二、AI大模型视频教程

在这里插入图片描述

三、AI大模型各大学习书籍

在这里插入图片描述

四、AI大模型各大场景实战案例

在这里插入图片描述

五、结束语

学习AI大模型是当前科技发展的趋势,它不仅能够为我们提供更多的机会和挑战,还能够让我们更好地理解和应用人工智能技术。通过学习AI大模型,我们可以深入了解深度学习、神经网络等核心概念,并将其应用于自然语言处理、计算机视觉、语音识别等领域。同时,掌握AI大模型还能够为我们的职业发展增添竞争力,成为未来技术领域的领导者。

再者,学习AI大模型也能为我们自己创造更多的价值,提供更多的岗位以及副业创收,让自己的生活更上一层楼。

因此,学习AI大模型是一项有前景且值得投入的时间和精力的重要选择。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/147843.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

灰狼算法求解函数,MATLAB代码

目录 程序说明 概述 主要功能 关键函数 结论 程序说明 概述 该程序实现了灰狼优化算法&#xff08;GWO&#xff09;&#xff0c;用于求解优化问题。该算法模拟灰狼的捕猎行为&#xff0c;通过种群搜索找到最优解。程序中定义了种群数量、问题维度、变量上下界和适应度函…

全行业商家0退货0退款一键卖全球,淘天助力卖家拓展海外生意!

今年7月中旬&#xff0c;淘宝推出了“大服饰全球包邮计划”&#xff0c;在服饰行业先行先试&#xff0c;带领商家拓展海外市场。计划推出以来&#xff0c;吸引了数十万服饰商家报名参与&#xff0c;包括天猫商家蕉下、淘宝商家JOC、美洋等。有服饰商家怀着试一试的心态报了名&a…

碳课堂|CBAM的制度及核心内容

引言 全球变暖和气候变化是21世纪面临的最严峻挑战之一。为应对这一全球性问题&#xff0c;各国纷纷采取措施&#xff0c;减少温室气体排放&#xff0c;并推动可持续发展。其中&#xff0c;欧盟提出的碳边界调整机制&#xff08;CBAM, Carbon Border Adjustment Mechanism&…

pr视频剪辑、福昕剪辑……四款剪辑视频大比拼

最近入了视频剪辑的坑&#xff0c;我最近在尝试不同的视频剪辑软件&#xff0c;想找到最适合我的那一款。今天&#xff0c;我就来跟大家分享一下我使用福昕视频剪辑、爱拍视频剪辑、Adobe Premiere&#xff08;简称PR&#xff09;和Shotcut这四款软件时的一些体验和感受。希望我…

FPGA_传递参数的方式

FPGA Verilog 调用模块后带有 “ #()” 的含义 最后4个LED闪烁控制模块的例化,它们的源码都是 led_controller.v 模块&#xff0c;但它们的名称不一样,分别为“uut_led_controller_clk12m5 ”&#xff0c;“uut_led_controller_clk25m”&#xff0c;“uut_ledcontroller clk50…

Pandas -----------------------基础知识(二)

dataframe读写数据操作 import pandas as pd# 准备数据(字典) data [[1, 张三, 1999-3-10, 18],[2, 李四, 2002-3-10, 15],[3, 王五, 1990-3-10, 33],[4, 隔壁老王, 1983-3-10, 40] ]df pd.DataFrame(data, columns[id, name, birthday, age]) df写到csv文件中 &#xff0c;…

Azure Pipeline 常用任务记录

各种任务的查询&#xff1a; 任务查询 下载类 1 DownloadPackage1 从 Azure Artifacts 中的包管理源下载包 2 DownloadSecureFile1 下载安全文件&#xff0c;这里的安全文件在Library中上传&#xff0c;默认的位置会传到$(Agent.TempDirectory) 3 DownloadBuildArtifacts1…

shopify主题开发中给产品页设置多个模板

在shopify开发中&#xff0c;有时候商家可能需要为不同的产品去设置自己想要的产品页模板。下面主要教大家如何为产品类型页面设置多个模板&#xff0c;大家只要按照下面几个步骤就可以轻松实现产品的定制化页面&#xff1a; 1、首先在定制器创建产品模板 进入商品自定义页面…

【LangChain系列】实战案例5:用LangChain实现灵活的Agents+RAG,该查时查,不该查时就别查

目前为止&#xff0c;我们实现的RAG练习中&#xff0c;答案都是全部来源于检索到的文本内容。而检索过程可能在某些情况下是不需要的。 如何优化这个过程&#xff0c;让我们的RAG程序在必要时才去检索&#xff0c;不必要时&#xff0c;直接使用大模型原有数据来回答呢&#xf…

M2型TAM靶向肽CRV; Ahx-CRVLRSGSC ;

【M2型TAM靶向肽CRV 简介】 M2型TAM靶向肽CRV是一种用于靶向肿瘤相关巨噬细胞&#xff08;TAMs&#xff09;中M2型亚群的多肽。这种多肽序列为CRVLRSGSC&#xff0c;包含一对二硫键&#xff0c;其三字母代码为Cys-Arg-Val-Leu-Arg-Ser-Gly-Ser-Cys&#xff08;Cys-Cys&#xff…

什么是json?

JSON简介:JSON的全称为JavaScript Object Nation(JavaScript 对象表示语法)&#xff0c;基于 ECMAScript&#xff0c;存放的是的类似于键值对&#xff0c;本质上来说是javascript的数据类型&#xff0c;是一种轻量级的数据交互格式&#xff0c;简单来说呢&#xff0c;json就是一…

万博智云CEO王嘉在华为全联接大会:以创新云应用场景,把握增长机遇

一、大会背景 2024年9月19-21日&#xff0c;第九届华为全联接大会将在上海世博展览馆和上海世博中心举办。作为华为的旗舰盛会&#xff0c;本次大会以“共赢行业智能化”为主题邀请了众多思想领袖、商业精英、技术专家、合作伙伴、开发者等业界同仁&#xff0c;从战略、产业、…

NS2159 1A 线性锂离子电池充电管理IC

1 特性 ● 输入电压范围 4.5V-26V ● 输入过压保护电压 6.0V ● 用于单节锂离子电池线性工作模式充电 ● 支持 0V 电池电压充电 ● 涓流/恒流/恒压三段式充电 ● 内部预设 4.2V 充电浮充电压 ● 1A 可编程充电电流 ● C/10 充电终止功能 ● 内置自动复充功能 ● 内置过温保护功…

51单片机-DA(数字转模拟)(呼吸灯)

作者&#xff1a;Whappy 个人理解&#xff1a;将电压或电流信号进行等分或不等分&#xff08;高电平的电压范围和低电平的范围&#xff0c;如0-5v&#xff0c;0-1.8位低电平&#xff0c;3.8-5v为高电平&#xff09;&#xff0c;同样也是通过采样&#xff0c;量化等操作将不连续…

智能创造的幕后推手:AIGC浪潮下看AI训练师如何塑造智能未来

文章目录 一、AIGC时代的算法与模型训练概览二、算法与模型训练的关键环节三、AI训练师的角色与职责四、AI训练师的专业技能与素养五、AIGC算法与模型训练的未来展望《AI训练师手册&#xff1a;算法与模型训练从入门到精通》亮点内容简介作者简介谷建阳 目录 《医学统计学从入门…

2024图纸加密软件最佳选择!10款超好用的图纸加密软件推荐!

随着企业对数据安全的重视不断提升&#xff0c;尤其是在涉及重要设计图纸等机密文件的管理上&#xff0c;选择一款高效且安全的图纸加密软件显得尤为重要。2024年&#xff0c;我们精选了10款超好用的图纸加密软件&#xff0c;帮助企业保护知识产权与机密数据的安全。 1.安秉图纸…

多语言文本 AI 情感分析 API 数据接口

多语言文本 AI 情感分析 API 数据接口 AI / 文本处理 AI 模型快速分析文本情感倾向 多语言文本 / 情感分析。 1. 产品功能 支持多语言文本情感分析&#xff1b;基于特定 AI 模型&#xff0c;快速识别文本情感倾向&#xff1b;适用于评论分析、舆情监控等场景&#xff1b;全接…

2024/9/23 leetcode 148题 排序链表

目录 148.排序链表 题目描述 题目链接 解题思路与代码 148.排序链表 题目描述 给你链表的头结点 head &#xff0c;请将其按 升序 排列并返回 排序后的链表 。 示例 1&#xff1a; 输入&#xff1a;head [4,2,1,3] 输出&#xff1a;[1,2,3,4]示例 2&#xff1a; 输入&am…

【Python】入门学习1:开发前的准备

准备工作&#xff1a; 1、电脑系统&#xff1a;windows 64位&#xff1b; 2、python学习所需工具&#xff1a;“解释器、编译器”&#xff1b; &#xff08;1&#xff09;python 解释器&#xff1a;解释代码的&#xff0c;把 python 计算机语言翻译给计算机认识&#xff1b;…

双通道隔离驱动之选,SLMi823x系列SLMi8235BDCG-DG可编程死区满足您需求

SLMi823x系列SLMi8235BDCG-DG双通道死区可编程的隔离驱动器。SLMi823x系列SLMi8235BDCG-DG配置为双输入&#xff0c;双输出驱动器。另外&#xff0c;SLMi823x系列SLMi8235BDCG-DG峰值输出电流为 4.0A。 所有输出驱动器的 VDDA/B 电源电压最高到40V。3V 至 18V 的 VDDI 宽范围输…