0基础学习PyTorch——最小Demo

大纲

  • 环境准备
    • 安装依赖
  • 训练和推理
    • 训练
      • 生成数据
      • 加载数据
        • TensorDataset
        • DataLoader
      • 定义神经网络
      • 定义损失函数和优化器
      • 训练模型
    • 推理
  • 参考代码

PyTorch以其简洁直观的API、动态计算图和强大的社区支持,在学术界和工业界都享有极高的声誉,成为许多深度学习爱好者和专业人士的首选工具。本系列将更多从工程实践角度探索PyTorch的使用,而不是算法公式的讨论。

环境准备

使用《管理Python虚拟环境的脚本》中的脚本初始化一个虚拟环境。

source env.sh init

在这里插入图片描述

然后进入虚拟环境

source env.sh enter

在这里插入图片描述

安装依赖

source env.sh install pyyaml

在这里插入图片描述

source env.sh install torch

这个过程比较漫长,需要下载一个多G文件。
在这里插入图片描述

source env.sh install numpy

在这里插入图片描述

训练和推理

训练就是模型训练。我们可以认为知道系统的输入和输出(目标),猜测系统中的算法。在这里插入图片描述

推理则是使用模型,计算出对应的输出。

在这里插入图片描述
举一个例子,也是我们后面代码的例子。假如我们使用f(x)=2x+1计算一批随机数(输入)得到一批计算结果(目标),然后我们将这些数据交给一个模型训练器,可以得到一个模型。这个模型的计算结果(输出)应该非常近似于f(x)=2x+1。

训练

生成数据

数据生成不是必须的,因为我们从其他地方获取数据。为了让这个例子没有太多依赖项,我们就自己生成数据。
input_data 是100个随机数数组,它是模型训练的“输入”数据;target_data是对input_data使用f(x)=2x+1得到的一个数组,它是模型训练的“目标”数据。即我们需要模型要将“输入”尽量转换成接近的“目标”。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset# 生成一些随机数据
torch.manual_seed(0)
input_data = torch.randn(100, 1)  # 100个样本,每个样本有2个特征
target_data = 2 * input_data + 1  # 简单的线性关系

加载数据

# 创建数据加载器
dataset = TensorDataset(input_data, target_data)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
TensorDataset

TensorDataset 是一个简单的数据集封装类,用于将多个张量(tensors)包装在一起。它的主要作用是将特征和标签数据对齐,以便于后续的数据加载和处理。

主要功能:

  • 将多个张量封装成一个数据集。
  • 使得数据集可以通过索引访问。
DataLoader

DataLoader 是一个数据加载器类,用于将数据集分批次加载到模型中进行训练或评估。它的主要作用是提供一个迭代器,能够高效地加载数据,并支持多线程并行加载。

主要功能:

  • 将数据集分批次加载。
  • 支持多线程并行加载数据。
  • 支持数据的随机打乱(shuffle)。
  • 提供一个迭代器,方便在训练循环中使用。

定义神经网络

定义神经网络是深度学习模型开发的核心步骤之一。一个良好定义的神经网络可以有效地学习和泛化数据,从而在各种任务中取得优异的表现。
本文不过度讨论神经网络,只是抛砖引玉,让大家知道结构长什么样子。

# 定义一个简单的神经网络
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.linear = nn.Linear(1, 1)  # 输入和输出都是1维def forward(self, x):return self.linear(x)

定义损失函数和优化器

# 实例化模型、损失函数和优化器
model = SimpleNet()
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器

损失函数用于衡量模型预测值与真实值之间的差异。它是模型优化的目标,模型训练的目的是最小化损失函数的值

优化器用于更新模型参数,以最小化损失函数。

训练模型

对于有限的数据,我们可以通过增加训练次数来优化模型。所以下面代码,我们对一个数据集进行了20次训练。

num_epochs = 20
for epoch in range(num_epochs):for inputs, targets in dataloader:optimizer.zero_grad()  # 清空梯度outputs = model(inputs)  # 前向传播loss = criterion(outputs, targets)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

后面会单独开一篇文章来将训练过程。

