机器学习(二)——线性回归模型、多分类学习(附核心思想和Python实现源码)

目录

  • 关于
  • 1. 基本形式
  • 2. 线性回归
    • 2.1 单变量线性回归
    • 2.2 多元线性回归
    • 2.2 对数线性回归
  • 3. 对数几率回归
  • 4. 线性判别分析
  • 5. 多分类学习
    • 5.1 拆分策略
  • 6. 类别不平衡问题
  • X 案例代码
    • X.1 源码
    • X.2 数据集(糖尿病数据集)
    • X.3 模型效果


关于

  • 本文是基于西瓜书(第三章)的学习记录。讲解线性模型的重要概念、Python实现代码。
  • 西瓜书电子版:百度网盘分享链接

1. 基本形式

  • 线性模型的核心思想是使用输入属性的线性组合来预测输出。假设我们有一个示例 a = ( a 1 , a 2 , … , a d ) a=(a_1,a_2,\ldots,a_d) a=(a1,a2,,ad) ,其中 d 是属性的数量。线性模型可以表示为:
    f ( a ) = w 1 a 1 + w 2 a 2 + … + w d a d + b f(a) = w_1 a_1 + w_2 a_2 + \ldots + w_d a_d + b f(a)=w1a1+w2a2++wdad+b这里 ( w 1 , w 2 , … , w d ) ( w_1, w_2, \ldots, w_d ) (w1,w2,,wd)是模型的权重, b b b是偏置项。权重决定了每个属性对预测结果的影响程度,而偏置项则允许模型在没有输入时有一个非零的预测值
  • 线性模型形式简单、易于建模,但却蕴涵着机器学习中一些重要的基本思想。许多功能更为强大的非线性模型可在线性模型的基础上通过引入层级结构或高维映射而得。
  • 由于直观表达了各属性在预测中的重要性,因此线性模型有很好的可解释性。

2. 线性回归

  • 线性回归是线性模型中的一种,它的目标是预测一个连续的输出值

2.1 单变量线性回归

  • 在最简单的情况下,我们只有一个输入属性。我们的目标是找到一条直线,使得预测值 f ( z ) = w z + b f(z) = w z + b f(z)=wz+b尽可能接近真实标记 y 。这里,我们使用均方误差(MSE)作为性能度量,并试图最小化它:
    ( w ∗ , b ∗ ) = arg ⁡ min ⁡ w , b ∑ i = 1 m ( y i − ( w a i + b ) ) 2 (w^*, b^*) = \arg\min_{w, b} \sum_{i=1}^m (y_i - (w a_i + b))^2 (w,b)=argw,bmini=1m(yi(wai+b))2
  • 均方误差有非常好的几何意义,它对应了常用的欧几里得距离或简称"欧氏距离"
  • 最小二乘法:基于均方误差最小化来进行模型求解的方法称为“最小二乘法”。在线性回归中,最小二乘法就是试图找到一条直线,使所有样本到直线上的欧氏距离之和最小。
  • 求解过程称为线性回归模型的最小二乘参数估计

2.2 多元线性回归

  • 当输入属性不止一个时,我们使用最小二乘法来估计模型参数。数据集 D 被表示为一个 m*(d+1) 大小的矩阵 X ,其中每行对应一个示例,最后一列恒为1,用于偏置项 b 。我们的目标是最小化均方误差:
    min ⁡ w , b ∑ i = 1 m ( y i − ( w T a i + b ) ) 2 \min_{w, b} \sum_{i=1}^m (y_i - (w^T a_i + b))^2 w,bmini=1m(yi(wTai+b))2
  • 数据集表示的矩阵X的表示:
    在这里插入图片描述

2.2 对数线性回归

  • 模型公式: ln ⁡ y = w T x + b \ln y=\boldsymbol{w}^\mathrm{T}\boldsymbol{x}+b lny=wTx+b。在形式上仍是线性回归,但实质上已是在求取输入空间到输出空间的非线性函数映射
  • 实际上是在试图让 e w T x + b e^{\boldsymbol{w}^\mathrm{T}\boldsymbol{x}+b} ewTx+b逼近 y y y
  • 示意图:
    在这里插入图片描述

