10.解析解方法推导线性回归——不容小觑的线性回归算法

引言

线性回归是许多复杂机器学习模型的基础。作为一种基本的机器学习方法,线性回归提供了清晰的思路和工具,通过理解其推导过程,可以更好地掌握机器学习的基本原理和模型设计。

通过阅读本篇博客,你可以:

1.学会如何用解析解的方法推导线性回归的最优解

2.了解如何判定损失函数是凸函数或非凸

一、解析解的推导

通过上一篇9.深入线性回归推导出MSE——不容小觑的线性回归算法-CSDN博客的讲解,我们已经得到了线性回归的损失函数形式,也明确了目标就是最小化损失函数,那么问题就变成了 \theta 什么时候可以使得损失函数最小。

1.最小二乘形式变化

我们已知损失函数公式为 :

J(\theta ) = \frac{1}{2}\sum_{m}^{i=1}(h_{\theta }x_{i}-y_{i})^{2} 

我们将损失函数变化一个形式,变为以下:

\frac{1}{2}(X\theta - y)^{T}(X\theta -y)

其中 X 为变量集,即所有 x_{i} 样本行程的 m 行 n 列的样本矩阵\theta 即我们要求的最优解,它是一个 m 行 1 列的矩阵。这个公式是如何变化而来的呢?

首先原损失函数中的 h_{\theta } 是线性回归模型中的预测函数,h_{\theta }x_{i} 用来表示预测值 \hat{y_{i}} 。所以我们可以得出以下结论:

\hat{y_{i}} = h_{\theta }x_{i} = x_{i}\theta

\hat{y} = h_{\theta }X = X\theta

得到这个结论之后,我们回归到公式本身:

J(\theta ) = \frac{1}{2}\sum_{m}^{i=1}(h_{\theta }x_{i}-y_{i})^{2} 

\Rightarrow J(\theta ) = \frac{1}{2} \sum_{m}^{i=1}(h_{\theta }x_{i}-y_{i})(h_{\theta }x_{i}-y_{i})

将上述公式代入,又由于矩阵的性质,我们需要将其中一项转置,这里就相当于一个长度为 m 的向量乘以它自己,说白了就是对应位置相乘相加。

所以我们的公式变为:

J(\theta ) = \frac{1}{2}(X\theta - y)^{T}(X\theta - y)

由矩阵运算的基本性质

可继续推出公式:

J(\theta) = \frac{1}{2}((X\theta)^{T} - y^{T} )(X\theta - y)

\Rightarrow J(\theta) = \frac{1}{2} (\theta^{T}X^{T} - y^{T})(X\theta - y)

最终,我们得到:

J(\theta) = \frac{1}{2}(\theta^{T}X^{T}X\theta - \theta^{T}X^{T}y-y^{T}X\theta+y^{T}y)

2.推导出模型的解析解形式

假使我们开着小车,从下图中寻找最优解。为了便于理解,我们假设存在横轴表示 \theta ,存在纵轴表示 loss损失,曲线是 loss function

我们把最小二乘看成是一个函数曲线,最优解一定是驻点中某个极小值(驻点顾名思义就是小车可以停驻的点)。从图中我们可以看出,驻点的特定是梯度全为0(梯度:函数在某点上的切线的斜率)。

所以要求出 \theta 的解析解形式,我们就可以通过把函数的一阶导函数推导出来,再使其的值为0以求出 \theta 。依据以下求导公式:

我们能将公式进行推导:

{J}'(\theta) = \frac{1}{2}\left [ {(\theta^{T}X^{T}X\theta)}' - {(\theta^{T}X^{T}y)}'-{(y^{T}X\theta)}' + {(y^{T}y)}'\right ]

由于 X 和 y 是已知的,\theta 是我们要求的答案。所以和 \theta 没关系的部分在求导时可以忽略不计,继续推导为以下公式:

\Rightarrow {J}'(\theta) = \frac{1}{2}[2X^{T}X\theta-X^{T}y-(y^{T}X)^T]

\Rightarrow {J}'(\theta) = \frac{1}{2}[2X^{T}X\theta-2X^{T}y]

\Rightarrow {J}'(\theta) = X^{T}X\theta - X^{T}y

然后我们设置导函数为0,去进一步解出来驻点对应的 \theta 值为多少:

0 = X^{T}X\theta - X^{T}y

\Rightarrow X^{T}X\theta = X^{T}y

由于矩阵与逆矩阵相乘可以得到单位矩阵,所以我们最终可以求出 \theta 的解析解形式(解析解为方程的解析式,是方程的精确解,能在任意精度下满足方程):

