【Pytorch】大语言模型中的CrossEntropyLoss

文章目录

  • 前言
  • 什么是CrossEntropyLoss
  • 语言模型中的CrossEntropyLoss
    • 计算loss的前期准备
    • CrossEntropyLoss的输入
    • CrossEntropyLoss的输出
  • 额外说明

前言

在大语言模型时代,我们常常使用交叉熵损失函数来计算loss,因此,理解该loss的计算流程有助于帮助我们对训练过程有更清晰的认知。本文从以下几个角度介绍nn.CrossEntropyLoss()

  • 使用该函数的前期准备:如何组织函数的输入(logits & labels)
  • 该函数流程
  • 常用参数
  • 该文章内容仅为个人理解,如有误解,欢迎讨论

什么是CrossEntropyLoss

这部分并不是本文的重点,我们仅介绍在语言模型的训练过程中,如何利用该loss

  • 相关信息可见:本人博客
  • 以及官网:CrossEntropyLoss官网

语言模型中的CrossEntropyLoss

计算loss的前期准备

huggingface-transformers源码中,我们在语言模型的forward中总是能看到这样一段函数。我们以LlamaForCausalLM为例:Llama源码

if labels is not None:# Shift so that tokens < n predict nshift_logits = logits[..., :-1, :].contiguous()shift_labels = labels[..., 1:].contiguous()# Flatten the tokensloss_fct = CrossEntropyLoss()shift_logits = shift_logits.view(-1, self.config.vocab_size)shift_labels = shift_labels.view(-1)# Enable model parallelismshift_labels = shift_labels.to(shift_logits.device)loss = loss_fct(shift_logits, shift_labels)if not return_dict:output = (logits,) + outputs[1:]return (loss,) + output if loss is not None else output

对于Decoder-only模型,在训练时,我们的目标是next token prediction,任务流程如下

  • 假定我们是常规的问答任务,问题是“where is the capital of China“,label为“The capital is Beijing”。该任务的目标为,当输入为“where is the capital of China“时,

  • 我们对question和label进行拼接和tokenize化,一般转化结果 (tokenize忽略) 为:< bos > where is the capital of China < sep > The capital is Beijing < eos >

    • < bos>为句子开头的标志
    • < sep>用于分隔question和label,本质作用是,当模型看到时就知道:问题结束了,下一个token要输出答案了
    • < eos>为生成结束的标志
    • 假定每个词算一个token (忽略空格),那么输入一共有13个token
  • 这时我们将整个序列输入到模型中,模型在每个token的位置都生成一个向量,我们利用lm_head将最后一层的hidden state转化成词表大小的向量logits,用于后续利用Softmax确定每个token的概率

  • 现在模型有了输出logits,怎么计算loss?

    • 对比labels和logits之间的差异来计算loss

    • 现在一共有13个token,生成了13个logits,每个logits都是用于生成next token的。那么很直接的,我们来对比该logits生成的next token准不准就好了

      • 输入:< bos> where is the capital of China < sep> The capital is Beijing < eos>

      • 对比情况为:< sep>->The, The->capital, …, is->Beijing, Beijing->< eos>

        • < sep>对应位置要生成The,…, Beijing对应位置要输出< eos>
      • 我们可以将输入右移一位作为labels: where is the capital of China < sep> The capital is Beijing

        • 可以看到,对于输入来说, < eos>位置没有对应的需要生成的token,因此我们去掉该token
        • 对于labels,< bos>不需要生成,因此我们去掉该token
      • 因此,我们在计算loss时,对logits去尾,labels是输入掐头且右移一位

      • 在代码中对应

          shift_logits = logits[..., :-1, :].contiguous()shift_labels = labels[..., 1:].contiguous()
        

CrossEntropyLoss的输入

此时还不能直接将shift_logitsshift_labels进行对比,来计算loss。因为我们上面的操作只是为了<sep> The capital is BeijingThe capital is Beijing <eos>中的token能一一对应起来,对于其他部分生成的token,我们并没有要求(因为不是answer,不需要生成)

  • CrossEntropyLoss函数中有一个参数为ignore_idx默认值为-100。labels值设置为-100的位置不会计算loss
  • 因此我们将除了需要计算loss的位置 (最后5个位置)的labels都设置为-100
  • 最终,需要输入到CrossEntropyLoss中的inputs和labels为
    • inputs为: [, where, is, the, capital, of, China, < sep>, The, capital, is, Beijing]对应的logits
      • 注意:不需要进行Softmax,直接传logits即可,函数内部有更稳定的Softmax计算方式
    • labels为: [-100, -100, -100, -100, -100, -100, -100, The, capital, is, Beijing, < eos>]
    • 我们在训练时,构造输入和labels要注意构造为这种形式

CrossEntropyLoss的输出

默认情况下,输出为mean,即各个token计算得到loss的平均值(在token-level上平均,分母是token的个数)

