PyTorch(六)优化模型参数

#c 目的 优化的目的

已经拥有了一个「模型」和「数据」,是时候通过「优化模型参数」来训练、验证和测试模型。

#d 迭代训练

训练模型是一个迭代过程;在每次迭代中,模型对输出做出猜测,计算其猜测的误差(损失),收集误差相对于其参数的导数,并使用「梯度下降」来优化这些参数。

1 超参数(Hyperparameters)

#d 超参数

超参数(Hyperparameters)是可调节的参数,它们允许你「控制模型优化过程」。不同的超参数值可能会影响模型的训练和收敛速率。

#e 超参数定义例子 超参数

周期数(Epochs):迭代数据集的次数。
批量大小(Batch Size):在更新参数之前通过网络传播的数据样本数量。
学习率(Learning Rate):在每个批量/周期中更新模型参数的程度。较小的值会导致学习速度慢,而较大的值可能会导致训练过程中出现不可预测的行为。

learning_rate = 1e-3
batch_size = 64
epochs = 5

2 优化循环,损失函数,优化器

#d 优化循环

一旦设置了超参数,就可以通过优化循环来训练和优化模型。优化循环的每一次迭代称为一个「周期(Epoch)」。

每个「周期」由两个主要部分组成:

  1. 训练循环(Train Loop):遍历训练数据集,尝试收敛到最优参数。
  2. 验证/测试循环(Validation/Test Loop):遍历测试数据集,以检查模型性能是否在提高。

#d 损失函数作用

当提供一些训练数据时,未经训练的网络很可能无法给出正确的答案。
「损失函数」衡量了所得「结果」与「目标值」之间的不相似程度,而在训练过程中,希望最小化的就是这个损失函数。为了计算损失,使用给定数据样本的输入进行预测,并将其与真实的数据标签值进行比较。

#e 常见损失函数 损失函数的作用

nn.MSELoss(均方误差):用于回归任务。
nn.NLLLoss(负对数似然损失):用于分类任务。
nn.CrossEntropyLoss:结合了nn.LogSoftmax和nn.NLLLoss的功能。
将模型的输出(logits)传递给nn.CrossEntropyLoss,它将对logits进行归一化并计算预测误差。

loss_fn = nn.CrossEntropyLoss()#初始化损失函数

#d 优化器作用

优化器(Optimizer)是「调整模型参数」以减少每个训练步骤中的「模型误差」的过程。优化算法定义了这个过程是如何执行的。所有的优化逻辑都被封装在优化器optimizer对象中。在PyTorch中还有许多不同的优化器可用,例如ADAM和RMSProp,它们对不同类型的模型和数据有更好的效果。

#e SGD(随机梯度下降) 优化器作用

通过注册需要训练的模型参数,并传入学习率超参数来初始化优化器。

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)#初始化优化器

#d 优化器作用位置

在训练循环中,优化发生在三个步骤中:

  1. 梯度清零:调用optimizer.zero_grad()来重置模型参数的梯度。梯度默认情况下是累加的;为了防止重复计算,在每次迭代中明确地将它们清零。

  2. 反向传播:通过调用loss.backward()对预测损失进行反向传播。PyTorch会计算损失相对于每个参数的梯度。

  3. 参数更新:一旦有了梯度,就调用optimizer.step()根据反向传播过程中收集的梯度来调整参数。

这个过程确保了模型在每次迭代中都能朝着减少损失的方向更新参数。

3 完整过程

