CatBoost模型Python代码——用CatBoost模型实现机器学习

一、CatBoost模型简介

1.1适用范围

CatBoost(Categorical Boosting)是一种基于梯度提升的机器学习算法,特别适用于处理具有类别特征的数据集。它可以用于分类、回归和排序任务,并且在处理具有大量类别特征的数据时表现优异。典型应用包括但不限于:

  • 电子商务中的推荐系统
  • 客户行为分析
  • 财务风险评估
  • 医疗数据分析
1.2原理

CatBoost使用梯度提升决策树(GBDT)作为其核心算法。其主要特点包括:

  1. 处理类别特征:CatBoost原生支持类别特征,并在内部使用目标编码(target encoding)来处理它们,从而减少了类别变量处理的复杂性。
  2. 顺序增强(Ordered Boosting):在构建每棵树时,CatBoost通过引入一种新的顺序提升方法来避免传统梯度提升中的预测偏差问题。
  3. 随机分片:为了进一步减少过拟合,CatBoost在每次树构建时随机分割数据集。
1.3优点
  • 高效处理类别特征:无需复杂的预处理步骤。
  • 减少过拟合:通过顺序增强和随机分片技术。
  • 易于使用:内置了许多默认的优化参数,适合初学者和快速原型开发。
  • 高性能:在许多实际应用中表现优于其他GBDT算法(如XGBoost和LightGBM)。
1.4缺点
  • 模型训练时间较长:尽管有许多优化,训练时间可能比其他简单模型更长。
  • 内存占用较高:在处理大规模数据时,内存需求较大。

二、实现CatBoost模型的Python代码

下面是一个使用CatBoost进行分类任务的完整Python代码示例,包含详细注释。

2.1导入必要的包和测试数据
import pandas as pd
from catboost import CatBoostClassifier, Pool
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt
import seaborn as sns# 加载Titanic数据集
url = 'https://web.stanford.edu/class/archive/cs/cs109/cs109.1166/stuff/titanic.csv'
data = pd.read_csv(url)# 查看数据集的列名
print("Columns in the dataset:", data.columns)
2.2简单的数据预处理
# 简单的数据预处理
# 填充缺失值
# data['Age'].fillna(data['Age'].median(), inplace=True)
# data['Embarked'].fillna(data['Embarked'].mode()[0], inplace=True)# 将Sex和Embarked转换为类别型特征
data['Sex'] = data['Sex'].astype('category')
# data['Pclass'] = data['Pclass'].astype('Pclass')# 选择特征和目标
features = ['Pclass', 'Sex', 'Age', 'Siblings/Spouses Aboard', 'Parents/Children Aboard', 'Fare']
target = 'Survived'X = data[features]
y = data[target]
2.3构建CatBoost模型
# 分割数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 创建CatBoost数据池
categorical_features = ['Sex', 'Pclass']
train_pool = Pool(X_train, y_train, cat_features=categorical_features)
test_pool = Pool(X_test, y_test, cat_features=categorical_features)# 初始化并训练CatBoost分类器
model = CatBoostClassifier(iterations=1000,learning_rate=0.1,depth=6,loss_function='Logloss',  # 二分类任务使用'Logloss'verbose=100  # 每100次迭代打印一次信息
)# 训练模型
model.fit(train_pool)# 在测试集上进行预测
y_pred = model.predict(test_pool)
y_pred_proba = model.predict_proba(test_pool)[:, 1]
2.4模型评估
# 评估模型
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')
print(classification_report(y_test, y_pred))

模型评估输出结果如下 :

