模型Alignment之RLHF与DPO

1. RLHF (Reinforcement Learning from Human Feedback)

RLHF 是一种通过人类反馈来强化学习的训练方法,它能够让语言模型更好地理解和执行人类指令。

RLHF 的三个阶段

RLHF 的训练过程一般分为三个阶段:

  1. 监督微调(Supervised Fine-Tuning, SFT)

    • 目的:让模型初步具备按照人类指令生成文本的能力。
    • 数据:使用大量的人工标注数据,包含输入的 prompt 和对应的期望输出。
    • 训练:将这些数据作为监督学习任务,对预训练的大语言模型进行微调。
    • 结果:得到一个初步具有指令跟随能力的模型。
  2. 奖励模型训练(Reward Model Training)

    • 目的:训练一个模型来评估模型生成的文本质量。
    • 数据:收集模型在 SFT 阶段生成的多个不同回复,由人类标注人员对这些回复进行排序,以表示它们相对于给定 prompt 的优劣。
    • 训练:将这些排序数据作为训练数据,训练一个奖励模型。奖励模型的输出是一个标量值,表示生成的文本的质量。
    • 结果:得到一个能够对文本质量进行打分的奖励模型。
  3. 强化学习微调(Reinforcement Learning Fine-Tuning)

    • 目的:使用奖励模型的反馈来进一步优化模型的生成能力。
    • 方法:采用强化学习算法(如 PPO),将语言模型作为策略,奖励模型作为价值函数。
    • 过程
      • 模型生成文本。
      • 奖励模型对生成的文本打分。
      • 根据奖励信号,更新模型的参数,使其生成更高质量的文本。
    • 结果:得到一个在人类反馈下表现更优的语言模型。

技术细节

  • 奖励模型:奖励模型通常是一个分类模型,它学习将不同的文本输出映射到一个连续的奖励值。常用的模型架构包括:
    • 基于 Transformer 的模型:与语言模型类似,具有强大的序列处理能力。
    • 对比学习模型:通过比较不同文本输出的相似性来学习奖励函数。
  • 强化学习算法:PPO(Proximal Policy Optimization)是一种常用的强化学习算法,它能够在保证策略稳定性的同时,高效地更新策略。
  • 数据收集:在 RLHF 的过程中,需要不断地收集新的数据来训练奖励模型和更新策略。这些数据可以来自以下几个方面:
    • 人工标注:由人类标注人员对模型生成的文本进行评估。
    • 用户反馈:收集用户在实际使用中的反馈。
    • 模型自生成:模型通过自生成的方式产生大量数据。

2. PPO在RLHF中的应用

PPO算法概述

PPO(Proximal Policy Optimization)是一种常用的强化学习算法,在RLHF中,它被用来优化语言模型,使其生成的文本能最大化人类反馈的奖励。

核心思想:

  • 策略更新: 通过不断调整模型的参数,使得模型生成的文本能获得更高的奖励。
  • 近端策略更新: 为了保证策略的稳定性,PPO限制了新旧策略之间的差异,避免模型发生剧烈变化。

PPO在RLHF中的具体步骤

  1. 采样数据:

    • 使用当前的语言模型生成多个文本样本。
    • 将这些样本输入到奖励模型中,获得对应的奖励分数。
  2. 计算优势函数:

    • 优势函数表示一个动作的好坏程度相对于平均动作的偏离程度。
    • 在RLHF中,优势函数可以表示为:
      • 优势函数 = 奖励 - 基线
    • 基线通常是所有样本奖励的平均值或一个估计值。
  3. 更新策略:

    • 概率比: 计算新旧策略下,生成相同文本的概率比。
    • 裁剪概率比: 为了防止策略更新过大,将概率比裁剪到一个合理范围内。
    • 计算损失函数:
      • 损失函数通常包含两项:
        • 策略损失: 鼓励模型生成高奖励的文本。
        • KL散度: 限制新旧策略之间的差异。
    • 更新模型参数: 使用梯度下降法来最小化损失函数,从而更新模型的参数。

损失函数的具体形式

PPO的损失函数可以写成如下形式:

L(θ) = 𝔼[min(r_t(θ) * A_t, clip(r_t(θ), 1 - ε, 1 + ε) * A_t)] - β * KL[π_θ, π_θ_old]
  • r_t(θ): 概率比,表示新旧策略下生成相同动作的概率比。
  • A_t: 优势函数。
  • clip: 裁剪操作,将概率比裁剪到[1-ε, 1+ε]的范围内。
  • β: KL散度的系数,用于控制新旧策略之间的差异。
  • KL[π_θ, π_θ_old]: 新旧策略之间的KL散度。

  • 第一项: 鼓励模型生成高奖励的文本。当优势函数为正时,希望概率比越大越好;当优势函数为负时,希望概率比越小越好。
  • 第二项: 限制新旧策略之间的差异,保证策略的稳定性。