\theta = (X^{T}X)^{-1}X^{T}y

这样,我们有数据集 X ,y 时,就可以将数据代入上面解析解公式,去直接求出对应的 \theta 值了。比如我们可以设想 X 为 m 行 n 列的矩阵,y 为 m 行 1 列的列向量。X^{T} 是 n 行 m 列的,所以 X^{T}X 就是 n 行 n 列的矩阵。又因为矩阵求逆形状不变,再次乘以 X^{T} 后变为 n 行 m 列的矩阵。最后乘以 y,结果  \theta 就是 n 行 1 列的列向量!

二、判断损失函数是否为凸函数

对于求解最优解而言,判断一个损失函数是否为凸函数是极其重要的。如果一个损失函数是凸函数,那么局部最优解即为全局最优解,这是因为在凸函数上没有局部极小值的存在,所有的局部极小值都位于全局最小值处。

如上图所示,左上和右下为非凸函数,左下和右上为凸函数。在非凸函数中,有很多条极值点,我们无法直接得到最优解。对于二次可微的函数,我们可以通过判断黑塞矩阵(hessian matrix)是否为半正定的来进行判断,所以我们要对目标函数在点 x 处的二阶偏导数进行求解。

对于我们的式子来说,就是在导函数的基础上再次对 \theta 进行求偏导,由于 X^{T}y 对 \theta 的导数为0,所以再次求偏导后的答案为 X^{T}X。所谓的正定就是答案的特征值全为正数,而半正定无非就是特征值大于等于0。这里我们对损失函数求二阶导的黑塞矩阵是 X^{T}X,自己和自己做点乘,所以答案一定是半正定的。

在此处我们不用深入去讨论数学推导的证明。在机器学习中,损失函数往往是凸函数,在深度学习中的损失函数往往是非凸函数。并且在实际应用当中,我们并不要求找到全局最优解,只要模型适用。机器学习的特点就是不强调模型 100% 正确,而是有价值的,堪用的

三、代码实战求解线性回归算法模型

经过大片的理论讲解,相信大家已经对线性回归模型的实现有了深刻的认识,接下来我们就要通过代码的形式来实战求解线性回归模型

1.导入需要使用的库

import numpy as np
import matplotlib.pyplot as plt

我们需要使用numpy模块进行矩阵之间的运算,最后用matplotlib模块中的绘图功能绘制 X 与 y 的关系图。

2.定义样本集

# 回归,有监督监督机器学习,X,y
X = 2 * np.random.rand(100,1)
y = 5 * 4 * X + np.random.randn(100,1)

这里的 X 是所有 x_{i} 组成的100行1列的矩阵。y 是真实值,5是偏置(截距),4是 X 的权重,后面的 np.random.randn(100,1) 是100行1列以正态分布形成的误差矩阵

3.实现解析解公式求解模型

# 为了求解W0截距项,我们给X矩阵加上一列全为1的X0
X_b = np.c_[np.ones((100,1)),X]

上述代码是通过 np.c_ 的方式将截距项恒为1的权重拼接到 X 矩阵中。

# 实现解析解的公式来求解θ
θ = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y)
print(θ)
"""[[5.21509616][3.77011339]]
"""

我们将得到的 \theta 的解析解公式通过代码去实现,np.linalg.inv() 是numpy模块中用来计算逆矩阵的函数,X_b.T 是变量 X_b 的转置,.dot() 是进行点乘运算(对numpy模块的讲解在专栏前面的文章当中7.科学计算模块Numpy(4)ndarray数组的常用操作(二)_ndarray逐元素相加-CSDN博客)。我们通过以上代码就可以表示公式:

\theta = (X^{T}X)^{-1}X^{T}y

输出 \theta 之后我们可以看到,截距项与权重都是相当接近真实情况,但由于误差的存在,我们不可能得到真实值,只能拟合数据得到最优解。

4.使用模型去预测

X_new = np.array([[0], [2]])
X_new_b = np.c_[np.ones((2, 1)), X_new]
print(X_new_b)
"""[[1. 0.][1. 2.]]
"""

我们定义一个2行1列的矩阵作为要预测样本的自变量。再加上截距恒为1的项,就会形成一个2行2列的矩阵。

y_predict = X_new_b.dot(θ)
print(y_predict)
"""[[ 5.21509616][12.75532293]]
"""

随后,我们计算预测值,也就是 \hat{y} 。

