模型的深度优化

文章目录

    • 一、测试模型是否正确
    • 二、图形打印直观观察
    • 三、保存训练模型
    • 四、正确率(仅使用于分类问题)

一、测试模型是否正确

本文承接我的上一篇文章完整网络模型训练(一)
运用测试数据集(test_dataloader)进行检验

total_test_loss = 0with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = sen(imgs)#这里的loss是一个tensor数据类型loss = loss_fn(outputs, targets)#加上item进行转换成整形total_test_loss = total_test_loss + loss.item()print(f"整体测试集上的Loss{total_test_loss}")

注释:
with torch.no_grad()使用 no_grad() 上下文,禁用梯度计算,以节省内存和加快计算速度(因为在测试阶段不需要反向传播)

在pycharm的运行框中按下crtl+f可以弹出搜索框
运行结果:
在这里插入图片描述
可以看到确实有整体数据集的信息,但是被很多乱七八遭的信息给掩盖了,所以可以改善一下代码

 if total_train_step % 100 == 0:print(f"训练次数:{total_train_step},Loss:{loss.item()}")

每当步骤为整百的时候才进行打印,这样子输出结果能够整洁一点:
在这里插入图片描述

二、图形打印直观观察

也可以用tensorboard进行画图观察,补上一下代码:

writer = SummaryWriter("./logs_train")
 if total_train_step % 100 == 0:print(f"训练次数:{total_train_step},Loss:{loss.item()}")writer.add_scalar("train_loss", loss.item(), total_train_step)
print(f"整体测试集上的Loss{total_test_loss}")writer.add_scalar("test_loss", total_test_loss,total_test_step)total_test_step = total_test_step + 1

运行结果:
在这里插入图片描述
在这里插入图片描述
可以直观的看到训练的模型的loss损失函数在不断的下降。

三、保存训练模型

还可以保存每一轮训练的模型:

torch.save(sen, f"sen_{i}.pth")print("模型已保存")

使用 torch.save() 将模型 ‘sen’ 保存到文件中,文件名为 “sen_{i}.pth”

四、正确率(仅使用于分类问题)

哪怕我们已经得到整体测试数据集上的Loss,也不能很好的说明数据集实际上的表现效果
在分类问题中我们可以用正确率进行表示

    total_accuracy = 0
accuracy = (outputs.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracy

使用 writer.add_scalar() 方法,将标量数据记录到日志文件中,用于在 TensorBoard 中可视化
“test_accuracy” 是图表的名称,用于表示测试集的准确率
total_accuracy/test_data_size 表示计算出的测试集上的准确率(总正确预测数 / 测试集数据总量)
total_test_step 表示当前测试步骤,用作 X 轴上的标记,表明该数据是在第几次测试时记录的

    writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)

运行结果:
在这里插入图片描述
可以看到计算处理整体数据集上的正确率,为32%

有时我们能看到有些代码会写上以下两段代码:

在训练模型开始的时候

sen.train()

例如:

	sen.train()for data in train_dataloader:imgs, targets = dataoutputs = sen(imgs)loss = loss_fn(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()

在测试步骤开始的时候

sen.eval()

例如:

 	sen.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = sen(imgs)#这里的loss是一个tensor数据类型loss = loss_fn(outputs, targets)#加上item进行转换成整形total_test_loss = total_test_loss + loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracy

这些只对特定模型起作用,例如dropout,不过一般都写上,反正无影响。

训练模型总体过程
准备数据->加载数据->准备模型->设置损失函数->设置优化器->开始训练->最后验证->结果聚合展示

整体代码展示:

import torch
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter#创建网络模型
class Sen(nn.Module):def __init__(self):super(Sen, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1 ,2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self,x):x = self.model(x)return x
#准备数据集
#因为CIFAR10是属于PRL的数据集,所以需要转化成tensor数据集
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)#length长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为{train_data_size}")
print(f"测试数据集的长度为{test_data_size}")#利用DataLoader来加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)#创建网络模型
sen = Sen()#损失函数
loss_fn = nn.CrossEntropyLoss()#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(sen.parameters(), lr=learning_rate)#记录训练的次数
total_train_step = 0
#记录测试的次数
total_test_step = 0
#训练的轮数
epoch= 10#添加tensorboard
writer = SummaryWriter("./logs_train")for i in range(epoch):print(f"-------第{i+1}轮训练开始-------")#训练步骤开始sen.train()for data in train_dataloader:imgs, targets = dataoutputs = sen(imgs)loss = loss_fn(outputs, targets)#优化器模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1if total_train_step % 100 == 0:print(f"训练次数:{total_train_step},Loss:{loss.item()}")writer.add_scalar("train_loss", loss.item(), total_train_step)#测试步骤开始sen.eval()total_test_loss = 0total_accuracy = 0#不需要梯度进行调整或者优化with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = sen(imgs)#这里的loss是一个tensor数据类型loss = loss_fn(outputs, targets)#加上item进行转换成整形total_test_loss = total_test_loss + loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracyprint(f"整体测试集上的Loss{total_test_loss}")print(f"整体数据集上的正确率;{total_accuracy/test_data_size}")writer.add_scalar("test_loss", total_test_loss,total_test_step)writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)total_test_step = total_test_step + 1writer.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1554918.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【宽搜】4. leetcode 103 二叉树的锯齿形层序遍历

