机器学习-梯度下降实验一

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection 
import train_test_split, KFold
from sklearn.metrics import mean_squared_error, r2_score
from mpl_toolkits.mplot3d import Axes3D  # 用于3D图plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号# 1. 读取数据并进行处理
data = pd.read_csv('data.csv')# 提取输入 (X) 和输出 (Y)
X = data['X'].values.reshape(-1, 1)
Y = data['Y'].values# 划分训练集和测试集,70% 训练,30% 测试
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3, random_state=42)# 为输入 X 添加一列 1 以考虑截距项 (bias)
X_train_b = np.c_[np.ones((X_train.shape[0], 1)), X_train]  # 添加截距项
X_test_b = np.c_[np.ones((X_test.shape[0], 1)), X_test]# 初始化参数 (theta)
theta = np.zeros(2)# 定义超参数
learning_rate = 0.01
n_iterations = 1000# 计算代价函数 (均方误差)def compute_cost(X, Y, theta):m = len(Y)predictions = X.dot(theta)cost = (1 / (2 * m)) * np.sum((predictions - Y) ** 2)return cost# 梯度下降算法def gradient_descent(X, Y, theta, learning_rate, n_iterations):m = len(Y)cost_history = np.zeros(n_iterations)for iteration in range(n_iterations):gradients = (1 / m) * X.T.dot(X.dot(theta) - Y)theta = theta - learning_rate * gradientscost_history[iteration] = compute_cost(X, Y, theta)return theta, cost_history# 交叉验证函数def cross_validation(X, Y, learning_rate, n_iterations, k_folds=5):kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)cv_mse = []for train_index, val_index in kfold.split(X):X_train_fold, X_val_fold = X[train_index], X[val_index]Y_train_fold, Y_val_fold = Y[train_index], Y[val_index]# 为每个 fold 的训练数据添加 biasX_train_fold_b = np.c_[np.ones((X_train_fold.shape[0], 1)), X_train_fold]X_val_fold_b = np.c_[np.ones((X_val_fold.shape[0], 1)), X_val_fold]# 初始化 thetatheta = np.zeros(X_train_fold_b.shape[1])# 使用梯度下降训练模型theta_final, _ = gradient_descent(X_train_fold_b, Y_train_fold, theta, learning_rate, n_iterations)# 对验证集进行预测Y_val_pred = predict(X_val_fold, theta_final)# 计算均方误差mse = mean_squared_error(Y_val_fold, Y_val_pred)cv_mse.append(mse)# 返回交叉验证的平均MSEreturn np.mean(cv_mse)# 预测函数def predict(X, theta):X_b = np.c_[np.ones((X.shape[0], 1)), X]  # 添加截距项return X_b.dot(theta)# 自动调优学习率和迭代次数,并加入交叉验证
best_theta = None
best_mse = float('inf')
best_learning_rate = None
best_iterations = Nonelearning_rates = [0.001, 0.01, 0.02]
iteration_steps = [400, 500, 1000, 2000, 4000]
mse_results = np.zeros((len(learning_rates), len(iteration_steps)))for i, lr in enumerate(learning_rates):for j, iterations in enumerate(iteration_steps):cv_mse = cross_validation(X_train, Y_train, lr, iterations)mse_results[i, j] = cv_mse  # 记录每次的MSEif cv_mse < best_mse:best_mse = cv_msebest_learning_rate = lrbest_iterations = iterationsprint(f"Best MSE after cross-validation: {best_mse}, Best Learning Rate: {best_learning_rate}, Best Iterations: {best_iterations}")# 使用最优学习率和迭代次数重新训练模型
theta_final, cost_history = gradient_descent(X_train_b, Y_train, np.zeros(2), best_learning_rate, best_iterations)# 计算训练集和测试集的拟合程度
Y_train_pred = predict(X_train, theta_final)
Y_test_pred = predict(X_test, theta_final)# 计算均方误差和R2
train_mse = mean_squared_error(Y_train, Y_train_pred)
test_mse = mean_squared_error(Y_test, Y_test_pred)
train_r2 = r2_score(Y_train, Y_train_pred)
test_r2 = r2_score(Y_test, Y_test_pred)print(f"Train MSE: {train_mse}, Train R2: {train_r2}")print(f"Test MSE: {test_mse}, Test R2: {test_r2}")# 1. 可视化训练集和测试集的散点图与拟合直线
plt.figure(figsize=(10, 6))
plt.scatter(X_train, Y_train, color='blue', label='Train Data')
plt.scatter(X_test, Y_test, color='orange', label='Test Data')# 画拟合直线
X_range = np.linspace(min(X), max(X), 100)
Y_pred_line = predict(X_range, theta_final)
plt.plot(X_range, Y_pred_line, color='red', label='Fitted Line')# 画新样本的预测结果# 定义多个新输入数据
X_new_sample = np.array([7.0, 8.5, 6.0, 9.0, 5.5])  # 示例多个新输入# 对新输入进行预测
Y_new_pred = predict(X_new_sample, theta_final)print(Y_new_pred)
plt.scatter(X_new_sample, Y_new_pred, color='green', marker='x', s=100, label='Prediction for X=7.0')plt.title('训练集、测试集与预测结果的拟合曲线')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.grid(True)
plt.show()# 2. 可视化损失函数变化
plt.figure(figsize=(10, 6))
plt.plot(range(len(cost_history)), cost_history, color='green', label='Cost Function')
plt.title('损失的变化图')
plt.xlabel('Number of Iterations')
plt.ylabel('Cost (MSE)')
plt.grid(True)
plt.legend()
plt.show()# 3. 可视化最佳参数选择(学习率和迭代次数的搜索过程)X_lr, Y_iter = np.meshgrid(iteration_steps, learning_rates)fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111, projection='3d')ax.plot_surface(X_lr, Y_iter, mse_results, cmap='viridis')
ax.set_xlabel('Iterations')
ax.set_ylabel('Learning Rate')
ax.set_zlabel('MSE')
ax.set_title('Learning Rate and Iterations vs. MSE')plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1537063.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