3. 对数几率回归

  • 对数几率回归是用于二分类问题的线性模型,它通过将线性回归模型的预测值转换为0/1值来实现分类
  • 对数几率函数(Sigmoid函数)是实现这一转换的关键: P ( y = 1 ∣ a ) = 1 1 + e − ( w T a + b ) P(y=1|a) = \frac{1}{1 + e^{-(w^T a + b)}} P(y=1∣a)=1+e(wTa+b)1,其图像如下:
    在这里插入图片描述
    其中 z = ( w T a + b ) z = (w^T a + b) z=(wTa+b),即回归模型的预测值,这个函数将任何实数值的预测转换为0和1之间的概率值
  • 实际就是用线性回归模型的预测结果去逼近真实标记的对数几率,因此,其对应的模型称为"对数几率回归"
  • 虽然它的名字是“回归”,但实际却是一种分类学习方法
  • 它不是仅预测出“类别”,而是可得到近似概率预测,这对许多需利用概率辅助决策的任务很有用
  • sigmoid函数是任意阶可导的凸函数,有很好的数学性质,现有的许多数值优化算法都可直接用于求取最优解.

4. 线性判别分析

  • 核心思想:线性判别分析(LDA)是一种经典的线性学习方法,它试图找到一个投影方向,使得同类样本在这个方向上的投影尽可能接近,而异类样本的投影尽可能远离。在对新样本进行分类时,将其投影到同样的这条直线上,再根据投影点的位置来确定新样本的类别。原理图如下:
    在这里插入图片描述

5. 多分类学习

  • 多分类学习是将线性模型应用于具有多个类别的问题。
  • 多分类学习的基本思路是“拆解法”,即将多分类任务拆为若干个二分类任务求解,为拆出的每个二分类任务训练一个分类器

5.1 拆分策略

  • 一对一(OvO):为每一对类别训练一个分类器,这样N个类别就会产生N*(N-1)/2二分类任务,测试时通过投票机制确定最终类别。
  • 一对其余(OvR):为每个类别训练一个分类器,每次将一个类的样例作为正例、所有其他类的样例作为反例来训练N 个分类器,选择置信度最大的类别标记作为分类结果。
  • 多对多(MvM):每次将多个类别作为正类,其余作为反类。显然,OvO和OvR是 MvM的特例。

6. 类别不平衡问题

  • 定义:类别不平衡是指不同类别的训练样例数目差异很大的情况。这可能会导致模型偏向于多数类,因为模型的预测倾向于预测出现频率更高的类别。

  • 处理这一问题的基本策略包括:

    • 欠采样:减少多数类的样本数量。如EasyEnsemble利用集成学习机制,将多数类划分为若干个集合供不同的学习器使用,每个学习器使用部分集合,虽然每个学习器是欠采样,但是总的来看不会丢失重要信息
    • 过采样:增加少数类的样本数量。如SMOTE算法对少数类进行插值来产生额外的样例。
    • 阈值移动:调整分类阈值以平衡类别。在类别不平衡的情况下,模型学习到的概率分布可能会偏向于多数类.

X 案例代码

X.1 源码

import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score# 1. 加载数据集
diabetes = load_diabetes()
X, y = diabetes.data, diabetes.target
print("此时X,y的数据类型为:", type(X), type(y), '\n')# 2. 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("此时X_train,y_train的数据类型为:", type(X_train), type(y_train), '\n')
print("X_train的前10条数据展示:")
print(pd.DataFrame(X_train).head(10).to_string(index=False, justify='left'), '\n')# 3. 构建并训练多元线性回归模型
model = LinearRegression()
model.fit(X_train, y_train)# 4. 预测测试集上的目标变量
y_pred = model.predict(X_test)# 5. 评估模型性能
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)print("多元线性回归模型性能:")
print(f"Mean Squared Error: {mse:.2f}")
print(f"R^2 Score: {r2:.2f}", '\n')# 6. 绘制实际值与预测值的散点图
plt.figure(figsize=(10, 6))
plt.scatter(y_test, y_pred, alpha=0.7)
plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], color='red', linestyle='--')
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.title('Actual vs Predicted Values for Diabetes Dataset')
plt.grid(True)
plt.tight_layout()
plt.show()# 可选:查看模型的系数和截距
print("模型参数:")
print(f"Coefficients: {model.coef_}")
print(f"Intercept: {model.intercept_}", '\n')# 可选:将结果保存到DataFrame中以便进一步分析
results = pd.DataFrame({'Actual': y_test,'Predicted': y_pred
})
print("模型预测结果:")
print(results)