y_predict = X_new_b.dot(θ)
print(y_predict)
"""[[ 5.21509616][12.75532293]]
"""

只要使用刚刚拼接完的矩阵点乘 \theta 就可以得到预测值了。

5.绘图

plt.plot(X_new, y_predict, 'r-')
plt.plot(X, y, 'b.')
plt.axis([0, 2, 0, 15])
plt.show()

这边我们使用到了matplotlib模块中的绘图功能,成功绘制了下方的坐标图。其中,横轴表示了输入 x 的值,纵轴表示了 y 的值,红线代表了整体的函数,蓝色的点则是真实值的分布情况。我们可以发现,红色的直线尽可能地穿过了蓝色的点,这就是我们一直说的线性回归模型。

总结

这篇博客讲述了模型解析解的推导原理以及代码实现。希望可以对大家起到作用,谢谢。


关注我,内容持续更新(后续内容在作者专栏《从零基础到AI算法工程师》)!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/147128.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

对抗攻击的详细解析:原理、方法与挑战

对抗攻击的详细解析:原理、方法与挑战 对抗攻击(Adversarial Attack)是现代机器学习模型,尤其是深度学习模型中的一个关键安全问题。其本质在于,通过对输入数据添加精微的扰动,人类难以察觉这些扰动&#…

Cyber Weekly #25

赛博新闻 1、阿里云Qwen2.5发布!再登开源大模型王座,Qwen-Max性能逼近GPT-4o 阿里云在云栖大会上宣布通义千问发布新一代开源模型Qwen2.5,开源旗舰模型Qwen2.5-72B性能超越Llama3.1-405B,再次登上全球开源大模型的王座&#xff…

【优选算法之二分查找】No.5--- 经典二分查找算法

文章目录 前言一、二分查找模板:1.1 朴素二分查找模板1.2 查找区间左端点模板1.3 查找区间右端点模板 二、二分查找示例:2.1 ⼆分查找2.2 在排序数组中查找元素的第⼀个和最后⼀个位置2.3 搜索插⼊位置2.4 x 的平⽅根2.5 ⼭脉数组的峰顶索引2.6 寻找峰值…

鸿蒙开发(NEXT/API 12)【跨设备互通NDK开发】协同服务

