CatBoost中的预测偏移和排序提升

CatBoost 中,预测偏移(Prediction Shift)排序提升(Ordered Boosting) 是其关键概念和创新点。CatBoost 通过引入 排序提升 解决了梯度提升决策树(GBDT)算法中常见的 预测偏移问题,从而提高了模型的稳定性和性能。以下是对这两个概念的详细解释:


1. 预测偏移(Prediction Shift)

概念

预测偏移是指在梯度提升决策树(GBDT)训练过程中,由于在模型训练阶段同时使用特征和目标变量可能会导致未来信息泄漏,从而影响模型性能和稳定性。

  • 原因
    在标准 GBDT 算法中,训练样本的目标变量会被用来更新模型,而同时目标变量也会被用于特征变换(如目标统计编码)。这种特征变换过程中可能会使用目标变量的全局信息,从而导致未来样本的信息被泄漏到当前训练样本中。

  • 结果

    • 训练误差较低,但在测试集上表现较差(过拟合)。
    • 对目标变量统计不准确,尤其是分类特征的目标统计编码可能引入偏差。

示例

假设有一个分类特征 x x x 和目标变量 y y y

样本 i i i分类特征 x i x_i xi目标变量 y i y_i yi
1A1
2A0
3A1

如果在第 3 个样本训练过程中使用整个数据集的目标统计均值(包括第 3 个样本本身的 y 3 = 1 y_3 = 1 y3=1),则会导致信息泄漏。例如:

目标统计编码:
编码值 = 总目标值 总样本数 = 1 + 0 + 1 3 = 0.67 \text{编码值} = \frac{\text{总目标值}}{\text{总样本数}} = \frac{1 + 0 + 1}{3} = 0.67 编码值=总样本数总目标值=31+0+1=0.67

这会将第 3 个样本的目标变量泄漏到其特征变换中。


2. 排序提升(Ordered Boosting)

概念

排序提升是 CatBoost 提出的用于解决 预测偏移问题 的方法。其核心思想是在每一轮训练中,严格按照样本的时间或排列顺序,只使用当前样本之前的数据计算特征变换(如目标统计编码),避免了未来信息泄漏。


Ordered Boosting 的实现原理

CatBoost 的 排序提升 使用了一种特殊的数据划分和特征计算方式:

  1. 样本顺序化:

    • 假设样本被排列为 ( x σ 1 , y σ 1 ) , ( x σ 2 , y σ 2 ) , … , ( x σ n , y σ n ) (x_{\sigma_1}, y_{\sigma_1}), (x_{\sigma_2}, y_{\sigma_2}), \dots, (x_{\sigma_n}, y_{\sigma_n}) (xσ1,yσ1),(xσ2,yσ2),,(xσn,yσn),其中 σ \sigma σ 表示样本的排列顺序。
    • 在训练第 i i i 个样本时,仅使用前 i − 1 i-1 i1 个样本的数据来计算特征值。
  2. 目标统计的顺序计算:

    • 对于分类特征,目标统计值的计算严格遵循样本顺序。例如,计算第 i i i 个样本的目标统计值 T S ( x i ) TS(x_i) TS(xi) 时,仅基于样本 ( x 1 , y 1 ) , ( x 2 , y 2 ) , … , ( x i − 1 , y i − 1 ) (x_1, y_1), (x_2, y_2), \dots, (x_{i-1}, y_{i-1}) (x1,y1),(x2,y2),,(xi1,yi1) 的目标变量 y j y_j yj
    • 避免了将当前样本 y i y_i yi 或未来样本的目标值泄漏到统计值中。
  3. 模型更新的顺序化:

    • CatBoost 使用排序提升算法训练决策树时,每棵树的分裂决策仅基于当前模型状态和之前的数据更新。

排序提升算法的伪代码

如下图 14-2 中描述的伪代码:

  1. 对训练样本 ( x , y ) (x, y) (x,y) 按顺序排列为 ( x σ 1 , y σ 1 ) , ( x σ 2 , y σ 2 ) , … , ( x σ n , y σ n ) (x_{\sigma_1}, y_{\sigma_1}), (x_{\sigma_2}, y_{\sigma_2}), \dots, (x_{\sigma_n}, y_{\sigma_n}) (xσ1,yσ1),(xσ2,yσ2),,(xσn,yσn)
  2. 初始化模型 M 0 M_0 M0
  3. 对于每个样本 i i i
    • 根据模型状态 M i − 1 M_{i-1} Mi1 和前 i − 1 i-1 i1 个样本的目标变量 y σ j y_{\sigma_j} yσj 计算目标统计值。
    • 更新模型状态 M i = M i − 1 + Δ M M_i = M_{i-1} + \Delta M Mi=Mi1+ΔM,其中 Δ M \Delta M ΔM 是模型的增量更新(如一棵树的增量效果)。
  4. 输出最终模型 M n M_n Mn

