PyTorch的模型定义方法

文章目录

  • 1、简介
  • 2、导包
  • 3、设置属性
  • 4、构建数据集
  • 5、训练函数
    • 5.1、初始准备
    • 5.2、训练过程
    • 5.3、绘制图像
  • 6、运行效果
  • 7、完整代码

🍃作者介绍:双非本科大三网络工程专业在读,阿里云专家博主,专注于Java领域学习,擅长web应用开发、数据结构和算法,初步涉猎人工智能和前端开发。
🦅个人主页:@逐梦苍穹
📕所属专栏:人工智能
🌻gitee地址:xzl的人工智能代码仓库
✈ 您的一键三连,是我创作的最大动力🌹

1、简介

前面我们使用手动的方式来构建了一个简单的线性回归模型,如果碰到一些较大的网络设计,手动构建过于繁琐。

手动构建线性回归模型:https://xzl-tech.blog.csdn.net/article/details/140623730

所以,我们需要学会使用 PyTorch 的各个组件来搭建网络。
接下来,我们使用 PyTorch 提供的接口来定义线性回归。

  1. 使用 PyTorch 的 nn.MSELoss() 代替自定义的平方损失函数
  2. 使用 PyTorch 的 data.DataLoader 代替自定义的数据加载器
  3. 使用 PyTorch 的 optim.SGD 代替自定义的优化器
  4. 使用 PyTorch 的 nn.Linear 代替自定义的假设函数

解析如下:

数据集和数据加载器

  • 构建数据集对象 TensorDataset,用于将特征 x 和标签 y 封装为一个数据集。
  • 构建数据加载器 DataLoader,用于按批次加载数据,批次大小为 16,并打乱数据顺序。

构建模型、损失函数和优化器

  • 使用 nn.Linear 构建一个线性模型,输入和输出特征数均为 1。
  • 使用均方误差损失函数 nn.MSELoss
  • 使用随机梯度下降优化器 optim.SGD,学习率为 0.01。

训练过程

  • 外层循环控制训练轮数 epochs
  • 内层循环通过数据加载器 dataloader 按批次加载训练数据。
  • 每个批次中:
    • 将训练数据送入模型,计算预测值 y_pred
    • 计算预测值与真实值之间的损失 loss
    • 梯度清零,防止梯度累积。
    • 反向传播计算梯度。
    • 使用优化器更新模型参数。

我们接下来使用 PyTorch 来构建线性回归

2、导包

image.png

3、设置属性

image.png

4、构建数据集

image.png

5、训练函数

5.1、初始准备

image.png

5.2、训练过程

image.png

5.3、绘制图像

image.png

6、运行效果

image.png
从程序运行结果来看,我们绘制一条拟合的直线,和原始数据的直线基本吻合,说明我们训练的还不错。

7、完整代码

