【深度学习】(7)--保存最优模型

文章目录

  • 保存最优模型
    • 一、两种保存方法
      • 1. 保存模型参数
      • 2. 保存完整模型
    • 二、迭代模型
  • 总结

保存最优模型

我们在迭代模型训练时,随着次数初始的增多,模型的准确率会逐渐的上升,但是同时也随着迭代次数越来越多,由于模型会开始学习到训练数据中的噪声或非共性特征,发生过拟合现象,使得模型的准确率会上下震荡甚至于下降。

本篇就是介绍我们如何在进行那么多次迭代之中,找到训练最好效果时,模型的参数或完整模型。也方便以后使用模型时直接使用。

一、两种保存方法

我们知道,一个模型到底好不好,主要体现在对测试集数据结果上的表现,所以我们的方法主要从测试集入手,计算每次迭代测试集数据的准确率,取到准确率最大时对应的模型和参数

那么,我们该如何保存模型和参数呢?介绍一个小东西:

  • 文件拓展名pt\pth,t7,使用pt\pth或t7作为模型文件扩展名,保存模型的整个状态(包括模型架构和参数)或仅保存模型的参数(即状态字典,state_dict)。

1. 保存模型参数

方法

torch.save(model.state_dict(),path)
# model.state_dict()是一个从参数名称映射到参数张量的字典对象,它包含了模型的所有权重和偏置项
# path为创建的保存模型的文件

通过比较每一次迭代准确率的大小,取准确率最大时模型的参数

best_acc = 0
"""-----测试集-----"""
def test(dataloader,model,loss_fn):global best_accsize = len(dataloader.dataset) # 总数据大小num_batches = len(dataloader) # 划分的小批次数量model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item() # 预测正确的个数test_loss /= num_batchescorrect /= sizecorrect = round(correct, 4)print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")# 保存最优模型的方法(文件扩展名一般:pt\pth,t7)if correct > best_acc:best_acc = correct# 1. 保存模型参数方法:torch.save(model.state_dict(),path)  (w,b)print(model.state_dict().keys()) # 输出模型参数名称cnntorch.save(model.state_dict(),"best.pth") 

2. 保存完整模型

方法

torch.save(model,path)
# 直接得到整个模型

依旧是通过比较每一次迭代准确率的大小,但是取准确率最大时的整个模型

def test(dataloader,model,loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batchescorrect /= sizecorrect = round(correct, 4)print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")# 保存最优模型的方法(文件扩展名一般:pt\pth,t7)if correct > best_acc:best_acc = correct# 2. 保存完整模型(w,b,模型cnn)torch.save(model,"best1.pt")

二、迭代模型

接下来就要迭代模型,得到最优的模型:

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001,weight_decay=0.0001)epochs = 150
# training_data、test_data:数据预处理好的数据
train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=64,shuffle=True)
for t in range(epochs):print(f"Epoch {t+1} \n-------------------------")train(train_dataloader,model,loss_fn,optimizer)test(test_dataloader,model,loss_fn)
print("Done!")

在每轮数据迭代后,project工程栏中的best1.ptbest.pth文件中模型会随着迭代及时更新,迭代结束后,文件中保存的就是最优模型以及最优的模型参数。

在这里插入图片描述

总结

本篇介绍了:

  1. 为什么随着迭代次数越来越多,模型的准确率会上下震荡甚至于下降。—> 过拟合
  2. pt\pth,t7三个扩展名,用于保存完整模型或者模型参数。
  3. 模型的好坏,通过体现在测试集的结果上。
  4. 保存最优模型的两种方法:保存模型参数和保存完整模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/150736.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

同等学力英语考试词汇是多少

很多考生想知道同等学力英语考试词汇是多少,词汇量要求大约在4000-6000个单词之间。 这一要求介于大学英语四级单词和六级之间,具体词汇量可能会根据不同的资料有所差异。 例如有的资料指出,考生应掌握约6000个英语词汇和约700个常用词组&am…

QT版数据采集系统研发过程记录

研发目的:通过智能监测设备将各个变电站运行的电压、电流、温湿度等数据采集汇总到计算机中心服务器,通过系统软件展示各个站点对应的运行工况。 软件架构:使用QT开发跨平台(Windows系统、Ubuntu20.04)客户端软件、连…

探秘淘宝商品评论电商API接口:提升用户购买决策的终极武器!

随着互联网技术的不断发展和电商市场的迅速扩张,消费者在网购时对商品信息的获取和筛选变得尤为重要。其中,淘宝商品评论作为消费者了解商品质量、性能以及使用体验的重要途径,其价值和影响力日益凸显。本文将深入探讨淘宝商品评论电商API接口…

恢复 iPhone 16 上不见照片的 4 种简便方法

几乎每部 iPhone 都有一个共同点,那就是照片应用程序,里面装满了大量的照片和视频。人们购买 iPhone 的原因之一是其出色的相机质量。那么,如果你突然丢失了照片,你会有什么感觉呢 如果您不小心删除了照片或手机故障导致所有照片…

基于SpringBoot - Netty框架的云快充协议(充电桩协议)

云快充协议是一种标准通信协议,主要用于电动车与充电桩之间的数据交换。该协议包含了充电请求、状态查询、支付等多个功能模块。这些功能的实现不仅需要对协议进行深入理解,还需要编写相应的代码进行封装。 软件架构 1、提供云快充底层桩直连协议&#…

