动手学深度学习(pytorch土堆)-06损失函数与反向传播、模型训练、GPU训练

在这里插入图片描述

模型保存与读取

完整模型训练套路

import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom model import *train_data=torchvision.datasets.CIFAR10("data_nn",train=True,transform=torchvision.transforms.ToTensor(),download=True )
test_data=torchvision.datasets.CIFAR10("data_nn",train=False,transform=torchvision.transforms.ToTensor(),download=True )
train_data_size=len(train_data)
test_data_size=len(test_data)
print(test_data_size,train_data_size)train_dataloader=DataLoader(train_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)
#搭建神经网络XKK=xkkk()
#损失函数
loss_fn=nn.CrossEntropyLoss()
#优化器
learning_rate=0.01
optimizer=torch.optim.SGD(XKK.parameters(),lr=learning_rate)
#设置训练网络的一些参数
total_test_step=0
#训练的轮数
epoch=10
#添加tensorboard
writer=SummaryWriter("logs_train")
for i in range(epoch):print("----第{}轮训练开始----".format(i+1))for data in train_dataloader:imgs,targets=dataoutputs=XKK(imgs)loss=loss_fn(outputs,targets)#优化器调优,优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_test_step=total_test_step+1if total_test_step%100==0:print("训练次数:{},loss:{}".format(total_test_step,loss))writer.add_scalar("train_loss",loss.item(),total_test_step)#测试步骤开始total_test_loss=0with torch.no_grad():for data in test_dataloader:imgs,targets=dataoutputs=XKK(imgs)loss=loss_fn(outputs,targets)total_test_loss=total_test_loss+lossprint("整体测试集上的Loss:{}".format(total_test_loss))writer.add_scalar("test_loss",total_test_loss,total_test_step)total_test_step=total_test_step+1
writer.close()

在这里插入图片描述

使用GPU训练

import timeimport torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter# from model import *train_data=torchvision.datasets.CIFAR10("data_nn",train=True,transform=torchvision.transforms.ToTensor(),download=True )
test_data=torchvision.datasets.CIFAR10("data_nn",train=False,transform=torchvision.transforms.ToTensor(),download=True )
train_data_size=len(train_data)
test_data_size=len(test_data)
print(test_data_size,train_data_size)train_dataloader=DataLoader(train_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)
#搭建神经网络class xkkk(torch.nn.Module):def __init__(self):super(xkkk,self).__init__()self.model1=torch.nn.Sequential(Conv2d(3, 32, 5,1, padding=2),MaxPool2d(kernel_size=2),Conv2d(32, 32, 5, 1,padding=2),MaxPool2d(kernel_size=2),Conv2d(32, 64, 5,stride=1, padding=2),MaxPool2d(kernel_size=2),Flatten(),Linear(in_features=64*4*4, out_features=64),  # 1024=64*4*4,Linear(64, 10))def forward(self, x):output=self.model1(x)return output
XKK=xkkk()
XKK=XKK.cuda()
#损失函数
loss_fn=nn.CrossEntropyLoss()
loss_fn=loss_fn.cuda()
#优化器
learning_rate=0.01
optimizer=torch.optim.SGD(XKK.parameters(),lr=learning_rate)
#设置训练网络的一些参数
total_test_step=0
#训练的轮数
epoch=10
#添加tensorboard
writer=SummaryWriter("logs_train")
start_time=time.time()
for i in range(epoch):print("----第{}轮训练开始----".format(i+1))for data in train_dataloader:imgs,targets=dataimgs=imgs.cuda()targets=targets.cuda()outputs=XKK(imgs)loss=loss_fn(outputs,targets)#优化器调优,优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_test_step=total_test_step+1if total_test_step%100==0:end_time=time.time()print(end_time-start_time)print("训练次数:{},loss:{}".format(total_test_step,loss))writer.add_scalar("train_loss",loss.item(),total_test_step)#测试步骤开始total_test_loss=0with torch.no_grad():for data in test_dataloader:imgs,targets=dataimgs = imgs.cuda()targets = targets.cuda()outputs=XKK(imgs)loss=loss_fn(outputs,targets)total_test_loss=total_test_loss+loss.item()print("整体测试集上的Loss:{}".format(total_test_loss))writer.add_scalar("test_loss",total_test_loss,total_test_step)total_test_step=total_test_step+1torch.save(XKK,"XKK_{}.pth".format(i))print("模型已保存")
writer.close()

CPU训练时间如下
在这里插入图片描述

使用GPU训练时间如下
在这里插入图片描述
对比可知GPU训练速度提升5倍左右

GPU代码区CPU区别如下
在这里插入图片描述
只需要改动网络模型、数据、损失函数,调用它们对应的.cuda()。

