从FasterTransformer源码解读开始了解大模型(2.1)代码通读03

从FasterTransformer源码解读开始了解大模型(2.2)代码解读03-forward函数

写在前面的话

本篇的内容继续解读forward函数,从650行开始进行解读

零、输出Context_embeddings和context_cum_log_probs的参数和逻辑

从653行开始,会从输入的请求tensors中读取一个配置,如果请求中配置了is_return_context_embeddings参数并设置为true时,则会在返回参数中增加一个context_embeddings的tensor,这个tensor中包含的数据是输入经过了ContextDecoder过程的所有的层之后的logits,并对其进行求和。可能有些类似于强化学习(RLHF)之类的场景会用到这里的输出,所以在这里做了一层准备。

跳转到1093行,可以看见,这里调用了invokeSumLengthDimension这个kennels,来将context_decoder_output_buf这块buf中的显存数据拷贝到输出tensor的context_embeddings中,可以简单看一下这个kernel的实现,在src/fastertransformer/kernels/gpt_kernels.cu中

template<typename T>
void invokeSumLengthDimension(float*       out_buf,const T*     in_buf,const size_t batch_size,const size_t input_length,const size_t hidden_dim,cudaStream_t stream)
{dim3 gridSize(batch_size);dim3 blockSize(256);sum_length_dimension<<<gridSize, blockSize, 0, stream>>>(out_buf, in_buf, batch_size, input_length, hidden_dim);
}template<typename T>
__global__ void sum_length_dimension(float* out_buf, const T* in_buf, const size_t batch_size, const size_t input_length, const size_t hidden_dim)
{const int bidx = blockIdx.x;for (int hidx = threadIdx.x; hidx < hidden_dim; hidx += blockDim.x) {float accum = 0.0f;for (int step = 0; step < input_length; step++) {accum += static_cast<float>(in_buf[(bidx * input_length + step) * hidden_dim + hidx]);}out_buf[bidx * hidden_dim + hidx] = accum;}
}

从kernel中可以看出,这个kernel任务划分为按照batch_size维度进行了grid task分配,并为每个任务划分了256个线程。在kernel内部则是将每个batch内所有输入的logits按照输入length维度进行了累加,并拷贝到输出buf中。

类似的一个设置和处理环节,在783行还处理了is_return_context_cum_log_probs,这里会将contextDecoder完成之后的输出进行cum log计算。其中主要的处理逻辑函数是403行的computeContextCumLogProbs函数,进入这个函数后,在425到502行的处理逻辑是,首先将context_decoder的输出进行layernorm,然后使用cublas矩阵乘,将词表维度的logits计算出来。在此之后,在504行,则会使用invokeLogProbFromLogits这个kernel.

template<typename T>
void invokeLogProbFromLogits(float*       cum_log_probs,const T*     logits,const int*   input_ids,const int*   input_lengths,const size_t max_input_length,const size_t batch_size,const size_t vocab_size,const size_t vocab_size_padded,void*        workspace,const size_t workspace_size,cudaStream_t stream,const bool   batch_first)
{// A batched version of log prob computation.//// cum_log_probs: [batch_size]// logits: [max_input_length, batch_size, vocab_size] or [batch_size, max_input_length, vocab_size]// input_ids: [max_input_length, batch_size] or [max_input_length, batch_size]// input_lengths: [batch_size]// workspace: workspace buffer of size at least sizeof(float) * max_input_length * batch_size.FT_LOG_DEBUG(__PRETTY_FUNCTION__);// block_size should be multiple of 32 to use warpReduceMax.const int block_size = vocab_size < 1024 ? (vocab_size + 31) / 32 * 32 : 1024;assert(block_size % 32 == 0);assert(workspace != nullptr && workspace_size >= sizeof(float) * max_input_length * batch_size);assert(vocab_size <= vocab_size_padded);float* log_probs = reinterpret_cast<float*>(workspace);int    gx        = batch_first ? batch_size : max_input_length - 1;int    gy        = batch_first ? max_input_length - 1 : batch_size;dim3   grid(gx, gy);log_probs_kernel<T><<<grid, block_size, 0, stream>>>(log_probs,logits,input_ids,input_lengths,max_input_length,batch_size,vocab_size,vocab_size_padded,batch_first);accumulate_log_probs<<<batch_size, block_size, 0, stream>>>(cum_log_probs, log_probs, input_lengths, max_input_length, batch_size, batch_first);
}__global__ void accumulate_log_probs(float*       cum_log_probs,const float* log_probs,const int*   lengths,const size_t max_input_length,const size_t batch_size,const bool   batch_first)
{// Accumulate the log probability along with the sequence dimension.//   cum_log_probs[j] = sum_i log(softmax(logits))[ids[i,j]]//// cum_log_probs: [batch_size], cumulative log probability// log_probs: [max_length - 1, batch_size] or [batch_size, max_length - 1],//   log probability of each token// lengths: [batch_size], sequence lengths// batch_size: [1], batch_size. in case of beam > 1, batch x beam.int bidx = blockIdx.x;   // batch dimint tidx = threadIdx.x;  // step dimif (bidx < batch_size) {int length = lengths[bidx];// reposition logits to data for the current batch.log_probs += batch_first ? bidx * (max_input_length - 1) : bidx;int   stride      = batch_first ? 1 : batch_size;  // stride along with seq dim.float local_accum = 0.0f;for (int step = tidx; step < length - 1; step += blockDim.x) {local_accum += static_cast<float>(log_probs[step * stride]);}float accum = blockDim.x <= 32 ? warpReduceSum(local_accum) : blockReduceSum<float>(local_accum);if (tidx == 0) {cum_log_probs[bidx] = accum;}}
}

