机器学习与数据挖掘_使用梯度下降法训练线性回归模型

目录

实验内容

实验步骤

1. 导入必要的库

2. 加载数据并绘制散点图

3. 设置模型的超参数

4. 实现梯度下降算法

5. 打印训练后的参数和损失值

6. 绘制损失函数随迭代次数的变化图

7. 绘制线性回归拟合曲线

8. 基于训练好的模型进行新样本预测

实验代码

实验结果

实验总结


实验内容

(1)编写代码实现基于梯度下降的单变量线性回归算法,包括梯度的计算与验证;

(2)绘制数据散点图,以及得到的直线;

(3)绘制梯度下降过程中损失的变化图;

(4)基于训练得到的参数,输入新的样本数据,输出预测值。


实验步骤

1. 导入必要的库

使用 `numpy` 进行科学计算,并使用 `matplotlib` 来生成图形。为了保证图形中的中文正常显示,设置 `matplotlib` 的字体为黑体,并解决负号显示问题。

2. 加载数据并绘制散点图

使用 `numpy` 的 `genfromtxt` 函数从文件中加载数据,数据以逗号作为分隔符。分别提取第一列数据为 `x` 值,第二列数据为 `y` 值,展示数据点的分布情况。使用 `scatter` 函数绘制散点图,并使用 `show` 函数显示图形。

3. 设置模型的超参数

初始化线性回归模型的参数:学习率 `alpha` 设置为 `0.0001`。权重 `w` 和偏置 `b` 初始化为 `0`。设置梯度下降的迭代次数为 `1000`。获取数据样本数量 `m`。

4. 实现梯度下降算法

定义一个列表 `MSE` 用来存储每次迭代的均方误差。在每次迭代中,分别计算损失函数和模型参数的梯度:对每一个样本点,计算当前的预测值和真实值的误差,进而计算平方误差并累积。计算梯度,分别对权重 `w` 和偏置 `b` 进行更新。更新后的参数 `w` 和 `b` 基于学习率和当前梯度值来进行调整。

5. 打印训练后的参数和损失值

在训练结束后,打印出模型的最终参数 `w` 和 `b`。使用最后一次迭代的均方误差来表示最终的损失函数值。

6. 绘制损失函数随迭代次数的变化图

使用 `plot` 函数绘制损失函数随迭代次数变化的曲线,`x` 轴为迭代次数,`y` 轴为损失值。图形展示了梯度下降过程中损失函数值的变化趋势,验证模型的收敛情况。

7. 绘制线性回归拟合曲线

再次绘制原始数据的散点图,并基于训练得到的参数计算每个数据点的预测值。使用 `plot` 函数绘制线性回归拟合的曲线,并用红色标出拟合的直线。

8. 基于训练好的模型进行新样本预测

输入新的样本数据 `new_sample`,基于训练得到的参数 `w` 和 `b` 计算新的 `y` 值。打印出新样本数据及其对应的预测值。


实验代码

# 导入必要的库
import numpy as np  # 导入科学计算库
import matplotlib.pyplot as plt  # 导入绘图库
from matplotlib import rcParams  # 导入设置绘图样式的参数# 设置字体,防止中文乱码
rcParams['font.sans-serif'] = ['SimHei']  # 使用黑体
rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 1. 加载数据并画出散点图
points = np.genfromtxt('data1.txt', delimiter=',')  # 从文件中加载数据,数据以逗号分隔
x = points[:, 0]  # 获取第一列数据作为x
y = points[:, 1]  # 获取第二列数据作为y
plt.scatter(x, y)  # 绘制散点图,展示数据的分布情况
plt.show()# 2. 设置模型的超参数
alpha = 0.0001  # 学习率
w = 0  # 初始化权重w
b = 0  # 初始化偏置b
num_iter = 1000  # 梯度下降的迭代次数
m = len(points)  # 样本数量# 3. 梯度下降算法
MSE = []  # 用于保存每次迭代的均方误差
for iteration in range(num_iter):# 初始化梯度的和sum_grad_w = 0  # 用于累加w的梯度sum_grad_b = 0  # 用于累加b的梯度total_cost = 0  # 每次迭代的总损失初始化为0# 遍历所有数据点,计算偏导数并更新梯度for i in range(m):x_i = points[i, 0]  # 当前数据点的x值y_i = points[i, 1]  # 当前数据点的y值# 计算当前点的预测值pred_y_i = w * x_i + b# 计算损失函数(平方误差)total_cost += (y_i - pred_y_i) ** 2# 计算梯度sum_grad_w += (pred_y_i - y_i) * x_i  # 对w的偏导数sum_grad_b += (pred_y_i - y_i)  # 对b的偏导数# 计算当前迭代的均方误差total_cost /= mMSE.append(total_cost)  # 保存每次迭代的损失值# 计算偏导数的平均值grad_w = 2 / m * sum_grad_wgrad_b = 2 / m * sum_grad_b# 更新w和b,基于学习率和梯度w -= alpha * grad_wb -= alpha * grad_b# 4. 打印训练后的参数和损失值
print("参数w = ", w)
print("参数b = ", b)
# 使用 MSE[-1] 来表示最后一次迭代的损失函数值
print("最后的损失函数 = ", MSE[-1])# 5. 绘制损失函数随迭代次数的变化图
plt.plot(MSE)
plt.xlabel('迭代次数')
plt.ylabel('损失值')
plt.title('梯度下降过程中的损失函数变化')
plt.show()# 6. 画出拟合曲线
plt.scatter(x, y)  # 原始数据的散点图
pred_y = w * x + b  # 基于最终的w和b计算所有数据点的预测值
plt.plot(x, pred_y, color='red')  # 绘制线性回归拟合的直线,颜色为红色
plt.title('线性回归拟合曲线')
plt.show()# 7. 基于训练得到的参数进行新样本预测
new_sample = np.array([5, 10, 15])  # 新的输入数据
predicted_y = w * new_sample + b  # 计算新样本的预测值
print("输入的新样本数据: ", new_sample)
print("预测的y值: ", predicted_y)