跨设备互通提供跨设备的相机、扫描、图库访问能力,平板或2in1设备可以调用手机的相机、扫描、图库等功能。 说明 本章节以拍照为例展开介绍,扫描、图库功能的使用与拍照类似。 用户在平板或2in1设备上使用富文本类编辑应用(如:…

深入理解 flex-grow、flex-shrink、flex-basis

目录 1. Flex布局 2. flex-grow 计算方式 3. flex-shrink 计算公式 4. flex-basis 5. 应用场景 6. 总结 1. Flex布局 Flex 是 Flexible Box 的缩写,意为"弹性布局",用来为盒状模型提供最大的灵活性 flex属性是flex-grow, flex-shrink 和…

vscode安装ESLint与Vetur插件后自动修复代码不生效

vscode安装ESLint与Vetur插件后自动修复代码不生效 1、安装ESLint 和 Vuter 2、运行结果 2.1、代码保存时代码中的分号;能被检测出来,但是不会自动修复 2.2、手动运行ESLint 修复命令(在终端中执行 npx eslint . --fix)可以修复问题 3、解决办法 在.vscode目录下setti…

Spring Boot 3.x 配置 Spring Doc以及导入postman带图详解

一、pom.xml配置 <!-- API⽂档⽣成&#xff0c;基于swagger3 --><dependency><groupId>org.springdoc</groupId><artifactId>springdoc-openapi-starter-webmvc-ui</artifactId><version>2.0.2</version></dependency>…

STL-set/multiset关联式容器

目录 一、常见接口 1.0 迭代器 1.1 构造函数 1.2 增删查 1.3 查找和统计 二、multiset 2.1 构造 2.2 查找 2.3 删除 2.4 统计 关联式容器是⽤来存储数据的&#xff0c;与序列式容器不同的是&#xff0c;关联式容器逻辑结构通常是⾮线性结构&#xff0c;两个位置有紧密…

JSP分页功能实现案例:从基础到应用的全面解析

想要实现基于jsp的分页功能&#xff1a; 需要从数据库中获取数据&#xff0c;并在前端页面中分页展示 基于JDBC访问MySQL数据库&#xff0c;获取数据基于JSP处理数据并展示 本质上是JSP的一种开发模式&#xff08;即JSPJavaBean&#xff09; 第一步&#xff1a;创建JavaWeb项目…

gitlab 的CI/CD (二)

前言 上文完成了gitlab-runner的基础配置及将gitlab的制品上传至软件包库&#xff08;产品库&#xff09;的脚本编写&#xff1b; 本文实现gitlab的ci/cd对远程服务器的操作&#xff1b; 介绍 要让Gitlab Runner部署到远程机器&#xff0c;远程机器必须信任gitlab runner账…

C++标准库容器类——string类

引言 在c中&#xff0c;string类的引用极大地简化了字符串的操作和管理&#xff0c;相比 C 风格字符串&#xff08;char*或cahr[]&#xff09;&#xff0c;std::string 提供了更高效和更安全的字符串操作。接下来让我们一起来深入学习string类吧&#xff01; 1.string 的构造…

一种WLAN用户综合认证系统及其方法(本人专利号 201110408124.X)

一种WLAN用户综合认证系统及其方法(本人专利号 201110408124.X&#xff09; 本发明公开了一种WLAN用户综合认证系统及其方法&#xff0c;涉及移动通信技术领域。本系统包括WLAN终端与AP子系统和外部认证中心&#xff1b;设置有认证协议分析引擎单元和用户综合控制单元&#xff…

c/c++内存管理(详解) + new与delete的用法及底层

1:c/c内存分布情况 1.1:c/c内存的分布图 1.2:每个区域的用途及不同类型变量存储在那个区 1.3:例题讲解 2:c动态内存管理方式(new delete) 2.1:new的语法 2.2:delete的语法 3:operator new函数与operator delete函数 4:new与delete的实现原理 5:定位new表达式初识 6:mallo…

python+selenium实现自动联网认证,并实现断网重连

pythonselenium实现自动联网认证&#xff0c;并实现断网重连 echo off python “E:\autoD\auto_login.py” 要使自动登录脚本在系统重启后自动运行&#xff0c;你可以使用Windows的任务计划程序来设置。以下是详细的步骤&#xff1a; 1. 保存脚本 首先&#xff0c;将你的Py…

【高分系列卫星简介——高分二号卫星(GF-2)】

高分二号卫星&#xff08;GF-2&#xff09; 高分二号&#xff08;GF-2&#xff09;卫星是中国自主研制的首颗空间分辨率优于1米的民用光学遥感卫星&#xff0c;具有亚米级空间分辨率、高定位精度和快速姿态机动能力等特点&#xff0c;达到了国际先进水平。以下是对高分二号卫星…

对Spring-AI系列源码的讲解

前言 今天&#xff0c;我们将开启对Spring-AI系列源码的讲解。请大家不急不躁&#xff0c;我会逐步深入&#xff0c;每次专注于一个知识点&#xff0c;以防让人感到困惑。 首先&#xff0c;源码的讨论自然离不开自动装配。有人可能会问&#xff0c;之前已经讲解过这个内容了&…

【JavaSE】八种基本数据类型及包装类

数据类型字节数位数值范围包装类默认值整型byte18-128&#xff0c;127Byte0short216&#xff0c;Short0int432&#xff0c;Integer0long864&#xff0c;Long0L浮点型float432Float0.0fdouble864Double0.0d布尔型boolean18true falseBooleanfalse字符型char2160&#xff0c;Char…

C++编程语言:基础设施:异常处理(Bjarne Stroustrup)

第 13 章 异常处理(Exception Handling) 目录 13.1 错误处理(Error Handling) 13.1.1 异常(Exceptions) 13.1.2 传统错误处理(Traditional Error Handling) 13.1.3 探索(Muddling Through) 13.1.4 异常的替代观点(Alternative Views of Exceptions) 13.1.4.1 异步…

DAY78服务攻防-数据库安全RedisCouchDBH2database未授权访问CVE 漏洞

知识点&#xff1a; 1、数据库-Redis-未授权RCE&CVE 2、数据库-Couchdb-未授权RCE&CVE 3、数据库-H2database-未授权RCE&CVE 前置知识 1、复现环境&#xff1a;Vulfocus(官方在线的无法使用&#xff0c;需要自己本地搭建) 官方手册&#xff1a;https://fofapr…

老牛码看JAVA行业现状

一、坏消息深化与反思&#xff1a; 1、技术瓶颈与框架局限&#xff1a;尽管低代码平台崭露头角&#xff0c;为开发效率带来新气象&#xff0c;但其全面普及尚需时日&#xff0c;Java技术栈的进化似乎陷入了暂时的停滞。开发者们渴望突破&#xff0c;却发现传统框架与模式已难以…