pytorch学习(十一)checkpoint

当训练一个大模型数据的时候,中途断电就可以造成已经训练几天或者几个小时的工作白做了,再此训练的时候需要从epoch=0开始训练,因此中间要不断保存(epoch,net,optimizer,scheduler)等几个内容,这样才能在发生意外之后快速恢复工作。

通过本博客的学习,你将学会最优模型保存和模型自动加载的方法

2.首先看代码

    1.2保存最优模型

保存最优模型的代码如下:

        if epoch > int(num_epoches/3) and test_loss_value < min_loss_val:min_loss_val = test_loss_valuecheckpoint = {"epoch": epoch,"net": model.state_dict(),"optimizer":optimer.state_dict(),"lr_schedule":scheduler_1.state_dict()}if not os.path.isdir(r'tf_logs/' + "save_module"):os.makedirs("tf_logs/" + "save_module")PATH = r'tf_logs/'+"save_module" + "/ckpt_best_%s.pth"%(str(epoch+1))torch.save(checkpoint, PATH)

 min_loss_val 定义成全局的变量之后,应该在用到的函数中,使用global min_loss_val再次定义,否则会报错误

 local variable 'min_loss_val' referenced before assignment 

1.2自动加载模型

#找文件夹中数字最大的文件
def getBestModuleFilename(browser):file_name = browser             #"tf_logs/save_module"filenames = os.listdir(file_name)pattern = r"d+"result = []for i in range(len(filenames)):rst = int(filenames[i][10:-4])result.append(rst)val = max(result)index = result.index(val)file_best = filenames[index]print(file_best)return file_bestResume = Falseif Resume == False:start_epoch = 0else:#找到数字最大的pth文件path_checkpoint = r'tf_logs/'+"save_module"best_path_checkpoint = getBestModuleFilename(path_checkpoint)if(best_path_checkpoint == ""):returnelse:checkpointResume = torch.load(path_checkpoint)start_epoch = checkpointResume["epoch"]model.load_state_dict(checkpointResume["net"])optimer.load_state_dict(checkpointResume["optimizer"])scheduler_1.load_state_dict(checkpointResume["lr_schedule"])

2.代码

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import LambdaLR
import os
import recur_pwd_path = os.getcwd()def getBestModuleFilename(browser):file_name = browser             #"tf_logs/save_module"filenames = os.listdir(file_name)pattern = r"d+"result = []for i in range(len(filenames)):rst = int(filenames[i][10:-4])result.append(rst)val = max(result)index = result.index(val)file_best = filenames[index]print(file_best)return file_besttensor = torch.randn(3,3)
bTensor = type(tensor) == torch.Tensor
print(bTensor)
print("tensor is on ", tensor.device)
#数据转到GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
if torch.cuda.is_available():tensor = tensor.to(device)print("tensor is on ",tensor.device)
#数据转到CPU
if tensor.device == 'cuda:0':tensor = tensor.to(torch.device("cpu"))print("tensor is on", tensor.device)
if tensor.device == "cpu":tensor = tensor.to(torch.device("cuda:0"))print("tensor is on", tensor.device)trainning_data =datasets.MNIST(root="data",train=True,transform=ToTensor(),download=True)
print(len(trainning_data))
test_data = datasets.MNIST(root="data",train=True,transform=ToTensor(),download=False)train_loader = DataLoader(trainning_data, batch_size=64,shuffle=True)
test_loader = DataLoader(test_data, batch_size=64,shuffle=True)print(len(train_loader)) #分成了多少个batch
print(len(trainning_data)) #总共多少个图像
# for x, y in train_loader:
#     print(x.shape)
#     print(y.shape)class MinistNet(nn.Module):def __init__(self):super().__init__()# self.flat = nn.Flatten()self.conv1 = nn.Conv2d(1,1,3,1,1)self.hideLayer1 = nn.Linear(28*28,256)self.hideLayer2 = nn.Linear(256,10)def forward(self,x):x= self.conv1(x)x = x.view(-1,28*28)x = self.hideLayer1(x)x = torch.sigmoid(x)x = self.hideLayer2(x)# x = nn.Sigmoid(x)return xmodel = MinistNet()
model = model.to(device)
cuda = next(model.parameters()).device
print(model)
criterion = nn.CrossEntropyLoss()
optimer = torch.optim.RMSprop(model.parameters(),lr= 0.001)scheduler_1 = LambdaLR(optimer, lr_lambda=lambda epoch: 1/(epoch+1))num_epoches =10
min_loss_val = 100000
Resume = Falsedef train():global min_loss_valstart_epoch = -1if Resume == False:start_epoch = 0else:#找到数字最大的pth文件path_checkpoint = r'tf_logs/'+"save_module"best_path_checkpoint = getBestModuleFilename(path_checkpoint)if(best_path_checkpoint == ""):returnelse:checkpointResume = torch.load(path_checkpoint)start_epoch = checkpointResume["epoch"]model.load_state_dict(checkpointResume["net"])optimer.load_state_dict(checkpointResume["optimizer"])scheduler_1.load_state_dict(checkpointResume["lr_schedule"])train_losses = []train_acces = []eval_losses = []eval_acces = []#训练model.train()tensorboard_ind =0;for epoch in range(num_epoches):batchsizeNum = 0train_loss = 0train_acc = 0train_correct = 0for x,y in train_loader:# print(epoch)# print(x.shape)# print(y.shape)x = x.to('cuda')y = y.to('cuda')bte = type(x)==torch.Tensorbte1 = type(y)==torch.TensorA = x.deviceB = y.devicepred_y = model(x)loss = criterion(pred_y,y)optimer.zero_grad()loss.backward()optimer.step()loss_val = loss.item()batchsizeNum = batchsizeNum +1train_acc += (pred_y.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()tensorboard_ind += 1train_losses.append(train_loss / len(trainning_data))train_acces.append(train_acc / len(trainning_data))#测试test_loss_value = 0model.eval()with torch.no_grad():num_batch = len(test_data)numSize = len(test_data)test_loss, test_correct = 0,0for x,y in test_loader:x = x.to(device)y = y.to(device)pred_y = model(x)test_loss += criterion(pred_y, y).item()test_correct += (pred_y.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchtest_correct /= numSizeeval_losses.append(test_loss)eval_acces.append(test_correct)test_loss_value = test_lossprint("test result:",100 * test_correct,"%  avg loss:",test_loss)scheduler_1.step()#设置checkpointif epoch > int(num_epoches/3) and test_loss_value < min_loss_val:min_loss_val = test_loss_valuecheckpoint = {"epoch": epoch,"net": model.state_dict(),"optimizer":optimer.state_dict(),"lr_schedule":scheduler_1.state_dict()}if not os.path.isdir(r'tf_logs/' + "save_module"):os.makedirs("tf_logs/" + "save_module")PATH = r'tf_logs/'+"save_module" + "/ckpt_best_%s.pth"%(str(epoch+1))torch.save(checkpoint, PATH)# Press the green button in the gutter to run the script.if __name__ == '__main__':train()