0:	learn: 0.6538633	total: 159ms	remaining: 2m 39s
100:	learn: 0.2814504	total: 891ms	remaining: 7.93s
200:	learn: 0.2007734	total: 1.68s	remaining: 6.68s
300:	learn: 0.1536222	total: 2.45s	remaining: 5.69s
400:	learn: 0.1220845	total: 3.19s	remaining: 4.77s
500:	learn: 0.0961718	total: 3.95s	remaining: 3.93s
600:	learn: 0.0810769	total: 4.7s	remaining: 3.12s
700:	learn: 0.0694396	total: 5.45s	remaining: 2.33s
800:	learn: 0.0598153	total: 6.2s	remaining: 1.54s
900:	learn: 0.0527771	total: 6.93s	remaining: 761ms
999:	learn: 0.0474017	total: 7.67s	remaining: 0us
Accuracy: 0.8033707865168539precision    recall  f1-score   support0       0.84      0.85      0.84       1111       0.74      0.73      0.74        67accuracy                           0.80       178macro avg       0.79      0.79      0.79       178
weighted avg       0.80      0.80      0.80       178Feature: Pclass, Importance: 16.480181005946406
Feature: Sex, Importance: 24.322199798316337
Feature: Age, Importance: 27.28642174968946
Feature: Siblings/Spouses Aboard, Importance: 5.125530737270014
Feature: Parents/Children Aboard, Importance: 3.006729091175773
Feature: Fare, Importance: 23.77893761760206
2.5可视化特征重要性(可选)
# 可视化特征重要性(可选)
plt.figure(figsize=(10, 6))
plt.barh(X.columns, feature_importances)
plt.xlabel('Feature Importance')
plt.title('CatBoost Feature Importances')
plt.show()

特征重要性输出结果如下:

 2.6绘制混淆矩阵
# 绘制混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

绘制混淆矩阵输出结果如下:

2.7绘制ROC曲线
# 绘制ROC曲线
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
roc_auc = auc(fpr, tpr)plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc='lower right')
plt.show()

绘制ROC曲线输出结果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1486132.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

安装好anaconda,打开jupyter notebook,新建 报500错

解决办法: 打开anaconda prompt 输入 jupyter --version 重新进入jupyter notebook: 可以成功进入进行代码编辑

你了解你的GD32 MCU系统主频是多少吗 ?

系统时钟是GD32 MCU的时基,可以理解为系统的心跳,片上所有的外设以及CPU最原始的时钟都来自于系统时钟,因而明确当前系统时钟是多少非常重要,只有明确了系统时钟,才能够实现准确的定时、准确的采样间隔以及准确的通信速…

回溯题目的套路总结

前言 昨天写完了LeeCode的7,8道回溯算法的题目,写一下总结,这类题目的共同特点就是暴力搜索问题,排列组合或者递归,枚举出所有可能的答案,思路很简单,实现起来的套路也很通用,一…

win10安装ElasticSearch7.x和分词插件

说明: 以下内容整理自网络,格式调整优化,更易阅读,希望能对需要的人有所帮助。 一 安装 Java环境 ElasticSearch使用Java开发的,依赖Java环境,安装 ElasticSearch 7.x 之前,需要先安装jdk-8。…

unity 实现图片的放大与缩小(根据鼠标位置拉伸放缩)