3. DPO (Direct Preference Optimization)

DPO的工作原理

DPO的核心思想是:通过比较不同文本生成的优劣,直接优化模型参数。具体来说,DPO会收集大量的文本对,其中每一对文本代表着人类对两个文本的偏好。然后,DPO会训练模型,使得模型能够对新的文本对进行排序,并尽可能地与人类的偏好一致。

DPO与RLHF的区别

特点RLHFDPO
奖励模型需要训练奖励模型无需训练奖励模型
优化目标最大化奖励信号直接优化人类偏好
训练过程两阶段训练(预训练+强化学习)单阶段训练
  • 与RLHF相比,DPO旨在简化过程,直接针对用户偏好优化模型,而不需要复杂的奖励建模和策略优化
  • 换句话说,DPO专注于直接优化模型的输出,以符合人类的偏好或特定目标
  • 如下所示是DPO如何工作的概述

DPO没有再去训练一个奖励模型,使用奖励模型更新大模型,而是直接对LLM进行微调。
实现DPO损失的具体公式如下所示:

  • “期望值” E \mathbb{E} E是统计学术语,表示随机变量的平均值或平均值(括号内的表达式);优化 − E -\mathbb{E} E使模型更好地与用户偏好保持一致
  • π θ \pi_{\theta} πθ变量是所谓的策略(从强化学习借用的一个术语),表示我们想要优化的LLM; π r e f \pi_{ref} πref是一个参考LLM,这通常是优化前的原始LLM(在训练开始时, π θ \pi_{\theta} πθ π r e f \pi_{ref} πref通常是相同的)
  • β \beta β是一个超参数,用于控制 π θ \pi_{\theta} πθ和参考模型之间的分歧;增加 β \beta β增加差异的影响
    π θ \pi_{\theta} πθ π r e f \pi_{ref} πref在整体损失函数上的对数概率,从而增加了两个模型之间的分歧
  • logistic sigmoid函数 σ ( ⋅ ) \sigma(\centerdot) σ()将首选和拒绝响应的对数优势比(logistic sigmoid函数中的项)转换为概率分数

DPO需要两个LLMs,一个策略(policy)模型(我们想要优化的模型)还有一个参考(reference)模型(原始的模型,保持不变)。
我们得到两个模型的输出后,对其输出的结果计算softmax并取log,然后通过target取出预测目标对应的数值。(其实就是做了一个交叉熵,和交叉熵的计算过程一模一样)。通过这个过程我们可以得到每个模型在每个回答上的 π \pi π,于是代入公式计算结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1545110.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

echarts 导出pdf空白原因

问题阐述 页面样式: 导出pdf: 导出pdf,统计图部分为空白。 问题原因 由于代码中进行了dom字符串的复制,而echarts用canvas绘制,canvas内部内容不会进行复制,只会复制canvas节点,因此导出pdf空白。 解决…

1. IP地址介绍

IP地址 一、网络概述1、网络类型2、网络组成、传输介质2.1 组成2.2 传输介质 二、IP地址1、IP地址的表示方法2、IP地址的组成3、IP地址的类型3.1 根据IP地址第一个字节大小来分3.1.1 单播地址 Unicast 3.2 根据IP地址的使用 三、子网掩码 netmask1、默认的子网掩码2、判断多个I…

游戏开发2025年最新版——八股文面试题(unity,虚幻,cocos都适用)

1.静态合批与动态合批的原理是什么?有什么限制条件?为什么?对CPU和GPU产生的影响分别是什么? 原理:Unity运行时可以将一些物体进行合并,从而用一个描绘调用来渲染他们,就是一个drawcall批次。 限…

MyBatis—Plus 快速上手【后端 22】

MyBatis-Plus 使用入门指南 前言 在Java的持久层框架中,MyBatis因其灵活性和易用性而广受欢迎。然而,随着项目规模的扩大,MyBatis的一些重复性工作(如CRUD操作)开始显得繁琐。为了解决这一问题,MyBatis-Pl…

图论系列(dfs)9/24