推理

我们再生成一个输入数组,并计算期望的值。

# 生成一组测试输入数据
test_input = torch.randn(10, 1)  # 生成10个随机样本,每个样本1个特征
test_output = 2 * test_input + 1  # 期望的输出

然后用模型算出结果,并进行比较

# 推理
model.eval()
with torch.no_grad():output = model(test_input)for i in range(len(test_input)):print(f'''Test Input: {test_input[i].item()}, Test Output: {output[i].item()}, Actual Output: {test_output[i].item()}, Diff: {output[i].item() - test_output[i].item()}, Loss: {abs(output[i].item() - test_output[i].item()) / abs(test_output[i].item())* 100:.2f}%\n''')

在这里插入图片描述
在这里插入图片描述
我们看到,模型最后推理出的结果和我们的期望值误差在2%以内。

参考代码

https://github.com/f304646673/deeplearning/tree/main/mvp

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1542901.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

C++入门基础知识80(实例)——实例5【查看 int, float, double 和 char 变量大小】

成长路上不孤单😊😊😊😊😊😊 【14后😊///C爱好者😊///持续分享所学😊///如有需要欢迎收藏转发///😊】 今日分享关于C 实例 【查看 int, float, double 和 c…

vue源码分析(九)—— 合并配置

文章目录 前言1.vue cli 创建一个基本的vue2 项目2.将mian.js文件改成如下3. 运行结果及其疑问? 一、使用 new Vue 创建过程的 2 种场景二、margeOption的详细说明1.margeOption的方法地址2.合并策略的具体使用3.defaultStrat 默认策略方法 三:以生命周期…

9.sklearn-K-means算法

文章目录 环境配置(必看)头文件引用K-means算法1.简介2.API3.代码工程4.运行结果5.模型评估6.小结优缺点 环境配置(必看) Anaconda-创建虚拟环境的手把手教程相关环境配置看此篇文章,本专栏深度学习相关的版本和配置&…

idea使用spring initializr快速创建springboot项目

idea使用spring initializr快速创建springboot项目 1.打开idea,新建项目如图,选择好java版本,我这里是17。2.点击next,首先选择springboot版本,我这里选择3.3.4。勾选springweb,它会帮我们下载关于springmv…

【machine learning-14-特征缩放-归一化】

特征缩放是提升线性回归收敛速度的技巧,什么是特征缩放? 又是什么场景下需要特征缩放,有哪些特征缩放的方法呢? 特征值差异 我们还是以之前房间预测为例: 这里面是特征房屋大小 房间数目 与房价的关系 本文为简化…

数据处理与统计分析篇-day03-python数据分析介绍与环境搭建

概述 python优势 Python作为当下最为流行的编程语言之一 可以独立完成数据分析的各种任务 数据分析领域里有海量开源库 机器学习/深度学习领域最热门的编程语言 在爬虫,Web开发等领域均有应用 常用开源库 numpy NumPy(NumericalPython) 是 Python 语言的一…

#面试系列-腾讯后端一面

03.腾讯后端一面 项目相关 面试官可能是 Go 方向的,我面试的是 Java 方向的,所以面试官也没有问我简历上的项目,主要问了实验室中做的项目,哪个项目比较有技术挑战? 面试主要问了计算级网络相关,以及如果让…

通信工程学习:什么是TLS传输层安全协议

TLS:传输层安全协议 TLS(Transport Layer Security)传输层安全协议是一种用于在两个通信应用程序之间提供保密性、数据完整性以及真实性的安全协议。它是SSL(Secure Sockets Layer)协议的后继者,继承并增强…

数据结构与算法——Java实现 8.习题——移除链表元素(值)

祝福你有前路坦途的好运,更祝愿你能保持内心光亮 纵有风雨,依然选择勇敢前行 —— 24.9.22 203. 移除链表元素 给你一个链表的头节点 head 和一个整数 val ,请你删除链表中所有满足 Node.val val 的节点,并返回 新的头节点 。 示…

黎巴嫩BP机爆炸事件启示录:我国应加快供应链安全立法