3.运行结果为

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1486496.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Java | Leetcode Java题解之第274题H指数

题目&#xff1a; 题解&#xff1a; class Solution {public int hIndex(int[] citations) {int left0,rightcitations.length;int mid0,cnt0;while(left<right){// 1 防止死循环mid(leftright1)>>1;cnt0;for(int i0;i<citations.length;i){if(citations[i]>mi…

Vue 3 实现左侧列表点击跳转滚动到右侧对应区域的功能

使用 Vue 3 实现左侧列表点击跳转到右侧对应区域的功能 1. 引言 在这篇博客中&#xff0c;我们将展示如何使用 Vue 3 实现一个简单的页面布局&#xff0c;其中左侧是一个列表&#xff0c;点击列表项时&#xff0c;右侧会平滑滚动到对应的内容区域。这种布局在很多应用场景中都…

金字塔思维:打造清晰有力的分析报告与沟通技巧

金字塔思维&#xff1a;打造清晰有力的分析报告与沟通技巧 在职场中&#xff0c;撰写一份条理清晰、逻辑严谨、说服力强的分析报告是每位职场人士必备的技能。然而&#xff0c;许多人在完成报告后常常感到思路混乱&#xff0c;表达不清。为了帮助大家解决这一问题&#xff0c;本…

VSCode部署Pytorch机器学习框架使用Anaconda(Window版)

目录 1. 配置Anaconda1.1下载安装包1. Anaconda官网下载2, 安装Anaconda 1.2 创建虚拟环境1.3 常用命令Conda 命令调试和日常维护 1.4 可能遇到的问题执行上述步骤后虚拟环境仍在C盘 2. 配置cuda2.1 查看显卡支持的cuda版本2.2 下载对应cuda版本2.3 下载对应的pytorch可能出现的…

学习React(状态管理)

随着你的应用不断变大&#xff0c;更有意识的去关注应用状态如何组织&#xff0c;以及数据如何在组件之间流动会对你很有帮助。冗余或重复的状态往往是缺陷的根源。在本节中&#xff0c;你将学习如何组织好状态&#xff0c;如何保持状态更新逻辑的可维护性&#xff0c;以及如何…

SQL 简单查询

目录 一、投影查询 1、指定特定列查询 2、修改返回列名查询 3、计算值查询 二、选择查询 1、使用关系表达式 2、使用逻辑表达式 3、使用 BETWEEN关键字 4、使用 IN关键字 5、使用 LIKE关键字 6、使用 IS NULL/ NOT NULL关键字 7、符合条件查询 三、聚合函数查询 一…

vuepress搭建个人文档

vuepress搭建个人文档 文章目录 vuepress搭建个人文档前言一、VuePress了解二、vuepress-reco主题个人博客搭建三、vuepress博客部署四、vuepress后续补充 总结 vuepress搭建个人文档 所属目录&#xff1a;项目研究创建时间&#xff1a;2024/7/23作者&#xff1a;星云<Xing…

