Pytorch深度学习实践(4)使用Pytorch实现线性回归

使用Pytorch实现线性回归

基本步骤:

  • 准备数据集
  • 设计模型
  • 构造损失函数和优化器
  • 模型训练
    • forward计算损失
    • backward计算梯度
    • update更新参数

准备数据集

[ y p r e d ( 1 ) y p r e d ( 2 ) y p r e d ( 3 ) ] = ω [ x ( 1 ) x ( 2 ) x ( 3 ) ] + b \begin {bmatrix}y_{pred}^{(1)} \\ y_{pred}^{(2)} \\ y_{pred}^{(3)} \end{bmatrix} =\omega \begin {bmatrix}x^{(1)} \\ x^{(2)} \\ x^{(3)} \end{bmatrix} + b ypred(1)ypred(2)ypred(3) =ω x(1)x(2)x(3) +b

import torch
## 注意x和y的值必须是矩阵
x_data = torch.tensor([[1.0], [2.0], [3.0]])
y_data = torch.tensor([[2.0], [4.0], [6.0]])

设计模型

在这里插入图片描述

在Pytorch里,重点是构造计算图

在这里使用的是仿射模型,即线性单元
z = w x + b z = wx + b z=wx+b
需要确定的是 w w w b b b 的维度大小,即要通过输入和输出的维度来确定权重的维度

必须注意的是 l o s s loss loss一定要是一个标量

一般而言,会把模型设计成类

class LinearModel(torch.nn.Module):  #继承自Moduledef __init__(self):  #构造函数super(LinearModel, self).__init__()  # 调用负类的构造self.linear = torch.nn.Linear(1, 1)  # 构造Linear对象 包含权重和偏置def forward(self, x):y_pred = self.linear(x)return y_predmodel = LinearModel()  # 实例化LinearModel()对象

torch.nn.Linear(in_features, out_features, bias=True)参数:

  • in_features 输入的每一个样本的维度
  • out_features 输出的每一个样本的维度
  • bias 是否需要添加偏置,默认为True

forward()方法中y_pred = self.linear(x)调用的是了python中的__call__函数。在Pytorch的Module.__call__()中有一个重要的语句就是forward(),也就是说,在这里我们必须写forward()来去覆盖

定义损失函数和优化器

损失函数

损失函数使用MSE

criterion = torch.nn.MSELoss(size_average=False)

参数设置:

  • size_average,是否对损失求平均,默认为True
  • reduce,用来确定是否要把损失求和降维(特征降维)

一般而言,只考虑size_average

优化器

使用梯度下降

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

参数设置:

  • params,传入模型需要优化的权重
  • lr,学习率

模型训练

训练100次,主要是三个步骤

  • 前馈计算
  • 反向传播
  • 梯度更新

注意不要忘记梯度清零

for epoch in range(100):y_pred = model(x_data)  # 前馈计算loss = criterion(y_pred, y_data)  # 计算损失print(epoch, loss)optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向传播optimizer.step()  # 参数更新## 损失数据可视化
plt.plot(np.arange(100), loss_history)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()
## 打印训练后的参数
print("w = ", model.linear.weight.item())
print("b = ", model.linear.bias.item())

在这里插入图片描述

模型测试

x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print("y_pred = ", y_test.item())

测试结果如下
在这里插入图片描述

整体代码

import torch
import matplotlib.pyplot as plt
import numpy as np## 注意x和y的值必须是矩阵
x_data = torch.tensor([[1.0], [2.0], [3.0]])
y_data = torch.tensor([[2.0], [4.0], [6.0]])
loss_history = []########## 模型的定义 ##########
class LinearModel(torch.nn.Module):  #继承自Moduledef __init__(self):  #构造函数super(LinearModel, self).__init__()self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x)return y_predmodel = LinearModel()  # 实例化Linear()对象########## 定义损失函数和优化器 ##########
## 损失函数
criterion = torch.nn.MSELoss(size_average=False)
## 优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)########## 模型训练 ##########
for epoch in range(100):y_pred = model(x_data)  # 前馈计算loss = criterion(y_pred, y_data)  # 计算损失print(epoch, loss)loss_history.append(loss.item())optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向传播optimizer.step()  # 参数更新## 损失数据可视化
plt.plot(np.arange(100), loss_history)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()
## 打印训练后的参数
print("w = ", model.linear.weight.item())
print("b = ", model.linear.bias.item())########## 模型测试 ##########
x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print("y_pred = ", y_test.item())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1487482.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【YashanDB知识库】stmt未close,导致YAS-00103 no free block in sql main pool part 0报错分析

问题现象 问题单:YAS-00103 no free block in sql main pool part 0,YAS-00105 out of memory to allocate hash table of size 256 现象:业务处理sql时,报错YAS-00103 no free block in sql main pool part 0 问题风险及影响…

Springboot 开发之 RestTemplate 简介

一、什么是RestTemplate RestTemplate 是Spring框架提供的一个用于应用中调用REST服务的类。它简化了与HTTP服务的通信,统一了RESTFul的标准,并封装了HTTP连接,我们只需要传入URL及其返回值类型即可。RestTemplate的设计原则与许多其他Sprin…

k8s v1.30 完整安装过程及CNI安装过程总结

博主未授权任何人或组织机构转载博主任何原创文章,感谢各位对原创的支持! 博主链接 本人就职于国际知名终端厂商,负责modem芯片研发。 在5G早期负责终端数据业务层、核心网相关的开发工作,目前牵头6G技术研究。 博客内容主要围绕…

25.x86游戏实战-理解发包流程

免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 本次游戏没法给 内容参考于:微尘网络安全 工具下载: 链接:https://pan.baidu.com/s/1rEEJnt85npn7N38Ai0_F2Q?pwd6tw3 提…