实验结果

1. 数据散点图及其线性回归拟合曲线

数据散点图及其线性回归拟合曲线

2. 梯度下降过程中损失函数变化图

梯度下降过程中损失函数变化图

3. 相关参数展示及新样本数据和其预测值

相关参数展示及新样本数据和其预测值


实验总结

本次实验通过使用梯度下降法训练线性回归模型,实现了单变量线性回归的训练与预测。实验中,我们成功编写了基于梯度下降算法的代码,并通过图形展示了数据的分布情况及模型的拟合效果。

在实验过程中,模型的权重参数和偏置参数通过多次迭代逐步更新,梯度下降法有效地减少了损失函数值。最终,模型收敛到了一个较好的参数组合,使得拟合曲线能够较好地反映数据的趋势。此外,通过绘制损失函数的变化图,我们直观地看到了随着迭代次数的增加,损失值不断下降的过程,验证了梯度下降算法的收敛性。

实验结果表明,使用梯度下降法能够有效训练线性回归模型,并且在小数据集上可以获得较为理想的拟合效果。同时,通过该实验,进一步加深了对线性回归和梯度下降算法的理解和掌握。

总体而言,实验达到了预期的目标,完成了线性回归模型的训练、损失函数的可视化及新样本的预测任务。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/6642.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

web——sqliabs靶场——第一关

今天开始搞这个靶场,从小白开始一点点学习,加油!!!! 1.搭建靶场 注意点:1.php的版本问题,要用老版本 2.小p要先改数据库的密码,否则一直显示链接不上数据库 2.第一道题&#xff0…

Xamarin 实现播放视频 MP4

我的想法是在App启动时播放一段视频,实现方式如下: 准备一个视频: Logo.mp4 添加到资源中:Assets 然后将资源设置为 AndroidAsset 启动时,将资源文件拷贝到程序目录的files下: protected override void On…

4070显卡只要一毛钱?这个双十一太疯狂了

2024年双十一大战正酣,各大商家使尽浑身解数,奇招频出,真是让人看得目瞪口呆。每日口令红包、攒火力值领裂变红包、限时抢免定金红包……还有各类满减和打折优惠活动,玩法千奇百怪,算来算去索性放弃,真是没…

C++优选算法四 前缀和

前缀和算法是一种常用的优化技术,主要用于加速某些涉及连续子数组或子序列求和的问题。 一、定义与原理 定义:前缀和是指数组中某个位置之前(包括该位置)所有元素的和。前缀和算法则是通过提前计算并存储这些前缀和,…

yum安装指定版本Redis

一,yum安装Redis 1,列出可用的redis版本 yum --showduplicates list redis 只有5.0.3.5版本,如果已经满足需求,可以直接安装 2,安装redis yum -y install 如果显示installed, 说明安装成功了 也可以通过…

DAY21|二叉树Part08|LeetCode: 669. 修剪二叉搜索树、108.将有序数组转换为二叉搜索树、538.把二叉搜索树转换为累加树

目录 LeetCode: 669. 修剪二叉搜索树 基本思路 C代码 LeetCode: 108.将有序数组转换为二叉搜索树 基本思路 C代码 LeetCode: 538.把二叉搜索树转换为累加树 基本思路 C代码 LeetCode: 669. 修剪二叉搜索树 力扣代码链接 文字讲解:LeetCode: 669. 修剪二叉搜…

HarmonyOS基础:鸿蒙系统组件导航Navigation

大家好!我是黑臂麒麟(起名原因:一个出生全右臂自带纹身的高质量程序员😏),也是一位6(约2个半坤年)的前端; 学习如像练武功一样,理论和实践要相结合&#xff0…

​Houdini云渲染如何使用?如何让一个镜头使用成百上千台机器渲染,提高渲染效率