岛屿问题: 二叉树dfs遍历的框架代码: 要有一个终止条件、访问相邻节点; public void dfs(Treenode root){if(rootnull)return;dfs(root.left);dfs(root.right);} 网格dfs遍历的框架代码: public void dfs(int[][] grid,int x,int y){//如果x、y坐标不在网格里面 …

专业学习|随机规划概观(内涵、分类以及例题分析)

一、随机规划概览 (一)随机规划的定义 随机规划是通过考虑随机变量的不确定性来制定优化决策的一种方法。其基本思想是在决策过程中,目标函数和约束条件可以包含随机因素。 (1)重点 随机规划的中心问题是选择参数&am…

学习一下怎么用git

目录 初始化操作 设置名字: 设置邮箱: 查询状态 初始化本地仓库 清空git bush控制台 git的三个区域 文件提交 将会文件提交到暂存区 暂存指定文件 暂存所有改动文件 查看暂存区里面的文件 将文件提交到版本库 git文件状态查看 ​编辑 暂存区的相关指令…

时序预测:LSTM、ARIMA、Holt-Winters、SARIMA模型的分析与比较

引言 近年来,民航旅客周转量一直是衡量国家或地区民航运输总量的重要指标之一。为了揭示民航旅客周转量背后的规律和趋势,本研究旨在综合分析1990年至2023年的相关数据。 通过单位根检验和序列分解,我们确定了民航旅客周转量数据的非平稳性&…

MySQL(面试题 - 同类型归纳面试题)

目录 一、MySQL 数据类型 1. 数据库存储日期格式时,如何考虑时区转换问题? 2. Blob和text有什么区别? 3. mysql里记录货币用什么字段类型比较好? 4. MySQL如何获取当前日期? 5. 你们数据库是否支持emoji表情存储&…

也遇到过 PIL Image “image file is truncated“的问题

背景前言 属于活久见系列,最近工作上遇了该问题! 背景:前端 APP使用 Android CameraX 的接口,拍摄并上传图片,然后 Python后端服务对图片裁剪与压缩处理。后端服务处理图片时有遇到image file is truncated的情况。还…

Leetcode 螺旋矩阵

算法思想: 这个算法的目标是按照顺时针螺旋的顺序从矩阵中取出元素。为了做到这一点,整个思路可以分成几个关键步骤: 定义边界:首先需要定义四个边界变量: left:当前左边界的索引。right:当前右…

R语言机器学习遥感数据处理与模型空间预测技术及实际项目案例分析

随机森林作为一种集成学习方法,在处理复杂数据分析任务中特别是遥感数据分析中表现出色。通过构建大量的决策树并引入随机性,随机森林在降低模型方差和过拟合风险方面具有显著优势。在训练过程中,使用Bootstrap抽样生成不同的训练集&#xff…

夜间车辆 信号灯识别检测数据集 共3500张 YOLO数据集

夜间车辆 信号灯识别检测数据集 共3500张 YOLO数据集 夜间车辆与交通信号识别检测数据集(Nighttime Vehicle & Traffic Signal Recognition Dataset) 数据集概述 这是一个专为夜间环境设计的车辆和交通信号识别检测数据集,共包含3500张…

将python代码文件转成Cython 编译问题集

准备setup.py from distutils.core import setup from Cython.Build import cythonize import glob# 指定目标目录 python setup.py build -c mingw32 target_dir "src"# 使用glob模块匹配目录中的所有.pyx文件 pyx_files glob.glob(target_dir "/**/*.py&q…

基于STM32F103C8T6单片机的农业环境监测系统设计

本设计是基于STM32F103C8T6单片机的农业环境监测系统,能够完成对作物的生长环境进行信息监测和异常报警,并通过手机APP来实现查看信息和设定阈值的功能。为了实现设计的功能,该系统应该有以下模块:包括STM32单片机模块、水环境PH值…

STM32基础学习笔记-ADC面试基础题6

第六章、ADC 常见问题 1、基本概念:什么是ADC ?作用 ?逐次逼近型 2、传感器本质 ?传感器、电压、ADC数值转化 ? 3、ADC的特征 ? 转化时间、分辨率、精度、量化误差 ? 4、ADC框图组成部分 &…

如何安全有效地进行Temu自养号测评,提升账号权重防关联

在当今市场环境中,许多现成的系统或软件包往往缺乏全面的风险控制能力。掌握自养号测评技术,确保在运营过程中减少对外部系统的依赖。以下是搭建安全、高效运营环境的详细指导,特别针对手机端与电脑端环境的设置,以及关键资源的获…

计算机毕业设计Hadoop+Spark知识图谱体育赛事推荐系统 体育赛事热度预测系统 体育赛事数据分析 体育赛事可视化 体育赛事大数据 大数据毕设

《HadoopSpark知识图谱体育赛事推荐系统》开题报告 一、研究背景及意义 随着互联网技术的迅猛发展和大数据时代的到来,体育赛事数据的数量呈爆炸式增长。用户面对海量的体育赛事信息,常常感到信息过载,难以快速找到感兴趣的赛事内容。如何高…

虚拟机屏幕分辨率自适应VMWare窗口大小

文章目录 环境问题解决办法其它虚拟机和主机间复制粘贴 参考 环境 Windows 11 家庭中文版VMWare Workstation 17 ProUbuntu 24.04.1 问题 虚拟机的屏幕大小,是固定的。如下图,设置的分辨率是800*600,效果如下: 可见&#xff0c…

【PyTorch】数据读取和处理

数据读取机制DataLoader与Dataset 数据处理过程 DataLoader torch.utils.data.DataLoader 功能:构建可迭代的数据装载器 dataset:Dataset类,决定数据从哪里读取及如何读取batchsize:批大小num_works:是否多进程读取…