【机器学习】随机森林算法

随机森林(Random Forest)是一种集成学习算法,它结合了多个决策树的输出,以提高预测的准确性和稳定性。随机森林被广泛应用于分类和回归任务中,尤其适用于数据特征之间存在非线性关系或噪声的情况。

在本文中,我们将详细讲解随机森林的原理,并用Numpy实现一个基本的回归随机森林。最后,我们将展示如何使用Scikit-Learn实现随机森林。

随机森林的基本原理

随机森林是由 多棵决策树 组成的集成模型,通过以下步骤生成:

  1. 样本随机抽样(Bootstrap Sampling)

    • 从原始数据集中随机抽取若干个样本,生成多个不同的数据集(可以重复抽样)。
    • 对每个数据集生成一棵决策树模型。
  2. 特征随机选择(Random Feature Selection)

    • 在每个节点分裂时,从所有特征中随机选择一部分特征进行分割,选择使得分裂效果最好的特征。
    • 这样可以降低决策树之间的相关性,提升模型泛化能力。
  3. 模型集成

    • 对于分类任务,通过“投票”机制(多数表决)确定最终分类结果。
    • 对于回归任务,通过取多个决策树预测值的平均值得到最终预测结果。

随机森林的优点包括:

  • 减少模型的方差,提高泛化性能。
  • 不容易出现过拟合,尤其在数据量大、噪声多的情况下。

构建回归随机森林

我们将分以下步骤逐步实现一个简单的随机森林回归模型。

数据生成

首先,生成一组模拟数据,以便后续测试模型的效果。

import numpy as np
import matplotlib.pyplot as plt# 生成模拟数据
np.random.seed(0)
X = np.random.rand(100, 1) * 10  # 特征
y = 2 * X.flatten() + np.sin(X.flatten()) * 5 + np.random.randn(100) * 0.5  # 目标值

构建单棵决策树

在随机森林中,我们需要基于Bootstrap采样数据构建多棵决策树。这里我们实现回归树的基本构建方法,使用均方误差(MSE)作为分割标准:

# 均方误差(MSE)计算
def mean_squared_error(y):return np.var(y) * len(y)# 数据集分割
def split_dataset(X, y, feature, threshold):left_mask = X[:, feature] <= thresholdright_mask = ~left_maskreturn X[left_mask], y[left_mask], X[right_mask], y[right_mask]# 查找最佳分割特征和分割点
def best_split(X, y):best_mse = float("inf")best_feature, best_threshold = None, Nonefor feature in range(X.shape[1]):thresholds = np.unique(X[:, feature])for threshold in thresholds:_, y_left, _, y_right = split_dataset(X, y, feature, threshold)if len(y_left) == 0 or len(y_right) == 0:continuemse_split = mean_squared_error(y_left) + mean_squared_error(y_right)if mse_split < best_mse:best_mse = mse_splitbest_feature = featurebest_threshold = thresholdreturn best_feature, best_threshold# 决策树类
class RegressionTree:def __init__(self, max_depth=3, min_samples_split=2):self.max_depth = max_depthself.min_samples_split = min_samples_splitself.tree = Nonedef fit(self, X, y, depth=0):if len(y) < self.min_samples_split or depth >= self.max_depth:return np.mean(y)feature, threshold = best_split(X, y)if feature is None:return np.mean(y)left_X, left_y, right_X, right_y = split_dataset(X, y, feature, threshold)left_node = self.fit(left_X, left_y, depth + 1)right_node = self.fit(right_X, right_y, depth + 1)self.tree = {"feature": feature, "threshold": threshold, "left": left_node, "right": right_node}return self.treedef predict_sample(self, x, tree):if not isinstance(tree, dict):return treeif x[tree["feature"]] <= tree["threshold"]:return self.predict_sample(x, tree["left"])else:return self.predict_sample(x, tree["right"])def predict(self, X):return np.array([self.predict_sample(x, self.tree) for x in X])

构建随机森林模型

基于上面的决策树实现,我们可以通过多次 Bootstrap 采样来构建随机森林的模型。通过组合多棵决策树的预测结果,提升模型的稳定性。

class RandomForestRegressor:def __init__(self, n_estimators=10, max_depth=3, min_samples_split=2):self.n_estimators = n_estimatorsself.max_depth = max_depthself.min_samples_split = min_samples_splitself.trees = []def bootstrap_sample(self, X, y):indices = np.random.choice(len(y), len(y), replace=True)return X[indices], y[indices]def fit(self, X, y):self.trees = []for _ in range(self.n_estimators):X_sample, y_sample = self.bootstrap_sample(X, y)tree = RegressionTree(max_depth=self.max_depth, min_samples_split=self.min_samples_split)tree.fit(X_sample, y_sample)self.trees.append(tree)def predict(self, X):tree_predictions = np.array([tree.predict(X) for tree in self.trees])return np.mean(tree_predictions, axis=0)

训练与预测

接下来,我们用随机森林模型拟合数据,并可视化预测结果:

# 初始化并训练随机森林
forest = RandomForestRegressor(n_estimators=50, max_depth=4, min_samples_split=5)
forest.fit(X, y)# 预测并绘制结果
X_test = np.linspace(0, 10, 100).reshape(-1, 1)
y_pred = forest.predict(X_test)plt.scatter(X, y, color="blue", label="训练数据")
plt.plot(X_test, y_pred, color="red", label="随机森林预测")
plt.xlabel("特征")
plt.ylabel("目标值")
plt.title("随机森林回归预测")
plt.legend()
plt.show()

使用 Scikit-Learn 实现随机森林

Scikit-Learn 提供了一个简单易用的 RandomForestRegressor,用于快速实现和测试随机森林模型。我们可以用它来验证我们的手动实现。

from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error# 使用Scikit-Learn的随机森林
regressor = RandomForestRegressor(n_estimators=50, max_depth=4, min_samples_split=5, random_state=0)
regressor.fit(X, y)# 预测并计算MSE
y_pred_sklearn = regressor.predict(X_test)
mse = mean_squared_error(y, regressor.predict(X))
print("均方误差:", mse)# 可视化
plt.scatter(X, y, color="blue", label="训练数据")
plt.plot(X_test, y_pred_sklearn, color="green", label="Scikit-Learn 随机森林预测")
plt.xlabel("特征")
plt.ylabel("目标值")
plt.title("Scikit-Learn 随机森林预测示意图")
plt.legend()
plt.show()

总结

本文详细介绍了随机森林的工作原理,从基本概念到使用 Bootstrap 样本构建决策树的过程,手动实现了回归的随机森林算法,并用 Scikit-Learn 的 RandomForestRegressor 进行对比。随机森林算法的优势在于其高效的集成学习策略,有助于提升模型的泛化能力并减少过拟合风险。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/4593.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Golang--运算符

1、算术运算符 算术运算符&#xff1a; &#xff0c;-&#xff0c;*&#xff0c;/&#xff0c;%&#xff0c;&#xff0c;--&#xff0c;对数值类型的变量进行运算 package mainimport ("fmt" )func main() {//算术运算符// - * / % --//号在golang中表示正号&…

论文阅读:DynamicDet: A Unified Dynamic Architecture for Object Detection

论文地址&#xff1a;[2304.05552] DynamicDet: A Unified Dynamic Architecture for Object Detection 代码地址&#xff1a;GitHub - VDIGPKU/DynamicDet: [CVPR 2023] DynamicDet: A Unified Dynamic Architecture for Object Detection 概要 本文提出了一种名为 DynamicD…

Flutter 正在切换成 Monorepo 和支持 workspaces

其实关于 Monorepo 和 workspaces 相关内容在之前《Dart 3.5 发布&#xff0c;全新 Dart Roadmap Update》 和 《Flutter 之 ftcon24usa 大会&#xff0c;创始人分享 Flutter 十年发展史》 就有简单提到过&#xff0c;而目前来说刚好看到 flaux 这个新进展&#xff0c;所以就再…

expand,None索引,permute【pytorch】

torch.expand 输入必须是一个向量或等价形式&#xff0c;扩展的最后一个维度与输入大小一致 当输入形状为&#xff08;1&#xff0c;1,1,1,1&#xff0c;……&#xff0c;3&#xff09;_4时。 expand的最后一位输入向量的元素个数&#xff08;长度&#xff09;&#xff08;3&…

GEE | 对Landsat 8 影像进行缨帽变换

基于Landsat 8 影像的缨帽变换 var roi ee.FeatureCollection(users/yongweicao11/Dongguan2022); // Landsat 8 的缨帽变换系数矩阵var Landsat8TC ee.Array([[0.3029, 0.2786 , 0.4733, 0.5599, 0.5082, 0.1872],[-0.2941, -0.2435, -0.5424, 0.7276, 0.0713, -0.1608],[0.…

Obsidian的Vim插件设置配置全流程 -- 脱离鼠标拥抱Vim神教

Obsidian的Vim插件设置配置全流程 -- 脱离鼠标拥抱Vim神教 参考文章引言1. vim 及 vimrc 介绍2. 开启 Obsidian 内置的 Vim3. vimrc 插件的获取和安装4. vimrc 插件的设置5. vimrc 配置文件的设置附件 参考文章 vim 常见操作 Obsidian插件安装教程 引言 vim 很好用&#xff…

6.《双指针篇》---⑥和为S的两个数字(中等但简单)(牛客)

题目传送门 方法一&#xff1a;暴力解法。双循环 方法二&#xff1a;双指针&#xff08;推荐&#xff09; 1.定义一个顺序表&#xff0c;定义左右双指针 2.while循环。判断array[left] array[right] 的值。 3.若等于则将这两个值加入数组。并break 4.若大于则right-- 5.若小于…

LeetCode994. 腐烂的橘子(2024秋季每日一题 54)

在给定的 m x n 网格 grid 中&#xff0c;每个单元格可以有以下三个值之一&#xff1a; 值 0 代表空单元格&#xff1b;值 1 代表新鲜橘子&#xff1b;值 2 代表腐烂的橘子。 每分钟&#xff0c;腐烂的橘子 周围 4 个方向上相邻 的新鲜橘子都会腐烂。 返回 直到单元格中没有…

