pytorch线性/非线性回归拟合

一、线性回归

1. 导入依赖库

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.autograd import Variable
  • numpy:用来构建数据
  • matplotlib.pyplot: 将构建好的数据可视化
  • torch.nn:包含了torch已经准备好的层,激活函数、全连接层等
  • torch.optim:提供了神经网络的一系列优化算法,如 SGD、Adam 等
  • torch.autograd:用来自动求导,计算梯度。其中Variable用来包装张量,使得张量能够支持自动求导,但在 PyTorch 0.4 及以后,已经被 Tensor 对象取代。

2. 构建数据

        首先确定一个线性函数,例如y_data = 0.1 * x_data + 0.2。然后在这条直线上加一些噪点,最后看神经网络是否能抵抗这些干扰点,拟合出正确的线性函数。

        只要做神经网络相关的数据处理,就一定要把数据转为张量(tensor)类型。然后想要实现梯度下降算法,就要把张量类型再转为Variable类型。

x_data = np.random.rand(100)
noise = np.random.normal(0, 0.01, x_data.shape)  # 构建正态分布噪点
y_data = x_data * 0.1 + 0.2 + noisex_data = x_data.reshape(-1, 1)  # 把原始数据更改形状,自动匹配任意行,1列
y_data = y_data.reshape(-1, 1)x_data = torch.FloatTensor(x_data)  # 把numpy类型转为tensor类型
y_data = torch.FloatTensor(y_data)
inputs = Variable(x_data)  # 变成variable类型才可以自动求导操作
target = Variable(y_data)

 3. 构建神经网络模型

        构建神经网络模型通常遵循一个相对固定的模板。这种模板不仅让代码结构清晰,还能利用 PyTorch 提供的模块化设计,使得网络的定义、训练、推理更加简洁。

        这里我们定义一个一对一的全连接层即可。使用MSE代价函数,SGD优化算法。

class LinearRegression(nn.Module):# 定义网络结构def __init__(self):super(LinearRegression, self).__init__()  # 固定写法,初始化父类self.fc = nn.Linear(1, 1)  # 定义一个全连接层,且一对一# 定义网络计算(前向传播)def forward(self, x):out = self.fc(x)  # 将输入传递给全连接层return outmodel = LinearRegression()  # 定义模型
mse_loss = nn.MSELoss()  # 使用均方差代价函数
optimizer = optim.SGD(model.parameters(), lr=0.1)  # 使用随机梯度下降法优化模型

4. 模型训练

         在模型训练上,几乎也是一个固定套路。之前写的,inputs和target即x_data和y_data的Variable类型。那么当模型(model)获得输入值(inputs),通过前向传播(forward)就会获得一个输出值(out)。然后通过MSE代价函数就能计算出损失(loss),最后经过计算梯度,优化权值,就完成了一轮训练。共训练1000次,期间可以每隔200次看一下损失值。通过输出结果可以看到loss值在一直变小,训练还不错!

for i in range(1001):out = model(inputs)loss = mse_loss(out, target)  # 计算损失optimizer.zero_grad()  # 梯度清0loss.backward()  # 计算梯度optimizer.step()  # 优化权值if i % 200 == 0:print('第{}次,loss值为:{}'.format(i, loss.item()))

        如果我们查看看最后拟合后的权重值(weight)和偏置值(bias),可以发现和我们之前设计好的的 y_data = 0.1 * x_data + 0.2 几乎非常吻合。

for name, param in model.named_parameters():print('name:{}\nparam:{}\n'.format(name, param))

5. 绘图查看结果

         首先利用scatter画出散点图,然后用plot绘出神经网络的拟合结果。

y_pred = model(inputs)
plt.scatter(x_data, y_data)
plt.plot(x_data, y_pred.data.numpy(), color='red')
plt.show()

二、非线性回归

         构建非线性回归时,思路和线性回归几乎一致,只需要把数据改为非线性数据,然后神经网络模型增加一个隐藏层即可。    

1. 构建非线性数据 

        首先事先设计一个非线性函数:y_data = x_data²,然后再加入一些噪点干扰神经网络。