在任务划分阶段,grid的x维度是batch维度,而y维度则是输入长度,进入kernel后,则是按照batch维度对任务的log_probs进行了ReduceSum,并写入到cum_log_probs中。

一、跳过prompt和prefix阶段

回到653行,接下来是ft中处理其自身特殊的prompt和prefix的逻辑,但我们的源码解读可以先跳过这一段。之所以要跳过这一段是在当前主要的大模型处理逻辑中,并不会需要在推理引擎这一层过多地关注prompt和prefix的逻辑。在对话的大模型agent中,对于第n次对话的输入/输出,其真正要进行forward的输入,往往是由前n-1轮模型的输出和用户的输入共同组成的。从推理引擎的角度来看,prefill阶段在真正端到端时延的占比并不算高,所以专门设计一个处理prompt的逻辑(还需要常驻的显存空间)多少显得有些得不偿失了

二、正常的逻辑起始和输入展开

让我们直接来到865行,考虑以下一种场景来简化我们的源码解读的结构:beam_width=1,tp=1,pp=1,use_shared_contexts不启用,也就是说,没有beam search,在单机单卡上进行一次简单的,不共享contexts的推理服务的进行。

从866行开始,由于是计算起始,所以先对一些计算中的buff进行清空。在870行,由于不进行beam search,所以也不需要对beam search用到的buf进行清空。

在877到887行,如果词表大小不对齐的话,需要将词表计算的权重进行对齐拷贝。

在889到905行,不使用shared context的话,这一段的逻辑也不需要进行处理。在914行处理prompt的逻辑也可以跳过,直接进入955行的处理逻辑。在955行,由于我们的输入是带batch的,所以需要将batch进行展开(tiled),这里使用的是kernel invokeTileGptPromptInputs

void invokeTileGptPromptInputs(int*         tiled_input_ids,int*         tiled_input_lengths,int*         tiled_prompt_lengths,const int*   input_ids,const int*   input_lengths,const int*   prefix_prompt_lengths,const int    batch_size,const int    beam_width,const int    max_input_length,cudaStream_t stream)
{dim3 grid(batch_size, beam_width);dim3 block(min(1024, max_input_length));if (prefix_prompt_lengths != nullptr) {tileGptPromptInputs<true><<<grid, block, 0, stream>>>(tiled_input_ids,tiled_input_lengths,tiled_prompt_lengths,input_ids,input_lengths,prefix_prompt_lengths,max_input_length);}else {tileGptPromptInputs<false><<<grid, block, 0, stream>>>(tiled_input_ids,tiled_input_lengths,tiled_prompt_lengths,input_ids,input_lengths,prefix_prompt_lengths,max_input_length);}
}template<bool PREFIX_PROMPT>
__global__ void tileGptPromptInputs(int*       tiled_input_ids,int*       tiled_input_lengths,int*       tiled_prompt_lengths,const int* input_ids,const int* input_lengths,const int* prefix_prompt_lengths,const int  max_input_length)
{if (threadIdx.x == 0) {tiled_input_lengths[blockIdx.x * gridDim.y + blockIdx.y] = input_lengths[blockIdx.x];if (PREFIX_PROMPT) {tiled_prompt_lengths[blockIdx.x * gridDim.y + blockIdx.y] = prefix_prompt_lengths[blockIdx.x];}}for (int index = threadIdx.x; index < max_input_length; index += blockDim.x) {tiled_input_ids[(blockIdx.x * gridDim.y + blockIdx.y) * max_input_length + index] =input_ids[blockIdx.x * max_input_length + index];}
}