import torch
import torch.nn as nn# 假设有 3 个类,logits 形状为 (batch_size=3, num_classes=3)
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3], [1.5, 0.5, 2.0]])# 标签,其中第二个样本的标签为 ignore_index (-100)
labels = torch.tensor([0, -100, 2])# 定义 CrossEntropyLoss
criterion = nn.CrossEntropyLoss()# 计算损失
loss = criterion(logits, labels)print(f"Loss: {loss}")
>>> Loss: 0.51058030128479
  • 常用参数:

    • reduction:控制loss的输出形式,共三种'none', 'mean', 'sum',默认为'mean'

      • mean: 每个token计算得到的loss的平均值

      • none: 直接返回每个token计算得到的loss

        • 例子:

          import torch
          import torch.nn as nn# 假设有 3 个类,logits 形状为 (batch_size=3, num_classes=3)
          logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3], [1.5, 0.5, 2.0]])# 标签,其中第二个样本的标签为 ignore_index (-100)
          labels = torch.tensor([0, -100, 2])# 定义 CrossEntropyLoss
          criterion = nn.CrossEntropyLoss(reduction='none')# 计算损失
          loss = criterion(logits, labels)print(f"Loss: {loss}")
          >>> Loss: tensor([0.4170, 0.0000, 0.6041])
          
      • sum: 所有token对应loss求和

额外说明

对最上面的代码补充说明

  shift_logits = shift_logits.view(-1, self.config.vocab_size)shift_labels = shift_labels.view(-1)
  • 训练数据往往是按batch组织的,shape为(batch_size, seq_len, vocab_size)
  • 我们将所有batch的token压缩为一个序列,计算整个序列的loss,这样比较方便

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/145056.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

欧美海外仓系统有哪些服务商选择?

在跨境电商的全球化浪潮中&#xff0c;欧美市场以其成熟的电商生态和庞大的消费群体&#xff0c;成为了众多跨境卖家竞相争夺的高地。为了提升物流效率、降低成本并增强客户体验&#xff0c;海外仓成为了不可或缺的一环。而海外仓系统的选择&#xff0c;则直接关系到仓库的运营…

排序-----选择排序

首先介绍几种排序的分类&#xff1a; 选择排序是每次都遍历&#xff0c;标记出最小的元素&#xff0c;然后把它放在前面。 本文介绍优化后的版本&#xff1a;每次遍历标记出最小的和最大的元素&#xff0c;分别放到前面和后面。&#xff08;注意这里是找到对应的下标&#xff0…

数据结构与算法——Java实现 6.递归

要学会试着安静下来 —— 24.9.17 一、递归的定义 计算机科学中&#xff0c;递归是一种解决计算问题的方法&#xff0c;其中解决方案取决于同一类问题的更小子集 说明: ① 自己调用自己&#xff0c;如果说每个函数对应着一种解决方案&#xff0c;自己调用自己意味着解决方案是…

2024/9/20 使用QT实现扫雷游戏

有三种难度初级6x6 中级10x10 高级16x16 完成游戏 游戏失败后&#xff0c;无法再次完成游戏&#xff0c;只能重新开始一局 对Qpushbutton进行重写 mybutton.h #ifndef MYBUTTON_H #define MYBUTTON_H #include <QObject> #include <QWidget> #include <QPus…

jdk版本更换以及遇到的问题略谈(以jdk1.8和jdk11为例)

目录 在我看来 遇到的问题 原因以及解决方法 方法一&#xff1a;禁止误改误删 方法二&#xff1a;bat文件驱动运行 方法三&#xff1a;cmd命令 方法四&#xff1a;修改注册表&#xff08;不推荐&#xff09; 最近在进行漏洞复现&#xff08;shiro550&#xff09;的时候&…

DAY20信息打点-红蓝队自动化项目资产侦察武器库部署企查产权网络空间

2.自动化-网络空间-AsamF 1.去GitHub上下载项目之后使用CMD打开 2.输入命令AsamF_windows_amd64.exe -v生成配置文件 3.AsamF会在~/.config/asamf/目录下生成config.json文件 C:\Users\Acer\.config\asamf 5.根据文档输入命令去查询所需信息&#xff08;已经没有用了&#x…

jeecg自定义sql查询,并使用高级查询的参数

jeecg框架的高级查询界面如下 后台代码&#xff1a; controller层&#xff1a; GetMapping(value "/list")public Result<IPage<Receiver>> queryPageList(Receiver receiver,RequestParam(name"pageNo", defaultValue"1") Integ…

进程和线程问题解答

线程和进程的概念、区别 进程是操作系统进行资源分配的基本单位&#xff0c;拥有独立的地址空间&#xff0c;包括代码、数据、堆、栈等。进程间的切换开销较大。 线程是进程中的一个执行单元&#xff0c;是系统中最小的执行单位&#xff0c;共享进程的资源&#xff0c;如代码…

Python类及元类的创建流程