x_data = np.linspace(-2, 2, 200)[:, np.newaxis]  # linspace(起始点,终止点,分割点总数),然后增加维度到(200, 1)
noise = np.random.normal(0, 0.2, x_data.shape)
y_data = np.square(x_data) + noise

2. 修改神经网络模型 

         一般情况下,只有隐藏层使用激活函数才可用来拟合非线性数据,如sigmoid、relu、tanh等。这里可以先确定10个隐藏神经元看效果如何。

class NonLinearRegression(nn.Module):# 定义网络结构def __init__(self):super(NonLinearRegression, self).__init__()  # 固定写法,初始化父类self.fc1 = nn.Linear(1, 10)  #   定义隐藏层,10个隐藏神经元self.tanh = nn.Tanh()  # 激活函数self.fc2 = nn.Linear(10, 1)# 定义网络计算(前向传播)def forward(self, x):x = self.fc1(x)x = self.tanh(x)x = self.fc2(x)return x

        如果想要较短时间的训练来获取一个相对较好的结果,可以尝试 Adam 自适应矩阵优化算法。虽然 Adam 算法可以自动调整学习率,但是一般默认初始值是0.001,最后训练情况不理想,所以这里设置为0.05的初始值。而且这个算法容易过拟合,需要正则化 weight_decay 来提高模型的泛化性。

        注意:这里的代价函数不可以修改为交叉熵(CrossEntropyLoss),因为交叉熵大多用于分类任务。

model = NonLinearRegression()
mse_loss = nn.MSELoss()  # 均方差代价函数
optimizer = optim.Adam(model.parameters(), lr=0.05, weight_decay=0.001)  # 设置L2正则化,防止过拟合

3. 查看拟合结果 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1551432.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

2024还在拼多多赚钱的,无不满足这几个条件

拼多多只是我棋盘上的一小步,整个棋局几人看懂了? 如果我说我做拼多多,其实是另有目的,拼多多只是我棋局里的一小步,你们信吗?认真看文章,后面会为大家揭秘! 先来客观公正的回答下…

Queued Synchronous Peripheral Interface (QSPI)

文章目录 1. 介绍2. Feature List3. 概述3.1 QSPI框图3.2 操作模式3.3 三线模式3.4 时钟极性和时钟相位 4. Master模式4.1 状态机4.2 采样点4.3 波特率4.4 通信模式4.4.1 短数据模式4.4.2 长数据模式4.4.3 短连续模式4.4.4 长连续模式4.4.5 单配置多帧模式4.4.6 XXL模式4.4.7 M…

选择国企eHR人事管理系统的时候,应该注意什么?

近年来,中国正步入高速发展的黄金时期,国有企业(国企)在追求效率和管理水平提升方面迈出了重要步伐。为了进一步实现数字化、流程化和科学化管理,越来越多的国企选择引进eHR(电子人力资源管理)系…

【Diffusion分割】MedSegDiff-v2:Diffusion模型进行医学图像分割

MedSegDiff-V2: Diffusion-Based Medical Image Segmentation with Transformer 摘要: 最近的研究揭示了 DPM 在医学图像分析领域的实用性,医学图像分割模型在各种任务中表现出的出色性能就证明了这一点。尽管这些模型最初是以 UNet 架构为基础的&…

opencv实战项目(三十):使用傅里叶变换进行图像边缘检测

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 一,什么是傅立叶变换?二,图像处理中的傅立叶变换:三,傅里叶变换进行边缘检测: 一&#xff0c…

13个大V出文需要准确把握的重要因素

推文作为全球最大的社交平台之一,吸引了很多大V(即具有巨大粉丝团的影响力和的账户)的关注。那些大V常常运用推文发布相关各种各样热点的营销推广信息,以吸引更多人的关注参与。推文的发布时间段是V在宣传推广过程中需要准确把握的…

【真实访问】那些选择土木专业的学生,后来怎么样了?

“你会让孩子报土木专业吗?” 7月15日,澎湃新闻在微博上发起线上调研,截至16日12时,8000多人参与了投票,结果显示近7000人选择“不会,天坑专业”。短短几年时间,土木工程专业的报考从“香饽饽”…

CAN总线的错误类型