在任务划分阶段,按照batch维度进行划分,每个任务起了至少1024个线程来进行拷贝。由于没有beam width,那么gridDim.y就会是1。在kernel中,首先将input_length拷贝到tiled_input_lengths,之后再处理

在kernel中,由于没有beam width,那么gridDim.y就会是1。在kernel中,首先将input_length拷贝到tiled_input_lengths,之后再按照batch维度进行处理,将一个batch*max_input_length的数据进行展开,并拷贝到对应的buf,这有利于我们进行接下来的后续embedding和MHA计算
在这里插入图片描述

下一回预告

下一回继续讲解forward函数中的处理逻辑,会简单讲解embedding, pre_layernorm,以及进入attention_layer之后的初步讲解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1474821.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Edge浏览器油猴插件的安装与使用

油猴 (又称篡改猴或Tampermonkey) 是最流行的浏览器扩展之一。它允许用户自定义并增强网页的功能。用户脚本是小型 JavaScript 程序&#xff0c;可用于向网页添加新功能或修改现有功能。使用油猴&#xff0c;您可以轻松在任何网站上创建、管理和运行这些用户脚本。 1.插件的安…

spRAG框架学习小结

spRAG是什么 spRAG是一个针对非结构化数据的检索引擎。它特别擅长处理对密集文本的复杂查询&#xff0c;比如财务报告、法律文件和学术论文。有两种关键方法用于提高性能&#xff0c;超越了普通的RAG系统&#xff1a; 自动上下文&#xff08;AutoContext&#xff09;&#xff…

SQLServer Manager Studio扩展开发从入门到弃坑

Visualstudio的已经开发好了&#xff0c;可这个就是不行&#xff0c;直接运行点这些按钮加载失败&#xff0c;而我直接不调试模式&#xff0c;则直接什么都没有&#xff0c;调试 发现是根本没触发逻辑的。 文档资料太少&#xff0c; 我换了几个ssms.exe都不行&#xff0c;18-20…

【电路笔记】-AB类放大器

AB类放大器 文章目录 AB类放大器1、概述2、AB类放大器介绍3、AB类放大器效率4、偏置方法4.1 电压偏置4.2 分压网络4.3 电位器偏置4.4 二极管偏置5、二极管网络和电流源6、AB类放大器的电源分配7、总结1、概述 A类放大器提供非常好的输出线性度,这意味着可以忠实地再现信号,但…

Nacos架构设计

Nacos1.X架构设计 Nacos2.X架构修改

[240707] X-CMD v0.3.14: cb gh fjo zig 模块增强;新增 lsio 和 pixi 模块

目录 X-CMD 发布 v0.3.14✨ advise&#xff1a;Bash 环境下自动补全时&#xff0c;提供命令的描述信息✨ cb:支持下载指定版本的附件资源✨ gh:支持下载指定版本的附件资源✨ fjo:支持下载指定版本的附件资源✨ zig&#xff1a;新增 pm 和 zon 子命令✨ lsio&#xff1a;用于查…

Raylib 坐标系

draftx 符号调整为正数 发现采样坐标系原点0&#xff0c;0 在左上角&#xff0c;正方向 右&#xff0c;下 绘制坐标系 原点0&#xff0c;0 在左下角&#xff0c;正方向 右&#xff0c;上 拖拽可得 #include <raylib.h> // 重整原因&#xff1a;解决新函数放大缩小之下…

华媒舍:8种网站构建推广方法全揭密!

网站构建成为了推广宣传和宣传品牌的关键一环。对于新手&#xff0c;搭建和营销推广网站有可能是一项全新的挑战。下面我们就为大家介绍8种网站搭建和营销推广技巧&#xff0c;帮助你在这些方面取得成功。 1.选择适合自己的网站构建平台选择合适的网站构建平台针对构建一个成功…

阿里云ecs服务器,nginx多域名多项目部署教程,含本地部署教程

nginx多域名部署项目 本地部署线上部署 一、本地部署 第一步&#xff1a; winr 输入drivers 打开hosts文件&#xff0c;编辑 加行 127.0.0.1 自定义域名 … 第二步&#xff1a; 下载 nginx 安装好以后 打开ngin安装目录&#xff0c;选择nginx.conf 打开 #user Administ…