Nexus学习

系列文章目录 第一章 基础知识、数据类型学习 第二章 万年历项目 第三章 代码逻辑训练习题 第四章 方法、数组学习 第五章 图书管理系统项目 第六章 面向对象编程:封装、继承、多态学习 第七章 封装继承多态习题 第八章 常用类、包装类、异常处理机制学习 第九章 集…

Swing模拟银行柜台系统

> 这是一个基于JavaSwing实现的模拟银行柜台系统。 > 具有管理员、柜员、客户三种登录角色。 > 支持开户、注册、存取款、转账、汇款、账单查询等功能。 > 本项目适合JAVA初学者作为入门学习项目。 一、部分界面演示 二、基础依赖 技术/框架版本描述Java11编…

sql中的having与where对比

sql中的having与where对比 1、语法差异2、影响结果范围3、索引使用4、聚合函数5、总结 💖The Begin💖点点关注,收藏不迷路💖 在SQL中,having和where都是用来过滤数据的,但它们之间存在一些关键的不同点。 …

mac输入法 cpu占用,解决mac使用输入法出现卡顿延迟

1、介绍 网上有各种方法,例如有touchbar的macbook关闭输入建议;定时重启“简体中文输入法”进程;关闭“显示器具有单独的空间” 这些方法网上都能看到,有些人说能解决,有些人说还是卡,我试过了问题依然存在…

书生大模型实战(从入门到进阶)L1-InternLM + LlamaIndex RAG 实践

目录 配置基础环境 安装 Llamaindex 下载 Sentence Transformer 模型 下载 NLTK 相关资源 LlamaIndex HuggingFaceLLM LlamaIndex RAG LlamaIndex web 本文是对书生大模型L1-InternLM LlamaIndex RAG 实践部分的学习和实现,学习地址如下: 学习地…

JVM基本了解

一、JVM 基本组成 1、JDK\JRE\JVM JDK:全称“Java Development Kit”Java 开发工具包,提供 javac 编译器、jheap、jconso1e 等监控工具;JRE:全称“Java Runtime Environment”Java 运行环境,提供Class Library 核心类库 JVM;JVM:全称“Java Virtual Ma…

XILINX ZYNQ 7000 UART EMIO 串口IO扩展

当需要使用到PL端的IO口用作串口的时候可以使用EMIO对UART的引脚进行扩展 这里使用UART1 进行EMIO扩展 EMIO本质上是属于PL FPGA的资源所以需要进行综合然后再指定管脚 然后把UART1,TX RX做外部引脚 生成bit流文件,然后导入到SDK 创建一个API&#x…

如何解决跨域请求中的 CORS 错误

聚沙成塔每天进步一点点 本文回顾 ⭐ 专栏简介如何解决跨域请求中的 CORS 错误1. 引言2. 什么是 CORS?2.1 同源策略示例: 2.2 CORS 请求的类型 3. CORS 错误的原因3.1 常见 CORS 错误示例 4. 解决 CORS 错误的常见方法4.1 在服务器端启用 CORS4.1.1 Node…

使用Jlink打印单片机的调试信息

1.在工程中添加6个文件 除去RTT_Debug.h外的其他几个文件在jlink安装目录 RTT_Debug.h的内容如下 #ifndef _RTT_H_ #define _RTT_H_#include "SEGGER_RTT.h"#define STR_DEBUG //#define USART_DEBUG#define DBGLOG #define DBGWARNING #define DBGERROR#if def…

【自动驾驶】基于车辆几何模型的横向控制算法 | Stanley 算法详解与编程实现

写在前面: 🌟 欢迎光临 清流君 的博客小天地,这里是我分享技术与心得的温馨角落。📝 个人主页:清流君_CSDN博客,期待与您一同探索 移动机器人 领域的无限可能。 🔍 本文系 清流君 原创之作&…

RAG(Retrieval-Augmented Generation)检索增强生成技术基础了解学习与实践

RAG(Retrieval-Augmented Generation)是一种结合了信息检索(Retrieval)和生成模型(Generation)的技术,旨在提高生成模型的性能和准确性。RAG 技术通过在生成过程中引入外部知识库,使…

设计模式之装饰模式(Decorator)

前言 这个模式带给我们有关组合跟继承非常多的思考 定义 “单一职责” 模式。动态(组合)的给一个对象增加一些额外的职责。就增加功能而言,Decorator模式比生成子类(继承)更为灵活(消除重复代码 & 减少…

深入探索卷积神经网络(CNN)

深入探索卷积神经网络(CNN) 前言图像的数字表示灰度图像RGB图像 卷积神经网络(CNN)的架构基本组件卷积操作填充(Padding)步幅(Strides) 多通道图像的卷积池化层全连接层 CNN与全连接…

c++难点核心笔记(二)

系列文章目录 c难点&核心笔记(一) 继续接着上一章记录的重点内容包括函数,类和对象,指针和引用,C对象模型和this指针等内容,继续给大家分享!! 文章目录 系列文章目录友元全局函数做友元类做友元成员函…

傅里叶变换及其应用笔记

傅里叶变换 预备知识学习路线扼要描述两者之间的共同点:线性运算周期性现象对称性与周期性的关系周期性 预备知识 学习路线 从傅里叶级数,过度到傅里叶变换 扼要描述 傅里叶级数(Fourier series),几乎等同于周期性…