完整网络模型训练(一)

文章目录

    • 一、网络模型的搭建
    • 二、网络模型正确性检验
    • 三、创建网络函数

一、网络模型的搭建

以CIFAR10数据集作为训练例子

准备数据集:

#因为CIFAR10是属于PRL的数据集,所以需要转化成tensor数据集
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)

查看数据集的长度:

train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为{train_data_size}")
print(f"测试数据集的长度为{test_data_size}")

运行结果:
在这里插入图片描述

利用DataLoader来加载数据集:

train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)

搭建CIFAR10数据集神经网络:
在这里插入图片描述
卷积层【1】代码解释:
#第一个数字3表示inputs(可以看到图中为3),第二个数字32表示outputs(图中为32)
#第三个数字5为卷积核(图中为5),第四个数字1表示步长(stride)
#第五个数字表示padding,需要计算,计算公式:
在这里插入图片描述

nn.Conv2d(3, 32, 5, 1, 2)

最大池化代码解释:
#数字2表示kernel卷积核

nn.MaxPool2d(2)

读图
卷积层【1】的Inputs 和 Outputs是下图这两个:
在这里插入图片描述

最大池化【1】的Inputs 和 Outputs是下图这两个:
在这里插入图片描述
卷积层【2】的Inputs 和 Outputs是下图这两个:
在这里插入图片描述
以此类推

展平:
在这里插入图片描述
Flatten后它会变成64*4 *4的一个结果

线性输出:
在这里插入图片描述
线性输入是64*4 *4,线性输出是64,故如下代码
nn.LInear(64 *4 *4,64)

继续线性输出
在这里插入图片描述
nn.LInear(64,10)

搭建网络完整代码:

class Sen(nn.Module):def __init__(self):super(Sen, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1 ,2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self,x):x = self.model(x)return x

二、网络模型正确性检验

if __name__ == '__main__':sen = Sen()input = torch.ones((64, 3, 32, 32))output = sen(input)print(output.shape)

注释:

input = torch.ones((64, 3, 32, 32))

这一行代码的含义是:创建一个大小为 (64, 3, 32, 32) 的全 1 张量,数据类型为 torch.float32。
64:这是批次大小,代表输入有 64 张图片。
3:这是图片的通道数,通常为 RGB 图像的三个通道 (红、绿、蓝)。
32, 32:这是图片的高和宽,表示每张图片的尺寸为 32x32 像素。
torch.ones 函数用于生成一个全 1 的张量,这里的张量形状适合用于输入图像分类或卷积神经网络(CNN)中常见的 CIFAR-10 或类似的 32x32 像素图像数据。

运行结果:
在这里插入图片描述
可以得到成功变成了【64, 10】的结果。

三、创建网络函数

创建网络模型:

sen = Sen()

搭建损失函数:

loss_fn = nn.CrossEntropyLoss()

优化器:

learning_rate = 1e-2
optimizer = torch.optim.SGD(sen.parameters(), lr=learning_rate)

优化器注释:
使用随机梯度下降(SGD)优化器
learning_rate = 1e-2 这里的1e-2代表的是:1 x (10)^(-2) = 1/100 = 0.01

记录训练的次数:

total_train_step = 0

记录测试的次数:

total_test_step = 0

训练的轮数:

epoch= 10

进行循环训练:

for i in range(epoch):print(f"第{i+1}轮训练开始")for data in train_dataloader:imgs, targets = dataoutputs = sen(imgs)loss = loss_fn(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1print(f"训练次数:{total_train_step},Loss:{loss.item()}")

注释:
imgs, targets = data是解包数据,imgs 是输入图像,targets 是目标标签(真实值)
outputs = sen(imgs)将输入图像传入模型 ‘sen’,得到模型的预测输出 outputs
loss = loss_fn(outputs, targets)计算损失值(Loss),loss_fn 是损失函数,它比较outputs的值与targets 是目标标签(真实值)的误差
optimizer.zero_grad()清除优化器中上一次计算的梯度,以免梯度累积
loss.backward()反向传播,计算损失相对于模型参数的梯度
optimizer.step()使用优化器更新模型的参数,以最小化损失
loss.item() 将张量转换为 Python 的数值
loss.item演示:

import torch
a = torch.tensor(5)
print(a)
print(a.item())

运行结果:
在这里插入图片描述
因此可以得到:item的作用是将tensor变成真实数字5

本章节完整代码展示:

import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoaderclass Sen(nn.Module):def __init__(self):super(Sen, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1 ,2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self,x):x = self.model(x)return x
#准备数据集
#因为CIFAR10是属于PRL的数据集,所以需要转化成tensor数据集
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)#length长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为{train_data_size}")
print(f"测试数据集的长度为{test_data_size}")train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)sen = Sen()#损失函数
loss_fn = nn.CrossEntropyLoss()#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(sen.parameters(), lr=learning_rate)#记录训练的次数
total_train_step = 0
#记录测试的次数
total_test_step = 0
#训练的轮数
epoch= 10for i in range(epoch):print(f"第{i+1}轮训练开始")for data in train_dataloader:imgs, targets = dataoutputs = sen(imgs)loss = loss_fn(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1print(f"训练次数:{total_train_step},Loss:{loss.item()}")

运行结果:
在这里插入图片描述
可以看到训练的损失函数在一直进行修正。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1551540.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

养猪场饲料加工机械设备有哪些

养猪场饲料加工机械设备主要包括以下几类:1‌、粉碎机‌:主要用于将原料进行粉碎,以便与其他饲料原料混合均匀。常见的粉碎机有水滴式粉碎机和立式粉碎机两种,用户可以根据原料的特性选择适合的机型。2‌、搅拌机‌:用…

Element UI教程:如何将Radio单选框的圆框改为方框

大家好,今天给大家带来一篇关于Element UI的使用技巧。在项目中,我们经常会用到Radio单选框组件,默认情况下,Radio单选框的样式是圆框。但有时候,为了满足设计需求,我们需要将圆框改为方框,如下…

【新闻转载】Storm-0501:勒索软件攻击扩展到混合云环境

icrosoft发出警告,勒索软件团伙Storm-0501近期调整了攻击策略,目前正将目标瞄准混合云环境,旨在全面破坏受害者的资产。 该威胁行为者自2021年首次露面,起初作为Sabbath勒索软件行动的分支。随后,他们开始分发来自Hive…

华为OD机试 - 积木最远距离(Python/JS/C/C++ 2024 E卷 100分)

华为OD机试 2024E卷题库疯狂收录中,刷题点这里 专栏导读 本专栏收录于《华为OD机试真题(Python/JS/C/C)》。 刷的越多,抽中的概率越大,私信哪吒,备注华为OD,加入华为OD刷题交流群,…

MQTT--EMQX入门+MQTTX使用

目录 1、什么是EMQX?1.1 EMQX介绍1.2 EMQX特点1.3 与物联网之间的关系以及主要的产品主要的产品 2、安装启动2.1 基本命令2.2 目录结构 3、MQTTX客户端3.1 连接配置 总结PS: 1、什么是EMQX? 首先你得有MQTT的知识,不认识MQTT的小伙伴可以先看…

如何初步部署自己的服务器,达到生信分析的及格线2(待更新)

参考我的上一篇博客https://blog.csdn.net/weixin_62528784/article/details/142621762?spm1001.2014.3001.5501, 现在我们已经有了一个能够跑一些基础任务的、基本没有配置的服务器了,接下来要做的任务就是: (1)进一…

单片机毕业设计选题基于单片机的炉温度控制系统

** 文章目录 前言概要功能设计设计思路效果图 程序文章目录 前言 💗博主介绍:✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师,一名热衷于单片机技术探索与分享的博主、专注于 精通51/STM32/MSP430/AVR等单片机设计 主要对象是咱们…

深拷贝浅拷贝 JS代码实现

文章目录 JS数据类型深拷贝 & 浅拷贝赋值和浅拷贝的区别浅拷贝(Shallow Copy)代码实现 深拷贝(Deep Copy)代码实现 Map & WeakMap示例 WeakMap 和垃圾回收weakmap处理循环引用 typepf & instanceof JS数据类型 基本数…

满填充透明背景二维码生成

前几天项目上线的时候发现一个问题:通过Hutool工具包生成的二维码在内容较少时无法填满(Margin 已设置为 0)给定大小的图片。因此导致前端在显示二维码时样式异常。 从图片中我们可以看到,相同大小的图片,留白内容是不一样的。其中上半部分…

dwceqos网络驱动性能优化

文章介绍 本文会介绍优化QNX系统下io-pkt-v6-hc驱动模块cpu loading过高问题,经过优化可以降低约一半的cpu loading. 问题背景 激光雷达通过以太网发送数据到ADAS域控中,测试发现在激光功能激活的情况下,会出现比较明显的网络丢帧现象。 …

平安养老险深圳分公司积极开展“金融教育宣传月”活动,展现金融为民新风尚

2024年9月,平安养老险深圳分公司以“金融为民谱新篇,守护权益防风险”为主题,正式启动2024年“金融教育宣传月”活动,通过多样化开展进乡村、进商圈、进企业等宣传教育活动,将金融消保知识送达广大消费者身边&#xff…

光通信——PON技术

PON网络结构 PON(Passive Optical Network,无源光网络)系统的基本组成包括OLT(Optical Line Terminal,光线路终端)、ODN(Optical Distribution Network,光分配单元)和ON…

数据结构——队列的基本操作

前言 介绍 🍃数据结构专区:数据结构 参考 该部分知识参考于《数据结构(C语言版 第2版)》24~28页 🌈每一个清晨,都是世界对你说的最温柔的早安:ૢ(≧▽≦)و✨ 目录 前言 1、队列的基本概念…

Oracle 闪回版本(闪回表到指定SCN)

1.创建目录 mkdir /u01/app/oracle/flash 2.配置FRA alter system set db_recovery_file_dest_size15G; alter system set db_recovery_file_dest/u01/app/oracle/flash; 3.设置闪回参数--确保可以闪回48h内的数据库 alter system set db_flashback_retention_target2880; 4…

中关村环球时尚产业联盟 东晟时尚产业创新中心成立

2024年9月6日,中关村环球时尚产业联盟与东晟时尚创新科技(北京)有限公司于中关村科技园东城园举行了隆重的战略合作签约仪式。 中关村科技园东城园领导发表了致辞,并表示东城区作为首都北京的核心区域,拥有深厚的历史…

SW - 装配图旋转到一个想要的正视图

文章目录 SW - 装配图旋转到一个想要的正视图概述笔记将装配图旋转到自己想要的视图的方法保存当前视图选择自己保存的视图END SW - 装配图旋转到一个想要的正视图 概述 在弄装配图。 如果按照SW默认的视图,Y方向是反的。 原因在于我画零件图时,方向就…

从“抄袭”到“原创”:5个超实用的论文降重技巧!

AIPaperGPT,论文写作神器~ https://www.aipapergpt.com/ 每当写完一篇论文,松了一口气准备庆祝时,突然想到还有一个名叫“查重”的终极大Boss等着你,瞬间心情从云端跌入谷底。 是不是你? 很多同学在提交之前&#…

fatfs API使用手册

配置 /*---------------------------------------------------------------------------/ / Configurations of FatFs Module /---------------------------------------------------------------------------*/#define FFCONF_DEF 80286 /* Revision ID *//*---------------…

Spring IoC笔记

目录 1.什么是 IoC? 2.IoC类注解(五大注解) 2.1那为什么要这么多类注解? 2.2五大注解是不是可以混用? 2.3程序被spring管理的条件是? 3.bean对象 3.1Bean 命名约定 3.2获取bean对象 4.⽅法注解 B…

业绩由盈转亏,全面冲刺大模型的360值得期待吗?

在中国互联网市场上,360无疑是一家大家家喻户晓的公司,从安全软件起家,360的服务已经延展到了市场的方方面面,就在最近360的财报正式公布,很多人都在问360的财报该怎么看?全面冲刺大模型的360值得我们期待吗…