【Transformer】transformer模型结构学习笔记

文章目录 1. transformer架构2. transformer子层解析3. transformer注意力机制4. transformer部分释疑 图1 transformer模型架构 图2 transformer主要模块简介 图3 encoder-decoder示意图N6 图4 encoder-decoder子层示意图 1. transformer架构 encoder-decoder框架是一种处理NL…

ZYNQ-LINUX环境C语言利用Curl库实现HTTP通讯

前言 在Zynq-Linux环境中&#xff0c;需要使用C语言来编写APP时&#xff0c;访问HTTP一般可以使用Curl库来实现&#xff0c;但是在Zynq的SDK中&#xff0c;并没有集成该库&#xff0c;在寻找了很多资料后找到了一种使用很方便的额办法。这篇文章主要记录一下移植Curl的过程。 …

一招解决找不到d3dcompiler43.dll,无法继续执行代码问题

当您的电脑遇到d3dcompiler43.dll缺失问题时&#xff0c;首先需要了解d3dcompiler43.dll文件及其可能导致问题的原因&#xff0c;之后便可以选择合适的解决方案。在此&#xff0c;我们将会为您提供寻找d3dcompiler43.dll文件的多种处理方法。 一、d3dcompiler43.dll文件分析 d…

C++入门7——string类详解

目录 1.什么是string类&#xff1f; 2.string类对象的常见构造 2.1 string(); 2.2 string (const char* s); 2.3 string (const string& str); 2.4 string (const string& str, size_t pos, size_t len npos); 2.5 string (const char* s, size_t n); 2.7 验证…

绝区肆--2024 年AI安全状况

前言 随着人工智能系统变得越来越强大和普及&#xff0c;与之相关的安全问题也越来越多。让我们来看看 2024 年人工智能安全的现状——评估威胁、分析漏洞、审查有前景的防御策略&#xff0c;并推测这一关键领域的未来可能如何。 主要的人工智能安全威胁 人工智能系统和应用程…

Java里的Arrary详解

DK 中提供了一个专门用于操作数组的工具类&#xff0c;即Arrays 类&#xff0c;位于java.util 包中。该类提供了一些列方法来操作数组&#xff0c;如排序、复制、比较、填充等&#xff0c;用户直接调用这些方法即可不需要自己编码实现&#xff0c;降低了开发难度。 java.util.…

Python爬虫系列-让爬虫自己写爬虫(半自动化,代替人工写爬虫)

现在的PC、手机客户端等终端设备大量使用了网页前后端技术&#xff0c;另外主流的网站也会经常会更新&#xff0c;导致以前一个月更新一次爬虫代码&#xff0c;变成了天天需要更新代码&#xff0c;所以自动化爬虫技术在当前就显得特别重要&#xff0c;最近我也是在多次更新某个…

赋值运算符重载和const成员函数和 const函数

文章目录 1.运算符重载(1)(2)运算符重载的语法&#xff1a;(3)运算符重载的注意事项&#xff1a;(4)前置和后置重载区别 2.const成员函数3.取地址及const取地址操作符重载4.总结 1.运算符重载 (1) 我们知道内置类型(整形&#xff0c;字符型&#xff0c;浮点型…)可以进行一系…

TB作品】51单片机 Proteus仿真 51单片机SPI显示OLED字符驱动

// GND 电源地 // VCC 接5V或3.3v电源 // D0 P1^4&#xff08;SCL&#xff09; // D1 P1^3&#xff08;SDA&#xff09; // RES 接P12 // DC 接P11 // CS 接P10 OLED显示接口与控制实验报告 背景 OLED&#xff08;有机发光二极管&#xff09;显示器由于其高对比度、低功耗和…

最新版Python安装教程

一、安装Python 1.下载Python 访问Python官网&#xff1a; https:/www.oython.orgl 点击downloads按钮&#xff0c;在下拉框中选择系统类型(windows/Mac OS./Linux等) 选择下载最新稳定版本的Python 以下内容以演示安装Windows操作系统64位的python 左边是稳定发布版本Stabl…

Linux权限概述

一、权限概述 1.权限的基本概念 2.为什么要设置权限 3.linux用户的身份类别 4.user文件的拥有者 5.group文件所属组内用户 6.other其他用户 7.特殊用户root 二、普通权限管理 1.ls -l查看文件权限 2.文件类型以及权限解析 3.文件或文件夹的权限设置 4.通过数字给文件…