方式2
在这里插入图片描述

#定义训练的设备
#device=torch.device("cpu")#使用cpu
device=torch.device("cuda")#使用GPU
....
XKK=xkkk()
XKK=XKK.to(device)
......imgs=imgs.to(device)targets=targets.to(device).....

测试

import torch
import torchvision.transforms
from PIL import Image
from torch.nn import Conv2d, MaxPool2d, Flatten, Linearimage_path="imgs/dog.png"image = (Image.open(image_path).convert("RGB"))
#image=Image.open(image_path)
transform=torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()])
image=transform(image)class xkkk(torch.nn.Module):def __init__(self):super(xkkk,self).__init__()self.model1=torch.nn.Sequential(Conv2d(3, 32, 5,1, padding=2),MaxPool2d(kernel_size=2),Conv2d(32, 32, 5, 1,padding=2),MaxPool2d(kernel_size=2),Conv2d(32, 64, 5,stride=1, padding=2),MaxPool2d(kernel_size=2),Flatten(),Linear(in_features=64*4*4, out_features=64),  # 1024=64*4*4,Linear(64, 10))def forward(self, x):output=self.model1(x)return output
model=torch.load("XKK_9.pth",map_location=torch.device("cpu"))
print(model)
print(image.size())image=torch.reshape(image,(1,3,32,32))
model.eval()
with torch.no_grad():output=model(image)
print(output)
print(output.argmax(1))

输入一张小狗图片

image_path="imgs/dog.png"结果
tensor([5])

在这里插入图片描述

在这里插入图片描述
输入一张飞机图片
在这里插入图片描述

image_path="imgs/airplane.png"
结果
tensor([0])

可以看出预测准确

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1540229.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

AV1 Bitstream Decoding Process Specification--[7]: 语法结构语义-3

原文地址:https://aomediacodec.github.io/av1-spec/av1-spec.pdf 没有梯子的下载地址:AV1 Bitstream & Decoding Process Specification摘要:这份文档定义了开放媒体联盟(Alliance for Open Media)AV1视频编解码…

分发饼干00

题目链接 分发饼干 题目描述 注意点 1 < g[i], s[j] < 2^31 - 1目标是满足尽可能多的孩子&#xff0c;并输出这个最大数值 解答思路 可以先将饼干和孩子的胃口都按升序进行排序&#xff0c;随后根据双指针 贪心&#xff0c;将当前满足孩子胃口的最小饼干分配给该孩…

再次理解UDP协议

一、再谈端口号 在 TCP / IP 协议中&#xff0c;用 "源 IP", "源端口号", "目的 IP", "目的端口号", "协议号" 这样一个五元组来标识一个通信(可以通过 netstat -n 查看) 我们需要端口号到进程的唯一性&#xff0c;所以一个…

Obsidian如何粘贴的图片类似于Typora,图片相对当前路径

添加插件 下载插件&#xff1a; Custom Attachment Location 基础设置 同时需要在下面进行设置 示意效果

大数据多集群数据作业和集群状态监控

目前手里面有四套大数据集群的作业需要维护&#xff0c;分别属于不同的客户&#xff0c;我同岗位的兄弟离职后&#xff0c;所有的数据作业都落到我头上了&#xff0c;公司也不招人了。开发新的数据作业倒没有什么问题&#xff0c;就是客户叫我补数的时候&#xff0c;头比较大&a…

Linux基础权限

Linux基础权限 shell的概念Linux基础权限Linux的两种用户Linux的权限管理权限认知权限设置权限掩码粘滞位 shell的概念 &#xff08;shell&#xff09;命令行解释器 的存在意义&#xff1a; 将用户的命令翻译给操作系统&#xff0c;然后返回OS的结果给用户&#xff1b;保护OS…

YOLOv5图像识别教程包成功-以识别桥墩缺陷详细步骤分享

前置环境资源下载 提示&#xff1a;要开外网才能下载的环境我都放在了网盘里&#xff0c;教程中用到的环境可从这里一并下载&#xff1a; https://pan.quark.cn/s/f0c36aa1ef60 1. 下载YOLOv5源码 官方地址&#xff1a;GitHub - ultralytics/yolov5: YOLOv5 &#x1f680; …

9.4 溪降技术:带包下降

目录 9.4 携包下降概述观看视频课程电子书&#xff1a;携包下降在瀑布中管理背包扔背包滑索传送背包固定到安全带 V7 提示&#xff1a;将背包固定到安全带总结 9.4 携包下降 概述 在水流和悬崖边缘携包下降是最危险的情况&#xff01; 正如我们之前所学&#xff0c;在峡谷探险中…

流程型制造业MES系统的特点及主要功能介绍

流程型MES系统的应用程度较高。特别是石油石化行业原有自动化和信息化的程度较高&#xff0c;一般应用在于生产管控&#xff0c;mes系统的应用主要目的是使得最容易出现产品质量的配料、投料以及乳化加工过程得到管控和追溯。 随着生产工艺发展&#xff0c;石化行业MES系统应用…

Java基础(中)

面向对象基础 面向对象和面向过程的区别 面向过程编程&#xff08;Procedural-Oriented Programming&#xff0c;POP&#xff09;和面向对象编程&#xff08;Object-Oriented Programming&#xff0c;OOP&#xff09;是两种常见的编程范式&#xff0c;两者的主要区别在于解决…

Java设计模式——工厂方法模式(完整详解,附有代码+案例)

文章目录 5.3 工厂方法模式5.3.1概述5.3.2 结构5.3.3 实现 5.3 工厂方法模式 针对5.2.3中的缺点&#xff0c;使用工厂方法模式就可以完美的解决&#xff0c;完全遵循开闭原则。 5.3.1概述 工厂方法模式&#xff1a;定义一个创建对象的接口&#xff08;这里的接口指的是工厂&…

逆向中巧遇MISC图片隐藏

这道题比较有意思&#xff0c;而且因为我对misc并不是很熟悉&#xff0c;发现该题目将flag隐藏在图片的颜色属性&#xff0c;巧妙的跟踪到这些密文位置&#xff0c;拿下题目一血&#xff0c;还是很有参考学习意义的。&#xff08;题目附件&#xff0c;私信发。&#xff09; 1、…

openstack 2023.2 Bobcat 本地安装部署

一、系统环境 rootodoo16e-server:~# cat /etc/lsb-release DISTRIB_IDUbuntu DISTRIB_RELEASE22.04 DISTRIB_CODENAMEjammy DISTRIB_DESCRIPTION"Ubuntu 22.04.5 LTS"rootodoo16e-server:~# python3 --version Python 3.10.12rootodoo16e-server:~# pip --version …

基于YOLOv5s的无人机航拍输电线瓷瓶检测(附数据集与操作步骤)

本文主要内容:详细介绍了无人机航拍输电线瓷瓶检测的整个过程&#xff0c;从创建数据集到训练模型再到预测结果全部可视化操作与分析。 文末有数据集获取方式&#xff0c;请先看检测效果 现状 输电线路绝缘瓷瓶的检测主要依赖人工巡检。巡检人员需携带专业设备&#xff0c;攀…

亿级数据表多线程update锁表问题

目录 1、问题描述 2、原因分析 3、问题解决 1、问题描述 在pg数据库&#xff0c;某个业务&#xff0c;有一张数据表test&#xff0c;数据表结果如下&#xff1a; test(sjjbh,wlbid,gzmb,sfzg,zgsj,cjsj,xx...)&#xff0c;这个表没有主键&#xff0c;会有很多重复数据。 tes…

Vue报错 ‘vite‘ 不是内部或外部命令,也不是可运行的程序或批处理文件

报错 vue-project0.0.0 dev vite‘vite’ 不是内部或外部命令&#xff0c;也不是可运行的程序 或批处理文件。解决 第1步. 控制台输入 npm install -g create-vite第2步. 控制台输入 npm install -g vite第3步. 运行就ok啦

【HTTP】方法(method)以及 GET 和 POST 的区别

文章目录 方法&#xff08;method&#xff09;登录上传GET 和 POST 有什么区别&#xff08;面试&#xff09;区别不准确的说法 方法&#xff08;method&#xff09; 首行中的第一部分。首行是由方法、URL 和版本号组成 方法描述了这次请求想干什么&#xff0c;最主要的是&…

13 vue3之内置组件keep-alive

内置组件keep-alive 有时候我们不希望组件被重新渲染影响使用体验&#xff1b;或者处于性能考虑&#xff0c;避免多次重复渲染降低性能。而是希望组件可以缓存下来,维持当前的状态。这时候就需要用到keep-alive组件。 开启keep-alive 生命周期的变化 初次进入时&#xff1a;…

基于SpringBoot+Vue的私人牙科诊所管理系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【2025最新】基于JavaSpringBootVueMySQL的…

MySQL函数介绍--日期与时间函数(一)

我相信大家在学习各种语言的时候或多或少听过我们函数或者方法这一类的名词&#xff0c;函数在计算机语言的使用中可以说是贯穿始终&#xff0c;那么大家有没有思考过到底函数是什么&#xff1f;函数的作用又是什么呢&#xff1f;我们为什么要使用函数&#xff1f;其实&#xf…