Nuxt 使用指南:掌握 useNuxtApp 和运行时上下文

title: Nuxt 使用指南&#xff1a;掌握 useNuxtApp 和运行时上下文 date: 2024/7/21 updated: 2024/7/21 author: cmdragon excerpt: 摘要&#xff1a;“Nuxt 使用指南&#xff1a;掌握 useNuxtApp 和运行时上下文”介绍了Nuxt 3中useNuxtApp的使用&#xff0c;包括访问Vue实…

[Spring] Spring日志

&#x1f338;个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 &#x1f3f5;️热门专栏: &#x1f9ca; Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 &#x1f355; Collection与…

python实现责任链模式

把多个处理方法串成一个list。下一个list的节点是上一个list的属性。 每个节点都有判断是否能处理当前数据的方法。能处理&#xff0c;则直接处理&#xff0c;不能处理则调用下一个节点&#xff08;也就是当前节点的属性&#xff09;来进行处理。 Python 实现责任链模式&#…

在浏览器中测试JavaScript代码方法简要介绍

在浏览器中测试JavaScript代码方法简要介绍 在浏览器中测试JavaScript代码是前端开发中的一个重要技能。方法如下&#xff1a; 1. 浏览器控制台 最简单和直接的方法是使用浏览器的开发者工具中的控制台&#xff08;Console&#xff09;。 步骤&#xff1a; 在大多数浏览器…

Unity 之 【Android Unity 共享纹理】之 Android 共享图片给 Unity 显示

Unity 之 【Android Unity 共享纹理】之 Android 共享图片给 Unity 显示 目录 Unity 之 【Android Unity 共享纹理】之 Android 共享图片给 Unity 显示 一、简单介绍 二、共享纹理 1、共享纹理的原理 2、共享纹理涉及到的关键知识点 3、什么可以实现共享 不能实现共享…

有关于链表带环的两道OJ题目

目录 1.判断链表是否带环 1.1快指针的速度为慢指针的2倍 1.2快指针的速度为慢指针的3倍 2.找出带环链表开始入环的第一个节点 2.1将快慢指针相遇的节点与后面分开&#xff0c;构造交叉链表 2.2记录快慢指针相遇节点&#xff0c;与头结点一起向后走&#xff0c;相遇点为入…

远程开启空调,享受即刻凉爽

随着夏季的热浪逐渐侵袭&#xff0c;我们都渴望回到家中那一刻&#xff0c;能感受到一丝丝的凉意。但是&#xff0c;有时候&#xff0c;即使我们提前开窗通风&#xff0c;房间里的温度依然像烤箱一样闷热难耐。那么&#xff0c;有没有一种方法&#xff0c;能让我们在外头酷暑难…

升级Nvidia CUDA 遇到 sub-process /usr/bin/dpkg returned an error code (1)

1.问题描述 在自己Ubuntu22.04的服务器环境上存在cuda版本为11.5&#xff0c;按照官网教程升级为12.1运行安装命令 sudo apt-get -y install cuda 报错&#xff1a;sub-process /usr/bin/dpkg returned an error code (1) 官网教程&#xff1a; https://developer.nvidia…

PCIE软件基础知识

什么是PCIE PCIe&#xff0c;全称 Peripheral Component Interconnect Express&#xff0c;是一种高速串行计算机扩展总线标准&#xff0c;用于连接计算机内部的硬件组件&#xff0c;如显卡、存储设备、网络适配器等。PCIe是一种点对点的双向通信标准&#xff0c;这意味着它在发…

微信小程序canvas 使用案例(一)

一、cavans 对象获取、上线文创建 1.wxml <!-- canvas.wxml --><canvas type"2d" id"myCanvas"></canvas> 2.js /*** 生命周期函数--监听页面加载*/onLoad(options) {const query wx.createSelectorQuery()query.select(#myCanvas).f…

分离式网络变压器的集成化设计替代传统网络变压器(网络隔离滤波器)尝试

Hqst盈盛&#xff08;华强盛&#xff09;电子导读&#xff1a;今天分享的是应用了分离式网络变压器设计的的新型网络变压器&#xff08;网络隔离变压器&#xff09; 今天我们一起来看这款新型网络变压器&#xff0c;它就是应用分离式网络变压器集成到电路板上&#xff0c;加上外…

CAS乐观锁原理

1、什么是CAS&#xff1f; compare and swap也就是比较和交换&#xff0c;他是一条CPU的并发原语。 他在替换内存的某个位置的值时&#xff0c;首先查看内存中的值与预期值是否一致&#xff0c;如果一致&#xff0c;执行替换操作。 这个操作是一个原子性操作。 Java中基于Un…

昇思学习打卡-21-生成式/Diffusion扩散模型

文章目录 Diffusion扩散模型介绍模型推理结果 Diffusion扩散模型介绍 关于扩散模型&#xff08;Diffusion Models&#xff09;有很多种理解&#xff0c;除了本文介绍的离散时间视角外&#xff0c;还有连续时间视角、概率分布转换视角、马尔可夫链视角、能量函数视角、数据增强…