定义循环优化代码的train_loop,以及根据测试数据评估模型性能的test_loop。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda# 加载数据集
training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)# 创建数据加载器
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
# 创建模型
class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),nn.ReLU())def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork()def train_loop(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)#数据集的大小model.train()#设置模型为训练模式for batch, (X, y) in enumerate(dataloader):pred = model(X)#前向传播loss = loss_fn(pred, y)# 反向传播loss.backward()optimizer.step()#参数更新optimizer.zero_grad()#梯度清零if batch % 100 == 0:#每100个批次打印一次loss, current = loss.item(), batch * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")def test_loop(dataloader, model, loss_fn):model.eval()#设置模型为评估模式size = len(dataloader.dataset)#数据集的大小test_loss, correct = 0, 0#测试损失和正确数num_batches = len(dataloader)#批次数with torch.no_grad():#关闭梯度跟踪for X, y in dataloader:pred = model(X)#前向传播test_loss += loss_fn(pred, y).item()#计算损失correct += (pred.argmax(1) == y).type(torch.float).sum().item()#计算正确数test_loss /= num_batches#计算平均损失correct /= size#计算正确率print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")'''
初始化损失函数和优化器,并将其传递给train_loop和test_loop。随意增加轮数以跟踪模型的改进性能。
'''
loss_fn = nn.CrossEntropyLoss()#初始化损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)#初始化优化器
epochs = 10#周期数for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(train_dataloader, model, loss_fn, optimizer)test_loop(test_dataloader, model, loss_fn)
print("Done!")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1473452.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

如何在Windows 11上复制文件和文件夹路径?这里提供几种方法

在Windows 11上复制文件或文件夹的路径就像在右键单击菜单中选择一个选项或按键盘快捷键一样简单。我们将向你展示如何在电脑上以各种方式进行操作。 从右键单击菜单 复制文件或文件夹路径的最简单方法是在该项目的右键单击菜单中选择一个选项。你也可以使用此方法复制多个项…

电表读数检测数据集VOC+YOLO格式18156张12类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):18156 标注数量(xml文件个数):18156 标注数量(txt文件个数):18156 标…

使用JAR命令打包JAR文件使用Maven打包使用Gradle打包打包Spring Boot应用

本人详解 作者:王文峰,参加过 CSDN 2020年度博客之星,《Java王大师王天师》 公众号:JAVA开发王大师,专注于天道酬勤的 Java 开发问题中国国学、传统文化和代码爱好者的程序人生,期待你的关注和支持!本人外号:神秘小峯 山峯 转载说明:务必注明来源(注明:作者:王文峰…

vue 模糊查询加个禁止属性

vue 模糊查询加个禁止属性 父组件通过属性传,是否禁止输入-------默认可以输入

VirtualBox 安装 Ubuntu Server24.04

环境: ubuntu-2404-server、virtualbox 7.0.18 新建虚拟机 分配 CPU 核心和内存(根据自己电脑实际硬件配置选择) 分配磁盘空间(根据自己硬盘实际情况和需求分配即可) 设置网卡,网卡1 负责上网&#xff0c…

字符串相似度算法完全指南:编辑、令牌与序列三类算法的全面解析与深入分析

在自然语言处理领域,人们经常需要比较字符串,这些字符串可能是单词、句子、段落甚至是整个文档。如何快速判断两个单词或句子是否相似,或者相似度是好还是差。这类似于我们使用手机打错一个词,但手机会建议正确的词来修正它&#…

【VUE基础】VUE3第三节—核心语法之ref标签、props

ref标签 作用&#xff1a;用于注册模板引用。 用在普通DOM标签上&#xff0c;获取的是DOM节点。 用在组件标签上&#xff0c;获取的是组件实例对象。 用在普通DOM标签上&#xff1a; <template><div class"person"><h1 ref"title1">…

使用 PyTorch 创建的多步时间序列预测的 Encoder-Decoder 模型

Encoder-decoder 模型在序列到序列的自然语言处理任务&#xff08;如语言翻译等&#xff09;中提供了最先进的结果。多步时间序列预测也可以被视为一个 seq2seq 任务&#xff0c;可以使用 encoder-decoder 模型来处理。本文提供了一个用于解决 Kaggle 时间序列预测任务的 encod…

笔记13:switch多分支选择语句

引例&#xff1a; 输入1-5中的任意一共数字&#xff0c;对应的打印字符A,B,C,D,E int num 0; printf("Input a number[1,5]:"); scanf("%d"&#xff0c;&num); if( num 1)printf("A\n"); else if(num2)printf("B\n"); else i…

ZYNQ7020的bank引脚分区

一张图看ZYNQ7000的资源分布 从图中看出BANK33 34 35是ZYNQ的PL部分 也就是FPGA部分PS部分在BANK0 500 501&#xff0c;DDR控制器连接在PS部分BANK33的电压可调

ePTFE膜(膨体聚四氟乙烯膜)应用前景广阔 本土企业技术水平不断提升

ePTFE膜&#xff08;膨体聚四氟乙烯膜&#xff09;应用前景广阔 本土企业技术水平不断提升 ePTFE膜全称为膨体聚四氟乙烯膜&#xff0c;指以膨体聚四氟乙烯&#xff08;ePTFE&#xff09;为原材料制成的薄膜。ePTFE膜具有耐化学腐蚀、防水透气性好、耐候性佳、耐磨、抗撕裂等优…

CTF常用sql注入(三)无列名注入

0x06 无列名 适用于无法正确的查出结果&#xff0c;比如把information_schema给过滤了 join 联合 select * from users;select 1,2,3 union select * from users;列名被替换成了1,2,3&#xff0c; 我们再利用子查询和别名查 select 2 from (select 1,2,3 union select * f…

中英双语介绍伦敦金融城(City of London)

中文版 伦敦金融城&#xff0c;通常称为“金融城”或“城”&#xff08;The City&#xff09;&#xff0c;是英国伦敦市中心的一个著名金融区&#xff0c;具有悠久的历史和全球性的影响力。以下是关于伦敦金融城的详细介绍&#xff0c;包括其地理位置、人口、主要公司、历史背…

关于在自行封装的组件库中(使用vue-class-component)使用Vue-i18n无法正常翻译的解决办法

文章目录 介绍背景现象1解决办法 现象2原因分析解决办法 最终方案 介绍 大家或多或少都用过别人封装的组件库&#xff0c;甚至有人或者公司内有自行封装的一些公用组件库&#xff0c;而国际化翻译现在已经是各大项目中必不可少的一个插件了&#xff0c;但组件库中使用 i18n 进…

计算机网络 0319

OSPF协议&#xff1a;开放式最短路径优先 协议 基于代价的路由协议 适合与大型的网络 DR 指定路由器 BDR 备用指定路由器 OSPF的组播地址 224.0.0.5 224.0.0.6 RIP组播地址&#xff1a;224.0.0.9 OSPF数据包 过程&#xff1a;先各个发送hello包认识&#xff0c;成为邻居…

深圳航空顶象验证码逆向,和百度验证码训练思路

声明(lianxi a15018601872) 本文章中所有内容仅供学习交流使用&#xff0c;不用于其他任何目的&#xff0c;抓包内容、敏感网址、数据接口等均已做脱敏处理&#xff0c;严禁用于商业用途和非法用途&#xff0c;否则由此产生的一切后果均与作者无关&#xff01; 前言(lianxi a…

CC2530寄存器编程学习笔记_点灯

下面是我的CC2530的学习笔记之点灯部分。 第一步&#xff1a;分析原理图 找到需要对应操作的硬件 图 1 通过这个图1我们可以找到LED1和LED2连接的引脚&#xff0c;分别是P1_0和P1_1。 第二步 分析原理图 图 2 通过图2 确认P1_0和P1_1引脚连接到LED&#xff0c;并且这些引…

51单片机———LED点阵屏显示图形动画

单片机上的一小块屏幕就是LED点阵屏&#xff0c;与数码管一样&#xff0c;内部由LED灯组成&#xff0c;只是点阵屏使用的LED灯更多&#xff0c;LED灯呈矩形分布而非“8”字形&#xff1b;并且点阵屏和数码管一样&#xff0c;有两种接法共阳极和共阳极&#xff1b; 16*16LED点阵…

springboot集成tika解析word,pdf,xls文件文本内容

介绍 Apache Tika 是一个开源的内容分析工具包&#xff0c;用于从各种文档格式中提取文本和元数据。它支持多种文档类型&#xff0c;包括但不限于文本文件、HTML、PDF、Microsoft Office 文档、图像文件等。Tika 的主要功能包括内容检测、文本提取和元数据提取。 官网 https…

vite+vue3整合less教程

1、安装依赖 pnpm install -D less less-loader2、定义全局css变量文件 src/assets/css/global.less :root {--public_background_font_Color: red;--publicHouver_background_Color: #fff;--header_background_Color: #fff;--menu_background: #fff; }3、引入less src/main.…