LEAN 赋型唯一性(Unique Typing)之 κ 简化 (κ reduction)

在《赋型唯一性的证明过程简介》 提及到&#xff0c;κ 简化 &#xff08;κ reduction&#xff09;概念的引入&#xff0c;是为了证明&#xff0c;在不考虑 证据不区分&#xff08;Proof Irrelevance&#xff09;的情况&#xff0c;表达式具备唯一常态&#xff08;Unique norm…

基于paddleocr的批量图片缩放识别

说明 在进行ocr文字识别的时候&#xff0c;有时候我们需要使用批量测试的功能&#xff0c;但是有些图片会识别失败或者个别根本识别不出来&#xff0c;这时候我们可以通过对原图片进行缩放&#xff0c;提高图像的分辨率&#xff0c;然后再次识别&#xff0c;这样可以大大提高图…

轮转数组 给定一个整数数组 nums,将数组中的元素向右轮转 k 个位置,其中 k 是非负数

示例 1: 输入: nums [1,2,3,4,5,6,7], k 3 输出: [5,6,7,1,2,3,4] 解释: 向右轮转 1 步: [7,1,2,3,4,5,6]向右轮转 2 步: [6,7,1,2,3,4,5] 向右轮转 3 步: [5,6,7,1,2,3,4]示例 2: 输入&#xff1a;nums [-1,-100,3,99], k 2 输出&#xff1a;[3,99,-1,-100] 解释: 向右…

网络安全学习(四)Burpsuite

经过测试&#xff0c;发现BP需要指定的JAVA才能安装。 需要的软件已经放在我的阿里云盘。 &#xff08;一&#xff09;需要下载Java SE 17.0.12(LTS) Java Downloads | Oracle 1.2023版Burp Suite 完美的运行脚本的环境是Java17 2.Java8不支持 看一下是否安装成功&#xff0c…

智慧火灾应急救援:无人机、直升机航拍视角下的火灾应急救援检测数据集代码

智慧火灾应急救援&#xff1a;无人机、直升机航拍视角下的火灾应急救援检测数据集 引言 随着科技的发展&#xff0c;无人机、直升机等飞行器在火灾应急救援中的应用越来越广泛。这些飞行器不仅能快速到达火场&#xff0c;而且可以通过搭载的高清摄像机和其他传感器获取火场的…

编辑器拓展(入门与实践)

学习目标:入门编辑器并实现几个简单的工具 菜单编辑器 MenuItem [MenuItem("编辑器拓展/MenuItem")]static void MenuItem(){Debug.Log("这是编辑器拓展");} } 案例 1&#xff1a;在场景中的 GameObject 设置 1. 设置面板2. 直接创建 GameObject 结构…

jvisualvm工具使用-jvm本地调优(一)

前言&#xff1a; 公司的项目上线后&#xff0c;吞吐量越来越小了&#xff0c;也没有特殊异常抛出&#xff0c;测试环境、预生产又一切正常&#xff0c;反复看了日志&#xff0c;不纠结了&#xff0c;直接把可能影响的因素复制到本地开始jvm调试&#xff0c;随便记录贴个安装教…

linux概述与安装虚拟机