排序提升的优点
  1. 避免信息泄漏:

    • 通过按顺序计算特征值和模型更新,确保每个样本的特征计算只依赖于之前的样本信息。
    • 解决了传统梯度提升算法中的预测偏移问题。
  2. 提高模型鲁棒性:

    • 排序提升能够更好地适应分类特征中高基数、稀疏类别的情况。
    • 即使样本数量有限,也能生成稳定的特征统计值。
  3. 改进模型性能:

    • 避免了模型过拟合,提升了测试集上的性能。

排序提升的一个例子

假设训练样本如下:

样本 i i i分类特征 x i x_i xi目标变量 y i y_i yi
1A1
2B0
3A1
4B1
目标统计值的计算:

对于分类特征 x x x,计算目标统计值 T S ( x i ) TS(x_i) TS(xi) 时:

  1. 第 1 行:

    • T S ( x 1 ) TS(x_1) TS(x1):没有之前的样本,所以使用全局均值 p p p
  2. 第 2 行:

    • T S ( x 2 ) TS(x_2) TS(x2):类别 B B B 的目标统计值基于之前样本:
      T S ( x 2 ) = p TS(x_2) = p TS(x2)=p
  3. 第 3 行:

    • T S ( x 3 ) TS(x_3) TS(x3):类别 A A A 的目标统计值基于第 1 行:
      T S ( x 3 ) = y 1 1 = 1 TS(x_3) = \frac{y_1}{1} = 1 TS(x3)=1y1=1
  4. 第 4 行:

    • T S ( x 4 ) TS(x_4) TS(x4):类别 B B B 的目标统计值基于第 2 行:
      T S ( x 4 ) = y 2 1 = 0 TS(x_4) = \frac{y_2}{1} = 0 TS(x4)=1y2=0

总结

  1. 预测偏移(Prediction Shift)

    • 是由于目标变量泄漏到特征变换中引起的模型训练问题,导致过拟合和不稳定性。
  2. 排序提升(Ordered Boosting)

    • 是 CatBoost 的核心创新,通过严格按照时间或排列顺序训练模型,避免了预测偏移问题。
    • 在分类特征处理、目标统计值计算和模型更新中都有应用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/13395.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【python系列】python内置函数print()和input()

1.前言 正式开始学习python编程基础知识,首先要建立正确的学习姿势,什么姿势呢,当然不是躺着。首先要学会看语法,学习每一个内置函数都要先把语法和语义理解,再结合勤于练习。有些同学可能英语不太好,这里…

并发基础:(淘宝笔试题)三个线程分别打印 A,B,C,要求这三个线程一起运行,打印 n 次,输出形如“ABCABCABC....”的字符串

🚀 博主介绍:大家好,我是无休居士!一枚任职于一线Top3互联网大厂的Java开发工程师! 🚀 🌟 在这里,你将找到通往Java技术大门的钥匙。作为一个爱敲代码技术人,我不仅热衷于探索一些框架源码和算法技巧奥秘,还乐于分享这些宝贵的知识和经验。 💡 无论你是刚刚踏…

字节、快手、Vidu“打野”升级,AI视频小步快跑

文|白 鸽 编|王一粟 继9月份版本更新之后,光锥智能从生数科技联合创始人兼CEO唐家渝朋友圈获悉,Vidu大模型将于本周再次进行版本升级,Vidu-1.5版本即将上线。 此版本更新方向仍是重点延伸大模型的泛化能力和主体…

redis实现消息队列的几种方式

一、了解 众所周知,redis是我们日常开发过程中使用最多的非关系型数据库,也是消息中间件。实际上除了常用的rabbitmq、rocketmq、kafka消息队列(大家自己下去研究吧~模式都是通用的),我们也能使用redis实现消息队列。…

JVM(一、基础知识)

JVM虚拟机的灵魂三问 JVM是什么? 广义上是一种规范,狭义上的是JDK中的JVM虚拟机,虚拟机模拟计算机的组成部分,可以运行我们写的应用程序,是对操作系统的一层抽象,把我们的应用程序和操作系统解耦&#xff0…

问题分析与解决:Android开机卡动画问题分析

1. 问题背景及描述 在一个android设备的开发的项目中遇到了一个比较典型的问题:在主板贴片完成后,首次刷入androdi固件验证时,遇到了按键出发开机后,系统启动到android动画界阶段时一直循环卡在此阶段,无法进入桌面。如下如所示: 此问题在许多android项目的首次点亮阶段均…

视频会议接入GB28181视频指挥调度,语音对讲方案

传统的视频会议指挥调度系统目前主流的互联网会议大部分都是私有协议,功能都很独立。目前主流的视频监控国标都最GB平台,新的需求要求融合平台要接入监控等设备,并能实现观看监控接入会议,实时语音设备指挥现场工作人员办公实施。…

