【PyTorch攻略(2/7)】 加载数据集

一、说明

        PyTorch提供了两个数据原语:torch.utils.data.DataLoadertorch.utils.data.Dataset,允许您使用预加载的数据集以及您自己的数据。数据集存储样本及其相应的标签,DataLoader 围绕数据集包装一个可迭代对象,以便轻松访问样本。

        PyTorch域库提供了许多示例预加载数据集,例如FashionMNIST,它子类torch.utils.data.Dataset并实现特定于特定数据的函数。可以在此处找到它们并用作原型设计和基准测试模型的示例:

  • 图像数据集
  • 文本数据集
  • 音频数据集

二、加载数据集

        我们将从TorchVision加载FashionMNIST数据集。FashionMNIST是Zalando文章图像的数据集,由60,000个训练示例和10,000个测试示例组成。每个示例包含一个 28x28 灰度图像和一个来自 10 个类之一的相关标签。

  • 每张图片高 28 像素,宽 28 像素,共 784 像素。
  • 这 10 个类告诉它是什么类型的图像。例如,T型短裤/上衣,裤子,套头衫,连衣裙,包,踝靴等。
  • 灰度是介于 0 到 255 之间的值,用于测量黑白图像的强度。强度值从白色增加到黑色。例如,白色为 0,黑色为 255。

        我们使用以下参数加载 FashionMNIST 数据集:

  • 是存储训练/测试数据的路径。
  • 训练指定训练或测试数据集。
  • download = 如果数据在根目录中不可用,则 True 从互联网下载数据。
  • 转换指定特征和标注转换
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)

三、迭代和可视化数据集

我们可以像列表一样手动索引数据集:training_data[index]。我们使用 matplotlib 来可视化训练数据中的一些样本。

labels_map = {0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(training_data), size=(1,)).item()img, label = training_data[sample_idx] # Iterate training datafigure.add_subplot(rows, cols, i)plt.title(labels_map[label])plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

四、准备数据以使用数据加载程序进行训练

        数据集检索数据集的特征并一次标记一个样本。在训练模型时,我们通常希望以“小批量”方式传递样本,在每个时期重新洗牌数据以减少模型过度拟合,并使用 Python 的多处理来加速数据检索。

在机器学习中,需要指定数据集中的特征和标签。要素是输入,标注是输出。我们训练特征并训练模型来预测标签。

        DataLoader 是一个迭代对象,它在一个简单的 API 中为我们抽象了这种复杂性。要使用数据加载器,我们需要设置以下参数:

  • 数据是将用于训练模型的训练数据;以及用于评估模型的测试数据。
  • 批大小是每个批中要处理的记录数。
  • 随机播放是按索引随机抽取的数据。
from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

五、遍历数据加载器

        我们已将数据集加载到数据加载器中,并可以根据需要循环访问数据集。下面的每次迭代都会返回一批train_featurestrain_labels(分别包含 batch_size = 64 个要素和标注)。由于我们指定了 shuffle = True,因此在我们遍历所有批次后,数据将被洗牌。

# Display image and label
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

        NOrmalization是一种常见的数据预处理技术,用于缩放或转换数据,以确保每个特征的学习贡献相等。例如,灰度图像中的每个像素都有一个介于 0 到 255 之间的值,这些值是特征。如果一个像素值为 17,另一个像素值为 197。像素重要性的分布将不均匀,因为较高的像素体积会偏离学习。归一化会更改数据的范围,而不会扭曲其在我们的功能之间的区别。进行此预处理是为了避免:

  • 预测精度降低
  • 模型学习的难度
  • 特征数据范围的不利分布

六、变换

数据并不总是以训练机器学习算法所需的最终处理形式出现。我们使用 transform 对数据执行一些操作,并使其适合训练。

所有 TorchVision 数据集都有两个参数:用于修改特征的转换和用于修改接受包含转换逻辑的可调用对象的标签target_transformtorchvision.transform模块提供了几种开箱即用的常用变换。