# -*- coding: utf-8 -*-
# @Author: CSDN@逐梦苍穹
# @Time: 2024/7/23 4:08import torch  # 导入 PyTorch 库
from torch.utils.data import TensorDataset  # 导入 TensorDataset 类,用于创建数据集
from torch.utils.data import DataLoader  # 导入 DataLoader 类,用于批量加载数据
import torch.nn as nn  # 导入 torch.nn 模块,用于构建神经网络
import torch.optim as optim  # 导入 torch.optim 模块,用于优化算法
from sklearn.datasets import make_regression  # 导入 make_regression 函数,用于生成回归数据集
import matplotlib.pyplot as plt  # 导入 matplotlib.pyplot 模块,用于绘图# 设置 Matplotlib 的字体和显示属性,用来正常显示中文标签和负号
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置字体为 SimHei,用于显示中文
plt.rcParams['axes.unicode_minus'] = False  # 允许显示负号# 构建数据集
def create_dataset():# 使用 make_regression 函数生成回归数据集x, y, coef = make_regression(n_samples=100,  # 样本数量为 100n_features=1,  # 特征数量为 1noise=10,  # 噪声为 10coef=True,  # 返回系数bias=14.5,  # 偏置为 14.5random_state=0)  # 随机种子为 0# 将构建的数据转换为张量类型x = torch.tensor(x)y = torch.tensor(y)return x, y, coef  # 返回特征、标签和系数# 定义训练函数
def train():# 构建数据集x, y, coef = create_dataset()# 构建数据集对象, 将特征和标签封装为 TensorDataset 对象dataset = TensorDataset(x, y)# 构建数据加载器dataloader = DataLoader(dataset, batch_size=16, shuffle=True)  # 创建 DataLoader 对象,批次大小为 16,并打乱数据# 构建模型model = nn.Linear(in_features=1, out_features=1)  # 创建线性模型,输入特征数为 1,输出特征数为 1# 构建损失函数criterion = nn.MSELoss()  # 使用均方误差损失函数# 优化方法optimizer = optim.SGD(model.parameters(), lr=1e-2)  # 使用随机梯度下降法,学习率为 0.01# 初始化训练参数epochs = 100  # 训练轮数为 100# 训练过程for _ in range(epochs):  # 训练 epochs 轮for train_x, train_y in dataloader:  # 遍历每个批次的数据# 将一个批次的训练数据送入模型y_pred = model(train_x.type(torch.float32))  # 计算模型的预测值# 计算损失值loss = criterion(y_pred, train_y.reshape(-1, 1).type(torch.float32))  # 计算批次损失值# 梯度清零optimizer.zero_grad()  # 清零优化器中的梯度# 自动微分(反向传播)loss.backward()  # 反向传播计算梯度# 更新参数optimizer.step()  # 使用优化器更新模型参数# 绘制拟合直线plt.scatter(x, y)  # 绘制散点图x_vals = torch.linspace(x.min(), x.max(), 1000)  # 生成从 x 的最小值到最大值的等间距点y1 = torch.tensor([v * model.weight + model.bias for v in x_vals])  # 计算训练得到的拟合直线y2 = torch.tensor([v * coef + 14.5 for v in x_vals])  # 计算真实的直线plt.plot(x_vals, y1, label='训练')  # 绘制训练得到的拟合直线plt.plot(x_vals, y2, label='真实')  # 绘制真实直线plt.grid()  # 显示网格plt.legend()  # 显示图例plt.show()  # 显示图形# 主程序入口
if __name__ == '__main__':train()  # 调用 train 函数开始训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1486842.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【图形图像-1】SDF