1 题目描述 题目链接:二叉树的锯齿形层序遍历 2 题目解析 根据题目描述,第一行是从左往右遍历,第二行是从右往左遍历。和层序遍历的区别就是: 在偶数行需要从右往左遍历。 因此,只需要在层序遍历的基础上增加一个变…

【WebGis开发 - Cesium】三维可视化项目教程---初始化场景

系列文章目录 未完待续~ 目录 系列文章目录引言一、Cesium引入项目1.1 下载资源1.2 项目引入Cesium 二、初始化地球2.1 创建基础文件2.1.1 创建Cesium工具方法文件2.1.2 创建主页面 2.2 看下效果 三、总结 引言 本教程主要是围绕Cesium这一开源三维框架开展的可视化项目教程。…

银河麒麟服务器镜像完整性验证:MD5校验

银河麒麟服务器镜像完整性验证:MD5校验 步骤一:获取标准MD5值步骤二:计算MD5值步骤三:对比MD5值 💖The Begin💖点点关注,收藏不迷路💖 在下载或传输银河麒麟服务器镜像时&#xff0c…

Oracle架构之表空间详解

文章目录 1 表空间介绍1.1 简介1.2 表空间分类1.2.1 SYSTEM 表空间1.2.2 SYSAUX 表空间1.2.3 UNDO 表空间1.2.4 USERS 表空间 1.3 表空间字典与本地管理1.3.1 字典管理表空间(Dictionary Management Tablespace,DMT)1.3.2 本地管理方式的表空…

Ubuntu 中 Redis ,MySQL 基本使用

1、Redis (1)启动Redis 服务端客户端命令 服务端 ps aux | grep redis 查看redis服务器进程 sudo kill -9 pid 杀死redis服务器 sudo redis-server /etc/redis/redis.conf 指定加载的配置文件客户端 连接redis: redis-cli运⾏测试命令&am…

《python语言程序设计》2018版第8章19题几何Rectangle2D类(上)--原来我可以直接调用

2024.9.29 玩了好几天游戏。 感觉有点灵感了。还想继续玩游戏。 2024.10.4 今天练习阿斯汤加练完从早上10点睡到下午2点.跑到单位玩游戏玩到晚上10点多. 现在回家突然有了灵感 顺便说一句,因为后弯不好,明天加练一次. 然后去丈母娘家. 加油吧 第一章、追求可以外调的函数draw_r…

【Python】pyenv:管理多版本 Python 环境的利器

pyenv 是一个强大的 Python 版本管理工具,它允许开发者在同一台计算机上轻松安装和管理多个 Python 版本。对于需要在不同项目中使用不同 Python 版本的开发者来说,pyenv 是一个非常有用的工具,因为它可以帮助用户在全局和项目级别控制 Pytho…

C/C++/EasyX——入门图形编程(4)