​Houdini云渲染如何使用?如何让一个镜头使用成百上千台机器渲染,提高渲染效率呢,最简单的教程来了! 第一步:云渲码6666注册成都渲染101,并且下载渲染101客户端 客户端是上传下载的工具,将文件…

如何使用Varjo直接观看Blender内容

最近,开源的3D建模程序Blender为Varjo提供了出色的OpenXR支持,包括四视图和凹进渲染扩展。但是在Blender中,默认不启用VR场景检查。要开始使用VR场景检查,只需遵循以下步骤: 1. 下载并安装Blender 2.启用Blender VR场景…

linux 安装anaconda3

1.下载 使用repo镜像网址下载对应安装包 右击获取下载地址,使用终端下载 wget https://repo.anaconda.com/archive/Anaconda3-2024.02-1-Linux-x86_64.sh2.安装 使用以下命令可直接指定位置 bash Anaconda3-2024.02-1-Linux-x86_64.sh -b -p /home/anaconda3也…

JavaScript。—关于语法基础的理解—

一、程序控制语句 JavaScript 提供了 if 、if else 和 switch 3种条件语句&#xff0c;条件语句也可以嵌套。 &#xff08;一&#xff09;、条件语句 1、单向判断 &#xff1a; if... &#xff08;1&#xff09;概述 < if >元素用于在判断该语句是否满足特定条…

DDD学习笔记

DDD学习笔记 1. 什么是 DDD&#xff1f; 领域驱动设计&#xff08;Domain-Driven Design, DDD&#xff09;是一种复杂软件系统设计的方法&#xff0c;强调以业务领域为核心进行设计与开发。它通过将业务逻辑与代码组织紧密结合&#xff0c;帮助开发团队更好地理解和实现业务需…

c语言简单编程练习8

1、递归函数&#xff1a; 通过调用自身来解决问题的函数&#xff0c;递归也就是传递和回归&#xff1b; 递归函数的两个条件&#xff1a; 1&#xff09;函数调用函数本身 2&#xff09;一定要有结束条件 循环与递归的区别&#xff1a; 每调用一次递归函数&#xff0c;都会…

如何将MySQL彻底卸载干净

目录 背景&#xff1a; MySQL的卸载 步骤1&#xff1a;停止MySQL服务 步骤2&#xff1a;软件的卸载 步骤3&#xff1a;残余文件的清理 步骤4&#xff1a;清理注册表 步骤五:删除环境变量配置 总结&#xff1a; 背景&#xff1a; MySQL卸载不彻底往往会导致重新安装失败…

linux-环境变量

环境变量是系统提供的一组 name value 的变量&#xff0c;不同的变量有不同的用途&#xff0c;通常都具有全局属性 env 查看环境变量 PATH PATH是一个保存着系统指令路径的一个环境变量&#xff0c;系统提供的指令不需要路径&#xff0c;直接就可以使用就是因为指令的路径…

IDEA修改生成jar包名字的两种方法实现

IDEA修改生成jar包名字的两种方法实现 更新时间&#xff1a;2023年08月18日 11:45:36 作者&#xff1a;白白白鲤鱼 本文主要介绍了IDEA修改生成jar包名字的两种方法实现,通过简单的步骤,您可以修改项目名称并在打包时使用新的名称,具有一定的参考价值,感兴趣的可以了解下 …

【Java Web】JSP实现数据传递和保存(中)中文乱码 转发与重定向

文章目录 中文乱码转发与重定向转发重定向区别 升级示例1 中文乱码 JSP 中默认使用的字符编码方式&#xff1a;iso-8859-1&#xff0c;不支持中文。常见的支持中文的编码方式及其收录的字符&#xff1a; gb2312&#xff1a;常用简体汉字gbk&#xff1a;简体和繁体汉字utf-8&a…

ROS话题通信机制理论模型的学习

话题通信是ROS&#xff08;Robot Operating System&#xff0c;机器人操作系统&#xff09;中使用频率最高的一种通信模式&#xff0c;其实现模型主要基于发布/订阅模式。 一、基本概念 话题通信模型中涉及三个主要角色&#xff1a; ROS Master&#xff08;管理者&#xff0…

【Ai教程】Ollma安装 | 0代码本地运行Qwen大模型,保姆级教程来了!

我们平时使用的ChatGPT、kimi、豆包等Ai对话工具&#xff0c;其服务器都是部署在各家公司的机房里&#xff0c;如果我们有一些隐私数据发到对话中&#xff0c;很难保证信息是否安全等问题&#xff0c;如何在保证数据安全的情况下&#xff0c;又可以使用大预言模型&#xff0c;O…

从工作原理上解释为什么MPLS比传统IP方式高效?

多协议标签交换&#xff08;Multiprotocol Label Switching, MPLS&#xff09;是一种用于高速数据包转发的技术。它通过在网络的入口点对数据包进行标签操作&#xff0c;然后在核心网络内部基于这些标签来快速转发数据包&#xff0c;从而提高了数据传输效率。以下是几个方面解释…