在图形图像处理中,SDF(Signed Distance Field,带符号的距离场)是一种表示图形轮廓和空间距离的数学结构。它通常用于计算机图形学、文本渲染、碰撞检测和物理模拟等领域。 SDF(Signed Distance Field,带符号…

【数据结构】排序算法——Lesson2

Hi~!这里是奋斗的小羊,很荣幸您能阅读我的文章,诚请评论指点,欢迎欢迎 ~~ 💥💥个人主页:奋斗的小羊 💥💥所属专栏:C语言 🚀本系列文章为个人学习…

算法力扣刷题记录 五十七【236. 二叉树的最近公共祖先】和【235. 二叉搜索树的最近公共祖先】

前言 公共祖先解决。二叉树和二叉搜索树条件下的最近公共祖先。 二叉树篇继续。 一、【236. 二叉树的最近公共祖先】题目阅读 给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为:“对于有根树 T 的两个节点 p、q&#xff…

Windows 磁盘分区样式有几种?如何查看电脑分区样式?

在使用 Windows 操作系统的过程中,磁盘分区是一个重要的概念。磁盘分区的方式直接影响到数据存储和系统运行的效率。磁盘分区的时候也有不同的样式,你知道分区类型有哪些吗?不同的分区样式决定了硬盘的分区方式、可支持的最大存储容量以及兼容…

学习笔记:MySQL数据库操作3

1. 创建数据库和表 创建数据库 mydb11_stu 并使用该数据库。创建 student 表,包含字段:学号(主键,唯一),姓名,性别,出生年份,系别,地址。创建 score 表&…

Etsy:以手工制品和复古商品闻名的美国淘宝允许AI艺术品售卖

Etsy是一个美国网络商店平台,以手工艺成品买卖为主要特色,曾被纽约时报拿来和eBay,Amazon比较,被誉为“祖母的地下室收藏”。 Etsy 是一家以手工制品和复古商品闻名的美国网络商店平台在线市场,以手工艺成品买卖为主要…

由“微软蓝屏”事件引发的对网络安全与系统稳定性的思考

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、软件更新流程的漏洞与改进二、强化应急响应机制三、技术创新与应用四、关键行业的特殊应对五、用户意识的提升与数据备份六、全球合作与统一标准总结 前言 …

浅谈断言之XML断言

浅谈断言之XML断言 XML断言是JMeter的一个组件,用于验证请求的响应数据是否符合XML结构。这对于测试返回XML格式数据的Web服务特别有用。 如何添加XML断言? 要在JMeter测试计划中添加XML断言,遵循以下步骤: 打开测试计划&…

The Sandbox:虚拟游戏世界生态系统详解

元宇宙由区块链、软件基础、移动应用、控制台等组成,是一个虚拟空间,结合了增强现实(AR)、虚拟现实(VR)和在线游戏等元素。它强调互操作性,允许用户在不同的虚拟平台之间自由切换。与传统的现实…

病理AI领域的常用开源工具汇总

小罗碎碎念 本期推文主题:病理AI领域的常用开源工具汇总 我们有快一周的时间没见啦,所以,这一期推文带来一些比较有实用价值的资源。 我总结了5个病理AI领域常用的软件,用专用于注释的,也有包含整个处理流程的&#x…

【Linux】UDP 协议

目录 1. UDP 协议2. UDP 协议的特点:3. UDP 协议的格式4. UDP 的缓冲区基于UDP的应用层协议 1. UDP 协议 UDP (User Datagram Protocol) 是一种面向数据报的传输层协议, 是传输层的重要协议之一; UDP协议提供了一种无连接, 不可靠的数据传输服务; 适用于要求源主机以恒定速率…

主控制类,项目小结,实时更新UI

1.用户的信息进行更改,上传请求,服务端进行直接操作数据库,返回请求,客户端根据返回的请求,进行更新界面。 按照我前一篇所说的,写好了主控制类,和第二线程接受服务端的信息,这时候…

【Hot100】LeetCode—416. 分割等和子集

目录 题目1- 思路2- 实现⭐152. 乘积最大子数组——题解思路 3- ACM 实现 题目 原题连接:416. 分割等和子集 1- 思路 理解为背包问题 思路: 能否将均分的子集理解为一个背包,比如对于 [1,5,11,5],判断能否凑齐背包为 11 的容量…

leetcode算法题之接雨水

这是一道很经典的题目,问题如下: 题目地址 给定 n 个非负整数表示每个宽度为 1 的柱子的高度图,计算按此排列的柱子,下雨之后能接多少雨水。 解法1:动态规划 动态规划的核心就是将问题拆分成若干个子问题求解&#…

2024算法、高性能计算与人工智能国际学术会议(AHPCAI 2024)

2024算法、高性能计算与人工智能国际学术会议(AHPCAI 2024) 2024 International Conference on Algorithms, High Performance Computing and Artificial Intelligence 2024年8月14-16日 | 中国-郑州 2024中国算力大会正在发起“算力中国最佳学术论文…

今天我们聊聊C#的并发和并行

并发和并行是现代编程中的两个重要概念,它们可以帮助开发人员创建高效、响应迅速、高性能的应用程序。在C#中,这些概念尤为重要,因为该语言提供了对多线程和异步编程的强大支持。本文将介绍C#中并发和并行编程的关键概念、优点,并…

Langchain核心模块与实战[7]:专业级Prompt工程调教LLM[输入输出接口、提示词模板与例子选择器的协同工程]

Langchain核心模块与实战[7]:专业级Prompt工程调教LLM[输入输出接口、提示词模板与例子选择器的协同工程] 1. 大模型IO接口 任何语言模型应用的核心元素是…模型的输入和输出。LangChain提供了与任何语言模型进行接口交互的基本组件。 提示 prompts : 将模型输入模板化、动态…

【LeetCode:3096. 得到更多分数的最少关卡数目+ 前缀和】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

百度,有道,谷歌翻译API

API翻译 百度,有道,谷歌API翻译(只针对中英相互翻译),其他语言翻译需要对应from,to的code 百度翻译 package fills.tools.translate; import java.util.ArrayList; import java.util.HashMap; import java.util.Lis…

宠物空气净化器哪款除臭效果好?质量好的养狗空气净化器排名

作为一个宠物家电小博主,炎炎夏日,家中的宠物给你带来的不仅仅是温暖的陪伴,还有那挥之不去的宠物异味。普通空气净化器虽然能够应对一般的空气净化需求,但对于养猫家庭特有的挑战,如宠物毛发、皮屑和异味等&#xff0…