linux 1.Linux 概述 Linux 是一个极具影响力和广泛应用的操作系统。 它起源于芬兰人林纳斯・托瓦兹在大学期间编写的开源内核。Linux 作为一个整体&#xff0c;是免费供用户使用的&#xff0c;具备多用户、多任务、支持多线程的强大特性。 Linux 内核是其核心部分&#xff…

鸿蒙 ArkUI组件二

ArkUI组件&#xff08;续&#xff09; 文本组件 在HarmonyOS中&#xff0c;Text/Span组件是文本控件中的一个关键部分。Text控件可以用来显示文本内容&#xff0c;而Span只能作为Text组件的子组件显示文本内容。 Text/Span组件的用法非常简单和直观。我们可以通过Text组件来显…

上海餐饮数据分析与可视化

数据下载入口&#xff1a;PandasPyecharts | 上海市餐饮数据分析可视化 - Heywhale.com 数据介绍 类别&#xff1a;餐饮类别的名称&#xff08;如烧烤、美食、粤菜等&#xff09;行政区&#xff1a;餐厅所在行政区的名称&#xff08;如浦东新区、闵行区等&#xff09;点评数&a…

【Spring框架精讲】进阶指南:企业级Java应用的核心框架(Spring5)

文章目录 【Spring框架精讲】进阶指南&#xff1a;企业级Java应用的核心框架(Spring5)1.Spring框架快速入门1.1七大核心模块1.1.1 Spring Core1.1.2 Spring-Beans1.1.3 Spring Context1.1.4 Spring-Expression1.1.5 Spring AOP1.1.6 JDBC和DAO模块&#xff08;Spring DAO&#…

C语言 | Leetcode C语言题解之第412题Fizz Buzz

题目&#xff1a; 题解&#xff1a; /*** Note: The returned array must be malloced, assume caller calls free().*/ char ** fizzBuzz(int n, int* returnSize) {/*定义字符串数组*/char **answer (char**)malloc(sizeof(char*)*n);for(int i 1;i<n;i){/*分配单个字符串…

visual prompt tuning和visual instruction tuning

visual prompt tuning&#xff1a;作为一种微调手段&#xff0c;其目的是节省参数量&#xff0c;训练时需要优化的参数量小。 输入&#xff1a;视觉信息image token可学习的prompt token 处理任务&#xff1a;比如常见的分类任务 visual prompt tuning visual instruction tu…

Microsoft 365 Copilot: Wave 2

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

Python3将Excel数据转换为文本文件

文章目录 python3安装使用Python将Excel数据转换为文本文件&#xff1a;逐步指南openpyxl库简介前提条件脚本解析代码详细解析实际应用场景使用示例 结论 python3安装 centos安装python3 Python3基础知识 使用Python将Excel数据转换为文本文件&#xff1a;逐步指南 在数据处理…

闯关leetcode——27. Remove Element

大纲 题目地址内容 解题代码地址 题目 地址 https://leetcode.com/problems/remove-element/description/ 内容 Given an integer array nums and an integer val, remove all occurrences of val in nums in-place. The order of the elements may be changed. Then retur…

Docker 消息队列RabbitMQ 安装延迟消息插件

介绍 RabbitMQ的官方推出了一个插件&#xff0c;原生支持延迟消息功能。该插件的原理是设计了一种支持延迟消息功能的交换机。当消息投递到交换机后可以暂存一定时间&#xff0c;到期后再投递到队列。 查看版本号 docker exec rabbit名字 rabbitmqctl version根据版本下载 插…

Java | Leetcode Java题解之第412题Fizz Buzz

题目&#xff1a; 题解&#xff1a; class Solution {public List<String> fizzBuzz(int n) {List<String> answer new ArrayList<String>();for (int i 1; i < n; i) {StringBuffer sb new StringBuffer();if (i % 3 0) {sb.append("Fizz"…

启动windows更新/停止windows更新,在配置更新中关闭自动更新的方法

在Windows操作系统中&#xff0c;启动或停止Windows更新&#xff0c;以及调整“配置更新”的关闭方法&#xff0c;涉及多种途径&#xff0c;这里将详细阐述几种常用的专业方法。 启动Windows更新 1.通过Windows服务管理器&#xff1a; -打开“运行”对话框&#xff08;…

《小迪安全》学习笔记04

这一块主要讲信息收集——渗透测试第一步&#xff01;&#xff01; 1.首先看有无网站&#xff1a; 存在CDN就用上次说的方法找到真实IP&#xff0c;然后转上↑ 收集四类信息&#xff1a;程序源码&#xff08;CMS&#xff09;等等 2.看有无APP&#xff0c;如涉及到WEB&#xf…