1创建UnityHelper.cs using UnityEngine.Events; using UnityEngine.EventSystems;public class UnityHelper {/// <summary>/// 简化向EventTrigger组件添加事件的操作。/// </summary>/// <param name"_eventTrigger">要添加事件监听的UI元素上…

系统架构设计师①:计算机组成与体系结构

系统架构设计师①&#xff1a;计算机组成与体系结构 计算机结构 计算机的组成结构可以概括为以下几个主要部分&#xff1a;中央处理器&#xff08;CPU&#xff09;、存储器&#xff08;包括主存和外存&#xff09;、输入设备、输出设备&#xff0c;以及控制器、运算器、总线和…

下拉菜单过渡

下拉过渡&#xff0c;利用Y轴的transform&#xff1a;scaleY(0) —》transform&#xff1a;scaleY(1) 代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8" /><meta name"viewport" cont…

C2W3.Assignment.Language Models: Auto-Complete.Part1

理论课&#xff1a;C2W3.Auto-complete and Language Models 文章目录 1 Load and Preprocess Data1.1: Load the data1.2 Pre-process the dataExercise 01.Split data into sentencesExercise 02.Tokenize sentencesExercise 03Split into train and test sets Exercise 04H…

2024.7.22 作业

1.将双向链表和循环链表自己实现一遍&#xff0c;至少要实现创建、增、删、改、查、销毁工作 循环链表 looplinklist.h #ifndef LOOPLINKLIST_H #define LOOPLINKLIST_H#include <myhead.h>typedef int datatype;typedef struct Node {union {int len;datatype data;}…

K8S 部署peometheus + grafana 监控

安装说明 如果有下载不下来的docker镜像可以私信我免费下载。 系统版本为 Centos7.9 内核版本为 6.3.5-1.el7 K8S版本为 v1.26.14 动态存储&#xff1a;部署文档 GitHub地址 下载yaml 文件 ## 因为我的K8S 版本比较新&#xff0c;我下载的是当前的最新版本&#xff0c;你的要…

【启明智显分享】甲醛检测仪HMI方案:ESP32-S3方案4.3寸触摸串口屏,RS485、WIFI/蓝牙可选

今年&#xff0c;“串串房”一词频繁引发广大网友关注。“串串房”&#xff0c;也被称为“陷阱房”“贩子房”——炒房客以低价收购旧房子或者毛坯房&#xff0c;用极度节省成本的方式对房子进行装修&#xff0c;之后作为精修房高价租售&#xff0c;因甲醛等有害物质含量极高&a…

自动驾驶---视觉Transformer的应用

1 背景 在过去的几年&#xff0c;随着自动驾驶技术的不断发展&#xff0c;神经网络逐渐进入人们的视野。Transformer的应用也越来越广泛&#xff0c;逐步走向自动驾驶技术的前沿。笔者也在博客《人工智能---什么是Transformer?》中大概介绍了Transformer的一些内容&#xff1a…

昇思MindSpore 应用学习-K近邻算法实现红酒聚类-CSDN

K近邻算法实现红酒聚类-AI代码解析 本实验主要介绍使用MindSpore在部分wine数据集上进行KNN实验。 1、实验目的 了解KNN的基本概念&#xff1b;了解如何使用MindSpore进行KNN实验。 2、K近邻算法原理介绍 K近邻算法&#xff08;K-Nearest-Neighbor, KNN&#xff09;是一种…

传神社区|数据集合集第7期|法律NLP数据集合集

自从ChatGPT等大型语言模型&#xff08;Large Language Model, LLM&#xff09;出现以来&#xff0c;其类通用人工智能&#xff08;AGI&#xff09;能力引发了自然语言处理&#xff08;NLP&#xff09;领域的新一轮研究和应用浪潮。尤其是ChatGLM、LLaMA等普通开发者都能运行的…

类和对象:完结

1.再深构造函数 • 之前我们实现构造函数时&#xff0c;初始化成员变量主要使⽤函数体内赋值&#xff0c;构造函数初始化还有⼀种⽅ 式&#xff0c;就是初始化列表&#xff0c;初始化列表的使⽤⽅式是以⼀个冒号开始&#xff0c;接着是⼀个以逗号分隔的数据成 员列表&#xf…

嵌入式C/C++、FreeRTOS、STM32F407VGT6和TCP:智能家居安防系统的全流程介绍(代码示例)

1. 项目概述 随着物联网技术的快速发展,智能家居安防系统越来越受到人们的重视。本文介绍了一种基于STM32单片机的嵌入式安防中控系统的设计与实现方案。该系统集成了多种传感器,实现了实时监控、报警和远程控制等功能,为用户提供了一个安全、可靠的家居安防解决方案。 1.1 系…

c++ 高精度加法(只支持正整数)

再给大家带来一篇高精度&#xff0c;不过这次是高精度加法&#xff01;话不多说&#xff0c;开整&#xff01; 声明 与之前那篇文章一样&#xff0c;如果看起来费劲可以结合总代码来看 定义 由于加法进位最多进1位&#xff0c;所以我们的结果ans[]的长度定义为两个加数中最…

【Linux】HTTP 协议

目录 1. URL2. HTTP 协议2.1. HTTP 请求2.2. HTTP 响应 1. URL URL 表示着是统一资源定位符(Uniform Resource Locator), 就是 web 地址&#xff0c;俗称“网址”; 每个有效的 URL 可以通过互联网访问唯一的资源, 是互联网上标准资源的地址; URL 的主要由四个部分组成: sche…

如何查看jvm资源占用情况

如何设置jar的内存 java -XX:MetaspaceSize256M -XX:MaxMetaspaceSize256M -XX:AlwaysPreTouch -XX:ReservedCodeCacheSize128m -XX:InitialCodeCacheSize128m -Xss512k -Xmx2g -Xms2g -XX:UseG1GC -XX:G1HeapRegionSize4M -jar your-application.jar以上配置为堆内存4G jar项…

广州邀请媒体宣传(附媒体名单)

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 广州地区 媒体邀约&#xff1a; 记者现场采访&#xff0c;电视台到场报道&#xff0c;展览展会宣传&#xff0c;广交会企业宣传&#xff0c;工厂探班&#xff0c;媒体专访等。 适合广州…