【51蛋骗鸡一个独立按键控制流水灯开关】2022-1-18

缘由一个独立按键控制流水灯开关-编程语言-CSDN问答 #include<reg52.h>//头文件 sbit k1P3^7;// void main() //主函数 {unsigned char sj0, ls0;unsigned int ys0;P00;/*P0255;*/while(1){if(!k1&&!sj){if(!ls){ls1;/*P00;*/}else ls0;while(!k1);}if(…

shodan(五)连接Mongodb数据库Jenkinsorg、net、查看waf命令

声明&#xff1a;学习素材来自b站up【泷羽Sec】&#xff0c;侵删&#xff0c;若阅读过程中有相关方面的不足&#xff0c;还请指正&#xff0c;本文只做相关技术分享,切莫从事违法等相关行为&#xff0c;本人一律不承担一切后果 引言&#xff1a; 1.Shodan 是一个专门用于搜索连…

lvgl白屏问题(LCD长时间白屏)和优化lvgl

开机白屏时间过长 -- 这里我们不考虑是lvgl占的内存太大的问题&#xff0c;这里考虑的是为什么lcd屏幕启动后会有长时间的白屏。 首先我们要了解lvgl的相关操作&#xff0c;主要集中在一个函数中。只有程序执行到了这个函数&#xff0c;lvgl的屏幕才会显现出来 总结来说就是l…

公网ip和弹性公网ip有什么区别?哪个更好

公网ip和弹性公网ip有什么区别&#xff1f;公网IP和弹性公网IP都是用于互联网通信的IP地址&#xff0c;但它们在灵活性、成本和管理方式上有所不同。公网IP是直接分配给设备的静态IP地址&#xff0c;适用于需要固定外部访问的场景&#xff0c;但可能面临安全风险和设置复杂性。…

DevOps-课堂笔记

各种 aaS 类比于计算机网络的 OSI 参考模型&#xff0c;一个软件应用项目需要不同的支撑层&#xff0c;例如从下至上大概需要&#xff1a; 硬件层面的服务器针对硬件做弹性分配的虚拟化机制&#xff0c;例如虚拟机在虚拟化环境内运行的 OS支撑软件应用的中间件&#xff0c;例…

游戏想实习但定位不清的问题

国内的游戏大厂包括腾讯、网易、盛趣游戏、西山居、米哈游、莉莉丝、完美世界、游族、心动、叠纸、三七、TapTap、Tap4fun、字节跳动、哔哩哔哩、funplus、巨人、IGG、沐瞳等。而国外的游戏大厂则有育碧、EA、拳头、supercell、暴雪、R星、卡普空、任天堂、波兰蠢驴等。 一般来…

Dubbo使用Nacos作为注册中心

使用 Nacos 作为注册中心实现自动服务发现 本示例演示 Nacos 作为注册中心实现自动服务发现&#xff0c;示例基于 Spring Boot 应用展开&#xff0c;可在此查看 完整示例代码 1 基本配置 1.1 增加依赖 增加 dubbo、nacos-client 依赖&#xff1a; <dependencies><…

css基础

文章目录 基础 基础 配置网页的cion图标 在网站根目录下放置 favicon.ico 文件&#xff0c;浏览器在加载网页的时候会自动加载的。这个图片只能是 ico 格式&#xff0c;并且只能叫这个名字 如: css项目的根目录下 刷新网站没有生效&#xff0c;需要强制刷新&#xff0c;shif…

Lucene的Directory的详细使用与性能测试(6)

文章目录 第6章 Directory6.1 Directory介绍6.1.1 FSDirectory1&#xff09;SimpleFSDirectory&#xff1a;2&#xff09;NIOFSDirectory&#xff1a;3&#xff09;MMapDirectory&#xff1a;4&#xff09;FSDirectory子类对比 6.2.2 RAMDirectory 6.2 Directory性能测试环境搭…

HTML+javaScript+CSS

文章目录 HTMLjavaScriptCSS属性区块表单层叠样式表选择器常用属性盒子模型相关属性浮动float定位&#xff08;position&#xff09; JS操作节点事件点击事件onclick()聚焦事件、失焦事件鼠标移入移出事件 定时任务延迟定时任务重复定时任务 判断哪个单选框被选中设置按钮失效冒…

Linux系统每日定时备份mysql数据

一、创建存储脚本的文件夹 创建文件夹&#xff0c;我的脚本放在/root/dbback/mysql mkdir ... cd /root/dbback/mysql 二、编写脚本 vi backup_mysql.sh 复制脚本内容 DB_USER"填写用户名" DB_PASSWORD"填写密码" DB_NAME"数据库名称" # …

【计算机网络】零碎知识点(易忘 / 易错)总结回顾

一、计算机网络的发展背景 1、网络的定义 网络是指将多个计算机或设备通过通信线路、传输协议和网络设备连接起来&#xff0c;形成一个相互通信和共享资源的系统。 2、局域网 LAN 相对于广域网 WAN 而言&#xff0c;局域网 LAN 主要是指在相对较小的范围内的计算机互联网络 …