跟着尚硅谷学vue2—进阶版1.0—组件化编程

2. Vue 组件化编程 1. 传统方式和使用组件方式编写的对比 1. 传统方式编写应用 2. 使用组件方式编写应用 2. 模块与组件、模块化与组件化 1. 模块 理解: 向外提供特定功能的 js 程序, 一般就是一个 js 文件为什么: js 文件很多很复杂作用: 复用 js, 简化 js 的编写, 提高 j…

WebRTC视频 01 - 视频采集整体架构

一、前言: 我们从1对1通信说起,假如有一天,你和你情敌使用X信进行1v1通信,想象一下画面是不是一个大画面中有一个小画面?这在布局中就叫做PIP(picture in picture);这个随手一点&am…

【大数据学习 | HBASE高级】rowkey的设计,hbase的预分区和压缩

1. rowkey的设计 ​ RowKey可以是任意字符串,最大长度64KB,实际应用中一般为10~100bytes,字典顺序排序,rowkey的设计至关重要,会影响region分布,如果rowkey设计不合理还会出现region写热点等一系列问题。 …

Spring Boot编程训练系统:架构设计与实现技巧

1系统概述 1.1 研究背景 随着计算机技术的发展以及计算机网络的逐渐普及,互联网成为人们查找信息的重要场所,二十一世纪是信息的时代,所以信息的管理显得特别重要。因此,使用计算机来管理编程训练系统的相关信息成为必然。开发合适…

刘知远LLM——大模型微调:prompt-learningdelta tuning

文章目录 背景&概览Prompt-learningdelta tuning增量式指定式重参数化式 OpenPrompt工具包 对应视频P41-P57 如何高效使用大模型?涉及到NLP的前沿技术,如prompt-learning&delta tuning。 prompt-learning对学习大模型范式的改变,del…

Spring Boot编程训练系统:性能优化实践

摘要 随着信息技术在管理上越来越深入而广泛的应用,管理信息系统的实施在技术上已逐步成熟。本文介绍了编程训练系统的开发全过程。通过分析编程训练系统管理的不足,创建了一个计算机管理编程训练系统的方案。文章介绍了编程训练系统的系统分析部分&…

电子应用产品设计方案-4:基于物联网和人工智能的温度控制器设计方案

一、概述 本温度控制器旨在提供高精度、智能化、远程可控的温度调节解决方案,适用于各种工业和民用场景。 二、系统组成 1. 传感器模块 - 采用高精度的数字式温度传感器,如 TMP117,能够提供精确到 0.01C 的温度测量。 - 配置多个传感器分布在…

如何在 Ubuntu 24.04 上安装和配置 Fail2ban ?

确保你的 Ubuntu 24.04 服务器的安全是至关重要的,特别是如果它暴露在互联网上。一个常见的威胁是未经授权的访问尝试,特别是通过 SSH。Fail2ban 是一个强大的工具,可以通过自动阻止可疑活动来帮助保护您的服务器。 在本指南中,我…

同三维T610UDP-4K60 4K60 DP或HDMI或手机信号采集卡

1路DP/HDMI/TYPE-C(手机/平板等)视频信号输入1路MIC1路LINE OUT,带1路HDMI环出,USB免驱,分辨率4K60,可采集3路信号中其中1路,按钮切换,可采集带TYPE-C接口的各品牌手机/平板/笔记本电脑等 同三维…

Kafka--关于broker的夺命连环问

目录 1、zk在kafka集群中有何作用 2、简述kafka集群中的Leader选举机制 3、kafka是如何处理数据乱序问题的。 4、kafka中节点如何服役和退役 4.1 服役新节点 1)新节点准备 2)执行负载均衡操作 4.2 退役旧节点 5、Kafka中Leader挂了,…

Web项目版本更新及时通知

背景 单页应用,项目更新时,部分用户会出更新不及时,导致异常的问题。 技术方案 给出版本号,项目每次更新时通知用户,版本已经更新需要刷新页面。 版本号更新方案版本号变更后通知用户哪些用户需要通知?…

Android音视频直播低延迟探究之:WLAN低延迟模式

Android WLAN低延迟模式 Android WLAN低延迟模式是 Android 10 引入的一种功能,允许对延迟敏感的应用将 Wi-Fi 配置为低延迟模式,以减少网络延迟,启动条件如下: Wi-Fi 已启用且设备可以访问互联网。应用已创建并获得 Wi-Fi 锁&a…

Appium配置2024.11.12

百度得知:谷歌从安卓9之后不再提供真机layout inspector查看,仅用于支持ide编写的app调试用 所以最新版android studio的android sdk目录下已经没有了布局查看工具... windows x64操作系统 小米k30 pro手机 安卓手机 Android 12 第一步&#xff1a…