【说明】紧接上文(。・ω・。),好了,接下来,就让我们开始学习图像处理和获取鼠标消息的函数吧。(各位友友们不要着急,想在短时间内就想做小游戏或者写出各种好看的画面是不简…

小白快速上手 Docker 03 | Docker数据卷

数据卷 在前面使用Docker时,可能会遇到以下几个问题: 当Docker 里的容器挂了以后打不开,这时候只有删除该容器了,但删除容器会连容器中的产生的数据也一起删除了,大部分场景下这是不能接受的。Docker容器与容器之间不…

【深度学习基础模型】深度残差网络(Deep Residual Networks, DRN)详细理解并附实现代码。

【深度学习基础模型】深度残差网络(Deep Residual Networks, DRN)详细理解并附实现代码。 【深度学习基础模型】深度残差网络(Deep Residual Networks, DRN)详细理解并附实现代码。 文章目录 【深度学习基础模型】深度残差网络&a…

使用前端三剑客实现一个备忘录

一,界面介绍 这个备忘录的界面效果如下: 可以实现任务的增删,并且在任务被勾选后会被放到已完成的下面。 示例: (1),增加一个任务 (2),勾选任务 &#xff…

Chat登录时出现SSO信息出错的解决方法

目录 1. 问题所示2. 问题所示3. 解决方法 1. 问题所示 此贴主要是总结回顾,对此放置在运维专栏 出现如下问题,很懵,以为是节点挂了还是网址蹦了 一直刷新,登录之后就出现这个问题 2. 问题所示 对于SSO,也就是单点登…

ExcelToWord-Excel套打Word-Word邮件合并工具分享

Excel to Word转换工具分享 在日常工作或学习中,我们常常需要将Excel中的数据导出到Word文档中,以便更好地展示信息。市场上有许多Excel to Word的转换工具,它们各有特色。今天,我们就来推荐几款这样的工具,并探讨一下…

基于Springboot+Vue的教师科研管理系统 (含源码数据库)

1.开发环境 开发系统:Windows10/11 架构模式:MVC/前后端分离 JDK版本: Java JDK1.8 开发工具:IDEA 数据库版本: mysql5.7或8.0 数据库可视化工具: navicat 服务器: SpringBoot自带 apache tomcat 主要技术: Java,Springboot,mybatis,mysql,vue 2.视频演示地址 3.功能 系统中…

用Python实现运筹学——Day 12: 线性规划在物流优化中的应用

一、学习内容 线性规划在物流优化中可以用于解决诸如配送路径优化、货物运输调度等问题。配送中心的路径优化问题本质上是寻找一条最优路径,在满足需求点的需求条件下,最小化配送的总运输成本或时间。常见的物流优化问题包括: 配送中心的货…

Python小示例——质地不均匀的硬币概率统计

在概率论和统计学中,随机事件的行为可以通过大量实验来研究。在日常生活中,我们经常用硬币进行抽样,比如抛硬币来决定某个结果。然而,当我们处理的是“质地不均匀”的硬币时,事情就变得复杂了。质地不均匀的硬币意味着…

【C++】—— 类和对象(中)

【C】—— 类和对象(中) 文章目录 【C】—— 类和对象(中)前言1. 类的默认成员函数2. 构造函数3. 析构函数4. 拷贝构造函数5. 赋值运算符重载5.1 运算符重载5.2 赋值运算符重载 结语 前言 小伙伴们大家好呀,昨天的 【C】——类和对象(上) 大家理解的怎么样了 今天…

网约班车升级手机端退票

背景 作为老古董程序员,不,应该叫互联网人员,因为我现在做的所有的事情,都是处于爱好,更多的时间是在和各行各业的朋友聊市场,聊需求,聊怎么通过IT互联网 改变实体行业的现状,准确的…

卡码网KamaCoder 53. 寻宝

题目来源:53. 寻宝(第七期模拟笔试) C题解(来源代码随想录):最小生成树 prim prim三部曲 第一步,选距离生成树最近节点第二步,最近节点加入生成树第三步,更新非生成树节…

随时随地,轻松翻译:英汉互译软件的便捷之旅

翻译英汉互译工具,就如同一位随时待命的语言助手,在这纷繁复杂的语言世界中为我们搭建起理解与沟通的桥梁。接下来,让我们一同深入了解这些神奇的英汉互译工具,探索它的诸多功能和独特魅力。 1.福晰在线翻译 链接直达>>h…