FashionMNIST 功能采用 PIL 图像格式,标签为整数。对于训练,我们需要特征作为规范化张量,标签作为独热编码张量。为了进行这些转换,我们使用ToTensorLamda

from torchvision import datasets
from torchvision.transforms import ToTensor, Lambdads = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

七、ToTensor()

ToTensor 将 PIL 图像或 NumPy ndarray 转换为 FloatTensor,并将图像的像素强度值缩放在 [0., 1.] 范围内。

八、Lambda()

        Lambda 应用任何用户定义的 lambda 函数。在这里,我们定义了一个函数来将整数转换为独热编码张量。它首先创建一个大小为 10(我们数据集中的标签数量)的张量并调用 scatter,它在索引上分配一个 value=1,如标签 y 给出的那样。您也可以将torch.nn.functional.one_hot用作其他选项。

target_transform = Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

下一>> PyTorch 简介 (3/7)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/139848.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

C++ 共享内存相关的API

C 共享内存相关的API 1.什么是共享内存1.共享内存的概念2.共享内存的原理3.共享内存使用注意点 2.共享内存有关API的操作函数及示例1.新建共享内存-shmget2.连接共享内存到当前的地址空间-shnat3.当前进程分离共享内存shmdt4.控制共享内存-shmctl5.共享内存操作示例 3.共享内存…

Node2Vec实战---《悲惨世界》人物图嵌入

1. pip各个包后导入 import networkx as nx # 图数据挖掘 import numpy as np # 数据分析 import random # 随机数# 数据可视化 import matplotlib.pyplot as plt %matplotlib inline plt.rcParams[font.sans-serif][SimHei] # 用来正常显示中文标签 plt.rcParams[axes.uni…

定制SQLmap和WAF绕过

1. SQLmap tamper 脚本编写 以sqli-lab第26关为例 输入?id1’ --,报错字符型注入 考虑闭合问题,输入?id1’ and 1,但是回显中and和空格消失了,可知and和空格被过滤了 因为and和or被过滤考虑使用双写绕过手段,空格使…

Linux常用命令—find命令大全

文章目录 一、find命令常用功能1、find命令的基本信息如下。2、按照文件名搜索3、按照文件大小搜索4、按照修改时间搜索5、按照权限搜索举例:6、按照所有者和所属组搜索7、按照文件类型搜索8、逻辑运算符 一、find命令常用功能 1、find命令的基本信息如下。 命令名…

java:java.util.MissingResourceException: Cant find bundle for base name解决方式

java:java.util.MissingResourceException: Cant find bundle for base name解决方式 1 前言 代码执行如下: ResourceBundle.getBundle("res.Message",Locale.getDefault(), ReadMyProps.class.getClassLoader());或 ResourceBundle.getBu…

(25)(25.1) 光学流量传感器的测试和设置

文章目录 25.1.1 测试传感器 25.1.2 校准传感器 25.1.3 测距传感器检查 25.1.4 预解锁检查 25.1.5 首次飞行 25.1.6 第二次飞行 25.1.7 正常操作设置 25.1.8 视频示例(Copter-3.4) 25.1.9 空中校准 25.1.1 测试传感器 将传感器连接至自动驾驶仪…

ROS2 从头开始:第 8 部分 - 使用 ROS2 生命周期节点简化机器人软件组件管理

一、说明 欢迎来到我在 ROS2 上的系列的第八部分。对于那些可能不熟悉该系列的人,我已经涵盖了一系列主题,包括 ROS2 简介、如何创建发布者和订阅者、自定义消息和服务创建、组合和执行器以及 DDS 和 QoS 配置。如果您还没有机会查看以前的帖子&#xff…

2023华为杯数学建模D题第三问——区域双碳目标情景设计样例

在第二问建立好预测模型的基础上,如何设计第三问所说的区域双碳路径,以对宏观政策进行指导! 采用STIRPA的基本模型对中国碳达峰时间进行预测,对该模型公式两边取对数得到: 其中:P为人口,A为GDP…

异常记录-VS