据报道,当地时间9月17日下午,黎巴嫩首都贝鲁特以及黎巴嫩东南部和东北部多地都发生了BP机爆炸事件。当时的统计数据显示,爆炸造成9人死亡,约2800人受伤。9月18日,死亡人数上升到11人,受伤人数超过4000。 目…

计算机毕业设计 基于 Hadoop平台的岗位推荐系统 SpringBoot+Vue 前后端分离 附源码 讲解 文档

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点…

知乎:从零开始做自动驾驶定位; 注释详解(二)

这个个系统整体分为: 数据预处理 前端里程计 后端优化 回环检测 显示模块。首先来看一下数据预处理节点做的所有事情: 数据预处理节点 根据知乎文章以及代码我们知道: 节点功能输入输出数据预处理1.接收各传感器信息2.传感器数据时间同步 3.点云运动畸变补偿 4.传…

c++类与对象一

C类与对象(一) 面向对象初步认识 在c语言中,编程是面向过程编程,注重求解问题列出过程,然后调用函数求解问题。 在日常生活中。我们经常会遇到面向过程的问题 手洗衣服就是面向过程 而C是基于面向对象的。关注的是对象,把事情…

html实现TAB选项卡切换

<!DOCTYPE html> <html> <head> <title>选项卡示例</title> <style> .tabs { overflow: hidden; /* 防止选项卡溢出容器 */ border: 1px solid #ccc; background-color: #f1f1f1; } .tab-links { margin: 0; padding: 0; l…

DataX-Web项目的Windows环境部署及基本使用

一,datax-web是什么? DataX Web 是一个在 DataX 基础上开发的分布式数据同步工具,它提供了一个简单易用的操作界面,旨在降低用户使用 DataX 的学习成本,缩短任务配置时间,并减少配置过程中的错误。DataX Web 支持多种数据源,包括 RDBMS、Hive、HBase、ClickHouse、Mongo…

yarn : 无法加载文件 C:\Users\Rog\AppData\Roaming\npm\yarn.ps1,因为在此系统上禁止运行脚本

yarn : 无法加载文件 C:\Users\Rog\AppData\Roaming\npm\yarn.ps1&#xff0c;因为在此系统上禁止运行脚本 设置命令行窗口默认以管理员身份运行&#xff0c;在此基础上输入以下代码&#xff0c;应该就好使了&#xff0c;切记&#xff0c;以下代码才是关键&#xff0c;我基本上…

<刷题笔记> 力扣236题——二叉树的公共祖先

236. 二叉树的最近公共祖先 - 力扣&#xff08;LeetCode&#xff09; 题目解释&#xff1a; 我们以这棵树为例&#xff0c;来观察找不同的最近公共祖先有何特点&#xff1a; 思路一&#xff1a; 除了第二种情况&#xff0c;最近公共祖先满足&#xff1a;一个节点在他的左边&am…

犀牛数据爬虫逆向分析

目标网站 aHR0cHM6Ly93d3cueGluaXVkYXRhLmNvbS9pbmR1c3RyeS9uZXdlc3Q/ZnJvbT1kYXRh 一、抓包分析 请求参数和响应数据都有加密 二、逆向分析 1、请求参数 请求参数生成位置 数据解密涉及到一个异步栈 解密后的数据形式 剩下的就是扣取代码了&#xff0c;很简单&#xff0c;…

Class path contains multiple SLF4J bindings.

最近由于要改kafka成datahub&#xff0c;于是在pom文件上引入了 <dependency><groupId>com.aliyun.datahub</groupId><artifactId>aliyun-sdk-datahub</artifactId><version>2.25.1</version> </dependency> 然后让我去测试…

Linux 进程间通信(管道)

目录 一.理解进程间通信 1.进程间通信的意义 2.进程间如何实现通信呢&#xff1f; 二.匿名管道 1.匿名管道的底层原理 引用计数的应用 2.匿名管道代码实现 a.代码的整体框架 b.写接口 c.读接口 d.子进程资源回收 3.匿名管道的官方接口 4.*匿名管道四种情况和五种特…