Python类及元类的创建流程 代码运行结果再看type和object的关系和约定type和object具有的方法不一样看代码和运行结果&#xff0c;可以完全理解python的执行过程。再补充几点&#xff0c; 代码 class MetaCls(type):print(0>>>, MetaCls, 0)def __init__(self, name,…

基于微型5G网关的酒店服务机器人应用

智能机器人在酒店中已经越来越常见&#xff0c;并且也是提升客户体验、提高服务效率的重要工具。然而&#xff0c;尽管这些机器人在自动化服务方面可以发挥着重要作用&#xff0c;但它们仍然面临着一些通信、组网和在线管理方面的痛点。 针对这些难题&#xff0c;可以通过部署微…

Python边界值测试工具:生成指定大小文件

在进行软件测试的过程中&#xff0c;经常会有文件上传功能的需求&#xff08;例如头像上传、商品图上传等&#xff09;&#xff0c;这时候就需要考虑文件大小的边界值&#xff0c;例如只可上传1-2MB的图片&#xff0c;5-10MB的文件&#xff0c;想要验证需求的话&#xff0c;就需…

CVPT: Cross-Attention help Visual Prompt Tuning adapt visual task

论文汇总 当前的问题 图1:在VTAB-1k基准测试上&#xff0c;使用预训练的ViT-B/16模型&#xff0c;VPT和我们的CVPT之间的性能和Flops比较。我们将提示的数量分别设置为1、10、20、50,100,150,200。 如图1所示&#xff0c;当给出大量提示时&#xff0c;VPT显示了性能的显著下降…

电脑ip会因为换了网络改变吗

在当今数字化时代&#xff0c;IP地址作为网络世界中的“门牌号”&#xff0c;扮演着至关重要的角色。它不仅是设备在网络中的唯一标识&#xff0c;也是数据交换和信息传递的基础。然而&#xff0c;对于普通用户而言&#xff0c;一个常见的问题便是&#xff1a;当电脑连接到不同…

三菱变频器Modbus-RTU 通讯规格

能够从变频器的 RS-485 端子使用 Modbus-RTU 通讯协议&#xff0c;进行通讯运行和参数设定。 NOTE: 1、使用 Modbus-RTU 通讯协议时&#xff0c;请设定Pr.549 协议选择 “1” 2、从主机按地址0(站号0)进行hodbus-RTU通讯时&#xff0c;为广播通讯&#xff0c;变频器不向主机发…

QT编译后,如何手动运行,或复制到其他机器运行

编译后&#xff08;文件名叫Work.exe&#xff09;&#xff0c;通过QT功能&#xff0c;是可以成功运行的。如果在目录中双击&#xff0c;或复制到其他机器上运行&#xff0c;就会失败。怎么办&#xff1f; 打开命令窗口 运行命令 D:\Work\build\release>windeployqt Work.e…

写论文去哪个网站?2024最佳五款AI毕业论文学术网站

在2024年&#xff0c;AI技术在学术写作领域的应用已经变得越来越普遍。为了帮助学生和研究人员更高效地完成毕业论文的撰写任务&#xff0c;市场上涌现了许多优秀的AI论文写作工具。本文将重点推荐五款最佳的AI毕业论文学术网站&#xff0c;并特别强调千笔-AIPassPaper的优势。…

【应用开发三】 input子系统介绍

文章目录 1 名词解释2 输入设备编程框架2.1 input子系统2.2 读取数据流程2.3 input_event结构体2.3.1 type&#xff08;哪类事件&#xff09;2.2 code&#xff08;具体事件&#xff09;2.3 value&#xff08;数值&#xff09; 2.4 数据同步2.5 读取start input_event数据 1 名词…

【iOS】——YYModel源码总结

性能优化及优点 YYModel主要用于将JSON数据转换为模型对象&#xff0c;以及将模型对象转换为字典的库。相比于其他的数据转换库例如JSONModel&#xff0c;它更加的轻量级并且性能更高&#xff0c;因为它在很多地方做了优化&#xff1a; 通过CFDictionaryCreateMutable方法将数…

【数据结构与算法 | 灵神题单 | 二叉搜索树篇】力扣99, 1305, 230, 897

1. 力扣99&#xff1a;恢复二叉搜索树 1.1 题目&#xff1a; 给你二叉搜索树的根节点 root &#xff0c;该树中的 恰好 两个节点的值被错误地交换。请在不改变其结构的情况下&#xff0c;恢复这棵树 。 示例 1&#xff1a; 输入&#xff1a;root [1,3,null,null,2] 输出&…

【UWB无线载波通信技术】史上最详细解说!!

UWB定位技术的人员定位系统源码&#xff0c;高精度人员定位系统&#xff0c;自主研发&#xff0c;最高定位精度可达10cm&#xff0c;全套源码合作&#xff01; 目录 简介 基本原理 技术指标 应用 uwb定位系统应用场景 一、‌室内定位与导航‌&#xff1a; 二、‌特定行业…