1.文件加载失败 无法找到指定路径 Frame GUID: a6c744a8-0e4a-4fc6- 886a-064283054674 Frame mode: VSFM_ MdiChild Error code: 0x80131515 未理会这个提示,可以打开运行项目,只是会跳出这个弹窗。 无法关闭这个异常的窗口。

十四、MySql的用户管理

文章目录 一、用户管理二、用户(一)用户信息(二)创建用户1.语法:2.案例: (三) 删除用户1.语法:2.示例: (四)修改用户密码1.语法&#…

【MT7628AN】IOT | MT7628AN OpenWRT开发与学习

IOT | MT7628AN OpenWRT开发与学习 时间:2023-06-21 文章目录 `IOT` | `MT7628AN` `OpenWRT`[开发与学习](https://blog.csdn.net/I_feige/article/details/132911634?csdn_share_tail=%7B%22type%22%3A%22blog%22%2C%22rType%22%3A%22article%22%2C%22rId%22%3A%22132911634…

【C语言】指针的进阶(四)—— 企业笔试题解析

笔试题1: int main() {int a[5] { 1, 2, 3, 4, 5 };int* ptr (int*)(&a 1);printf("%d,%d", *(a 1), *(ptr - 1));return 0; } 【答案】在x86环境下运行 【解析】 &a是取出整个数组的地址,&a就表示整个数组,因此…

vue项目打包优化

首先第一步通过浏览器看首次加载的问题大小,时间跨度等方面入手 1. Coverage观察 Coverage是chrome开发者工具的一个新功能,从字面意思上可以知道它是可以用来检测代码在网站运行时有哪些js和css是已经在运行,而哪些js和css是还没有用到的&a…

WEB使用VUE3实现地图导航跳转

我们在用手机查看网页时可以通过传入经纬度去设置目的地然后跳转到对应的地图导航软件,如果没有下载软件则会跳转到下载界面 注意: 高德地图是一定会跳转到一个新网页然后去询问用户是否需要打开软件百度和腾讯地图是直接调用软件的这个方法有缺陷&…

【c#-Nuget 包“在此源中不可用”】 Nuget package “Not available in this source“

标题c#-Nuget 包“在此源中不可用”…但 VS 仍然知道它吗? (c# - Nuget package “Not available in this source”… but VS still knows about it?) 背景: 今日从公司svn 上拉取很久很久以前的代码,拉取下来200报错,进一步发…

【机器学习】文本多分类

声明:这只是浅显的一个小试验,且借助了AI。使用的是jupyter notebook,所以代码是一块一块,从上往下执行的 知识点:正则删除除数字和字母外的所有字符、高频词云、混淆矩阵 参考:使用python和sklearn的中文文本多分类…

资源分享 | 情绪脑电研究公开数据集

SEED SEED数据集是由上海交大类脑计算与机器智能研究中心(BCMI)开发的。该数据集是基于脑电的情绪分类任务而设计的数据集。该数据集记录了15名被试在观看积极、中性和消极情绪电影片段时的EEG信号,每个视频片段的时间为3-5分钟。每个参与者重复采集三天&#xff0…

多分类中混淆矩阵的TP,TN,FN,FP计算

关于混淆矩阵,各位可以在这里了解:混淆矩阵细致理解_夏天是冰红茶的博客-CSDN博客 上一篇中我们了解了混淆矩阵,并且进行了类定义,那么在这一节中我们将要对其进行扩展,在多分类中,如何去计算TP&#xff0…

IDEA中创建Java Web项目方法1

以下过程使用IntelliJ IDEA 2021.3 一、File-> New -> Project... 1. 项目类型中选择 Java Enterprise 项目 2. Name:填写自己的项目名称 3. Project template:选择项目的模板,Web application。支持JSP和Servlet的项目 4. Applica…

Nginx location 精准匹配URL = /

Location是什么? Location是Nginx中的块级指令(block directive),通过配置Location指令块,可以决定客户端发过来的请求URI如何处理(是映射到本地文件还是转发出去)及被哪个location处理。 匹配模式 分为两种模式&…