X.2 数据集(糖尿病数据集)

  • 糖尿病数据集包含442名患者的10项生理特征,目标是预测一年后疾病水平的定量测量值。这些特征经过了标准化处理,使得每个特征的平均值为零,标准差为1。

  • 概览

    • 样本数量:442个样本
    • 特征数量:10个特征
    • 目标变量:1个目标变量(一年后疾病水平的定量测量值)
  • 特征描述

    1. 年龄 (age):患者年龄(已标准化)
    2. 性别 (sex):患者性别(已标准化)
    3. 体质指数 (bmi):身体质量指数(已标准化)
    4. 血压 (bp):平均动脉压(已标准化)
    5. S1:血清测量值1(已标准化)
    6. S2:血清测量值2(已标准化)
    7. S3:血清测量值3(已标准化)
    8. S4:血清测量值4(已标准化)
    9. S5:血清测量值5(已标准化)
    10. S6:血清测量值6(已标准化)
  • 目标变量

    • 一年后疾病水平的定量测量值:这是模型需要预测的目标变量。
  • 使用

    • 可以使用 sklearn.datasets.load_diabetes() 函数来加载这个数据集,并查看其详细信息。

X.3 模型效果

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/3972.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【机器学习】22. 聚类cluster - K-means

聚类cluster - K-means 1. 定义2. 测量数据点之间的相似性3. Centroid and medoid4. Cluster之间距离的测量方式5. 聚类算法的类别6. K-mean7. 如何解决中心初始化带来的影响8. K-means问题:处理空集群9. 离群值的问题10. Bisecting K-means(二分K-means…

【python_pandas_将列表按照某几列进行分组,再求和,按照原列表的字段顺序返回】

说明: 1、按照[“行描述”,”‘公司代码’, ‘科目代码’, ‘预算项目代码’] 进行分组。 2、对“贷方”列进行求和。 3、最后按照之前的表头顺序进行排序,返回结果列表。 #-*- coding:utf-8-*import pandas as pd def consolidate_salary_provisions(l…

【Vue框架】基础语法练习(1)

其实更多知识点已经在Vue.js官网十分清楚了,大家也可以去官网进行更细节的学习 https://cn.vuejs.org/ 说明:目前最新是Vue3版本的,但是Vue2已经深得人心,所以就是可以支持二者合用。它们最大的区别就是Vue3是组合式API&#xf…

公司如何防止员工泄密?十佳措施拒绝泄密,公司防泄密刻不容缓! (2024最强科普)

如何有效防止员工泄露机密? 作为公司的经营者,您是否意识到了商业秘密的重要性?您是否已经知道应该采取什么样的措施才能保护好自己的商业秘密? 员工的泄密行为不仅可能造成重大的经济损失,还会对企业的声誉造成严重…

[大模型]视频生成-Sora简析

参考资料: Sora技术报告https://openai.com/index/video-generation-models-as-world-simulators/4分钟详细揭密!Sora视频生成模型原理https://www.bilibili.com/video/BV1AW421K7Ut 一、概述 相较于Gen-2、Stable Diffusion、Pika等生成模型的前辈&am…

linux中级(防火墙firewalld)

一。firewalld与iptables区别1.firewalld可以动态修改单条规则,不需要像iptables那样,修改规则后必须全部刷新才可生效。firewalld默认动作是拒绝,则每个服务都需要去设置才能放行,而iptables里默认是每个服务是允许,需…

【C/C++】memcpy函数的使用

零.导言 当我们学习了strcpy和strncpy函数后,也许会疑惑整形数组要如何拷贝,而今天我将讲解的memcpy函数便可以拷贝整形数组。 一.memcpy函数的使用 memcpy函数是一种C语言内存函数,可以按字节拷贝任意类型的数组,比如整形数组。 …

软件测试用例设计:从功能测试到边界值分析

功能测试介绍 功能测试是软件测试的一种重要方式,通过对软件的功能进行测试,来验证软件是否满足需求规格说明书中的各项功能要求。例如,对于一个简单的计算器软件,功能测试的用例可能包括加减乘除等基本运算,以及各种特…

[论文阅读]BERT-based Lexical Substitution

BERT-based Lexical Substitution 基于BERT的词汇替换 ACL2019 BERT-based Lexical Substitution - ACL Anthology 以前关于词汇替换的研究倾向于通过从词汇资源(例如 WordNet)中找到目标词的同义词来获得替代候选词,然后根据其上下文对候…

【Java SE 】特殊报错机制 ---> 异常 !

🔥博客主页🔥:【 坊钰_CSDN博客 】 欢迎各位点赞👍评论✍收藏⭐ 目录 1. 异常概念 1.1 算术异常 1.2. 空指针异常 1.3 数组越界异常 2. 异常的分类 2.1 编译时产生的异常 2.2 运行时产生的异常 3. 如何处理异常 3.1 异常…

使用kettle同步数据流程

使用kettle同步数据流程 一.Kettle软件安装(解压即可使用) 1.windows安装解压 pdi-ce-8.2.0.0-342.zip,点Spoon.bat启动kettle 2.Linux安装 把data-integration目录所有文件上传到服务器 二.安装数据库驱动把需要的…

两级运放的电路版图设计

电路版图文件PDK,88出,点击此处获取,24h秒发 PDF文件免费,已绑定 《集成电路版图设计课程》 课程设计(大作业)报告 2023 - 2024 学年第 1 学期 题 目 CMOS运算放大器的电路设计 专 业 …

Windows不支持配置NFS?还有什么注意事项?

我们前面介绍了如果配置Windows Server的NFS共享(Windows Server2012 R2搭建NFS服务器),也介绍了Linux如何配置NFS共享(CentOS 7搭建NFS服务器)。但是,我最近发现一个问题,那就是桌面版的Window…

解锁测试能力密码:直击三问,成就卓越测试

在测试人眼中真的是“万物皆可测”,不管是物体(铅笔,桌子)、终端(手机,电脑)、软件代码、硬件设备等等。那是因为在底层逻辑中,我们搞清楚了其核心本质,总结起来有三个方…

AutoCAD2021

链接: https://pan.baidu.com/s/1GG93ZFRfV_30xTWtDiv3Ew 提取码: dx8i 简介:一键安装,已经破解。支持W7-w10-w11系统64位

伍光和《自然地理学》电子书(含考研真题、课后习题、章节题库、模拟试题)

《自然地理学》(第4版)由伍光和、王乃昂、胡双熙、田连恕、张建明合著,于2018年11月出版。作为普通高等教育“十一五”国家级规划教材,本书不仅适用于高校地球科学各专业的基础课程,还可供环境、生态等有关科研、教学人…

奥数与C++小学四年级(第十七题 弹跳板)

参考程序代码&#xff1a; #include <iostream> bool visited[101] {false}; // 标记1-100是否被访问过int main() {int step 1; // 初始步数int i 2; // 步长visited[1] true; // 标记位置1已访问while (true) {step i; // 跳到下一个位置if (step >…

206面试题(28~46)

206道Java面试题&#xff08;28~46&#xff09; 28.Array和ArrayList有什么区别&#xff1f; 一、基本性质 Array(数组) Array是一种固定大小的数据结构。 用于存储多个相同类型的元素。 创建时需要指定数组的大小&#xff0c;且长度定义完后不能改变。 ArrayList(动态数组)…

“大跳水”的全新奥迪A3,精准狙击年轻人的心

文/王俣祺 导语&#xff1a;随着传统豪华品牌在国内市场的全面崩盘&#xff0c;奥迪再一次坐不住了。这次&#xff0c;奥迪“割肉”的目标瞄准了被称为“年轻人第一台豪车”的奥迪A3&#xff0c;这款车问世以来&#xff0c;就凭借出色的性能与品质收获了一大批年轻粉丝。如今&a…

网站建设公司怎么选?网站制作公司怎么选才不会出错?

寻找适合靠谱的网站设计公司&#xff0c;不要盲目选广告推最多的几家&#xff0c;毕竟要实现自身品牌营销&#xff0c;还是需要多方面考量。以下几个方面可以作为选择的参考&#xff1a; 1. 专业能力如何&#xff1f; 一个公司的专业能力&#xff0c;决定了最后网站设计的成果…