视图,存储过程和触发器

目录 视图 创建视图: 视图的使用 查看库中所有的视图 删除视图 视图的作用: 存储过程: 为什么使用存储过程? 什么是存储过程? 存储过程的创建 创建一个最简单的存储过程 使用存储过程 删除存储过程 带参的存储…

智能家居全在手机端进行控制,未来已来!

未来触手可及:智能家居,手机端的全控时代 艾斯视觉的观点是:在不远的将来,家,这个温馨的港湾,将不再只是我们休憩的场所,而是科技与智慧的结晶。想象一下,只需轻触手机屏幕&#xf…

常用的自动化测试工具有哪些?

什么是自动化测试?简单来说,自动化测试就是通过重复执行预定义的动作来执行测试用例的系统来代替人工操作。为了充分利用自动化,必须选择正确的自动化测试工具。 一、自动化测试工具有哪些 1、Selenium WEB自动化测试 Selenium是网页应用中最…

Java给定一些元素随机从中选择一个

文章目录 代码实现java.util.Random类实现随机取数(推荐)java.util.Collections实现(推荐)Java 8 Stream流实现(不推荐) 完整代码参考(含测试数据) 在Java中,要从给定的数据集合中随机选择一个元素,我们很容易想到可以使用 java.…

ARM系列运行异常排查

一、断点指令BKPT BKPT指令产生软件断点中断,可用于程序的调试。它使处理器停止执行正常指令(使处理器中止预取指)而进入相应的调试程序。 BKPT指令的格式为:BKPT 16位的立即数 二、使用BKPT进行软件异常定位 假设异常发生后…

electron 网页TodoList应用打包win桌面软件数据持久化

参考: electron 网页TodoList工具打包成win桌面应用exe https://blog.csdn.net/weixin_42357472/article/details/140648621 electron直接打包exe应用,打开网页上面添加的task在重启后为空,历史没有被保存,需要持久化工具保存之前…

铠侠最新BiCS8 218L NAND键合技术

随着存储技术的不断演进,Hybrid Bonding(混合键合)技术正逐渐成为内存和存储应用领域的重要组成部分。TechInsights最近对KIOXIA/WD BiCS8 218L CBA 1 Tb 3D TLC NAND进行了深入分析,揭示了这项技术如何在提高存储密度、降低功耗和…

在Windows下部署jar包,关闭命令提示符可以后台运行

前言 大多数情况下,都是选用Linux作为服务器部署服务,在Linux中通过以下命令运行 nohup java -jar xxxxx-1.0-SNAPSHOT.jar 但是有时由于其他原因,或本地测试,或云服务器使用Windows server等等,需要在Windows上面运…

[嵌入式Linux]-常见编译框架与软件包组成

嵌入式常见编译框架与软件包组成 1.嵌入式开发准备工作 主芯片资料包括: 主芯片资料 主芯片开发参考手册;主芯片数据手册;主芯片规格书; 硬件参考 主芯片硬件设计参考资料;主芯片配套公板硬件工程; 软件…

学术研讨 | 基于区块链的隐私计算与数据可信流通研讨会顺利召开

近日,由国家区块链技术创新中心组织的“基于区块链的隐私计算与数据可信流通研讨会”顺利召开,会议邀请了来自全国高校和科研院所的相关领域专家,围绕基于区块链与隐私计算技术的应用需求、研究现状、发展趋势、重点研究方向与研究进展等内容…

基于 LlamaIndex 构建自己的 RAG 知识库

创建虚拟环境用于运行 运行 InternLM 的基础环境,命名为 llamaindex conda create -n llamaindex python3.10 查看存在的环境 conda env list 激活刚刚创建的环境 conda activate llamaindex 安装基本库pytorch,torchvision ,torchaudio,pytorch-cuda 并指定通道&…

【React】JSX 实现列表渲染

文章目录 一、基础语法1. 使用 map() 方法2. key 属性的使用 二、常见错误和注意事项1. 忘记使用 key 属性2. key 属性的选择 三、列表渲染的高级用法1. 渲染嵌套列表2. 条件渲染列表项3. 动态生成组件 四、最佳实践 在 React 开发中,列表渲染是一个非常常见的需求。…

Mac装虚拟机占内存吗 Mac用虚拟机装Windows流畅吗

如今,越来越多的Mac用户选择在他们的设备上安装虚拟机来运行不同的操作系统。其中,最常见的是使用虚拟机在Mac上运行Windows。然而,许多人担心在Mac上装虚拟机会占用大量内存,影响电脑系统性能。此外,有些用户还关心在…

Nginx Proxy缓存

Proxy缓存 缓存类型 网页缓存 (公网)CDN数据库缓存 memcache redis网页缓存 nginx-proxy客户端缓存 浏览器缓存 模块 ngx_http_proxy_module 语法 缓存开关 Syntax: proxy_cache zone | off; Default: proxy_cache off; Context: http,…

【JavaEE】Bean的作用域和生命周期

一.Bean的作用域. 1.1 Bean的相关概念. 通过Spring IoC和DI的学习(不清楚的可以看的前面写过的总结,可以快速入门, http://t.csdnimg.cn/K8Xr0),我们知道了Spring是如何帮助我们管理对象的 通过 Controller , Service , Repository , Component , Configuration , Bean 来声明…

开发桌面程序-Electron入门

Electron是什么 来自官网的介绍 Electron是一个使用 JavaScript、HTML 和 CSS 构建桌面应用程序的框架。 嵌入 Chromium 和 Node.js 到 二进制的 Electron 允许您保持一个 JavaScript 代码代码库并创建 在Windows上运行的跨平台应用 macOS和Linux——不需要本地开发 经验。 总…