前言 CAN总线的错误类型主要包括:位错误、填充错误、格式错误、ACK错误和CRC错误。这里一定要做好CAN总线的错误类型、错误帧类型、节点状态之间的区别。 错误类型是帧传输出错的原因类型;错误帧类型(主动错误帧、被动错误帧)是帧…

基于IntraWeb的数据表格的多选实现

基于IntraWeb的数据表格的多选实现 既可以单条操作,也可以多选操作。 delphi源代码。 BS开发Web网站开发,不需要安装服务器,Apache和IIS都不需要,自带企业级服务器。 运行exe服务器就架好了,直接打开手机浏览器或者…

Zombie Slaughter 写实30个僵尸丧尸带动画角色模型

包含30个操纵的僵尸(15个男性和15个女性角色)+动画 所有僵尸都有分离的身体部位,以获得更好的射击/砍杀体验:) PBR材质包含4种纹理(基色、法线、粗糙度、AO),分辨率为4096x4096。 动画包括: -闲置 -步行 - 走回去 - 向右转 - 向左转 -担心 -尖叫 - 走路惹 -快跑 -阿格罗…

人工智能与伦理:如何确保AI应用中的隐私保护

引言 随着人工智能技术的飞速发展,AI已经渗透到我们生活的各个领域,从智能助手到个性化推荐系统,再到医疗诊断和金融服务,人工智能正在为我们带来前所未有的便利。然而,伴随着AI的广泛应用,隐私保护问题日益…

优选驾考系统小程序的设计

管理员账户功能包括:系统首页,个人中心,驾校管理,驾考文章管理,驾照类型管理,报名入口管理,学员报名管理,练车预约管理,考试场地管理 微信端账号功能包括:系统…

加油卡APP系统:省时、优惠、安心!

在汽车加油的刚需下,如何更加优惠的“加油”成为了大众关心的重点,而以优惠为主的加油卡系统也成为了大众的主要选择。 加油卡系统是汽车加油线上的服务系统,拥有全国各地的加油站权限,能够让车主在手机上进行充值,同…

VMware虚拟机连接公网,和WindTerm

一、项目名称 vmware虚拟机连接公网和windterm 二、项目背景 需求1:windows物理机,安装了vmware虚拟机,需要访问公网资源,比如云服务商的yum仓库,国内镜像加速站的容器镜像,http/https资源。 需求2&#xf…

【git】git分支之谜-十分钟给你讲透彻

这里写自定义目录标题 引子分支的直观模型在 git 中,分支是完整的提交记录分支用commit ID存储人们的直觉通常并没有那么错rebase 使用“直观”的分支概念merge也使用“直观”的分支概念github pull request 也使用直观的想法直觉很好,但它也有一些局限性…

前端编程艺术(1)---HTML

目录 1.HTML 2.注释 3.标题标签 4.段落标签 5.换行与水平分隔线 6.文本格式化标签 7.图像标签和属性 8.超链接 8.列表标签 9.表格标签 10.表单标签 11.HTML5 1.HTML HTML(HyperText Markup Language,超文本标记语言)是一种用于创建…

【JAVA开源】基于Vue和SpringBoot的新闻推荐系统

本文项目编号 T 056 ,文末自助获取源码 \color{red}{T056,文末自助获取源码} T056,文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 查…

Node-RED系列教程-23node-red获取OPCDAServer数据(DCOM配置)

准备了一个干净的windows 2016虚拟机: administrator wong@123 以管理员身份进入系统: 准备好如下软件: 使用的nodejs版本为: 设置淘宝镜像源: npm config set registry https://registry.npmmirror.com 安装nodered: npm install -g --unsafe-perm node-red@2.2.2

如何组织鼠标的默认的事件

如何组织鼠标的默认的事件 我原先的代码是 dblclick"checkNode(data)"设置了一个双击的事件,我如果双击的话就会导致这个内容被选中。 选中内容的同时会触发浏览器默认的操作,导致出现复制的框这些东西。 解决的方法。加一句。 mousedown.pr…

Power apps:一次提交多项申请

1、添加一个Form,导入sharepoint列表,添加确认,继续,取消按钮 2、在页面的onvisible属性中添加 Set(applynumber,Last(付款申请表).申请编号1); #定义一个申请编号变量,每次申请,就将列表最后一个…