优化器与现有网络模型的修改

文章目录

    • 一、优化器是什么
    • 二、优化器的使用
    • 三、分类模型VGG16
    • 四、现有网络模型的修改

一、优化器是什么

优化器(Optimizer)是一个算法,用于在训练过程中调整模型的参数,以便最小化损失函数(Loss Function)。损失函数衡量的是模型预测值与真实值之间的差异,而优化器则负责通过更新模型的权重(Weights)和偏置(Biases)来减少这种差异。

利用得到的梯度,用优化器对梯度进行修正,从而得到整体误差降低的目的。

优化器Optimizer 所需要从参数:

在这里插入图片描述

参数解析:

  • model.parameters()是训练的模型
  • lr(LearningRate)是学习率,这是最核心的参数之一,它决定了在每次迭代中参数更新的步长。如果学习率太高,可能会导致训练过程中的梯度爆炸,使模型无法收敛,训练很不稳定;如果学习率太低,训练过程可能会变得非常缓慢。
    推荐一开始用大的lr值进行运算,到后面用小的lr再进行运算。
  • 动量(Momentum)往往是特定参数,是用于加速梯度下降方法,特别是在处理凸优化问题时。它通过在连续的迭代中累积梯度信息来帮助优化器克服局部最小值,并加快收敛速度。

二、优化器的使用

本文使用我的上一章内容神经网络内容进行续写,神经网络具体可跳转损失函数和反向传播

使用一下代码来进行梯度优化:

    optim.zero_grad()# 向后传播result_loss.backward()#这一步对数值进行调优optim.step()

整体代码如下:

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(), download= True)dataloader = DataLoader (dataset, batch_size = 1)
class Sen(nn.Module):def __init__(self):super(Sen,self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, 1, 2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = nn.CrossEntropyLoss()
sen = Sen()#随机梯度下降
optim = torch.optim.SGD(sen.parameters(), lr=0.01)for data in dataloader:imgs, tatgets = dataoutputs = sen(imgs)result_loss = loss(outputs, tatgets)#对参数进行梯度清零optim.zero_grad()# 向后传播result_loss.backward()#这一步对数值进行调优optim.step()

在未运行时的梯度没有值:
在这里插入图片描述
当运行一下:
在这里插入图片描述
可以看到每个参数节点的值被计算出来了。

当for循环第二次运行的时候,可以看到grad梯度已经被优化了:

在这里插入图片描述

通过反复循环,上图中的data数据,也就是loss就会越来越被优化。

上面的for循环其实是为数据的一次小循环,我们可以加上epoch 外嵌套 进行数据的一轮轮循环深度优化:

for epoch in range(20):running_loss = 0.0#这里只是进行了一次的学习for data in dataloader:imgs, tatgets = dataoutputs = sen(imgs)result_loss = loss(outputs, tatgets)#对参数进行梯度清零optim.zero_grad()# 向后传播result_loss.backward()#这一步对数值进行调优aoptim.step()#这一步就相当于所有误差的一个整体求和running_loss = running_loss + result_loss

整体代码:

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(), download= True)dataloader = DataLoader (dataset, batch_size = 1)
class Sen(nn.Module):def __init__(self):super(Sen,self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, 1, 2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = nn.CrossEntropyLoss()
sen = Sen()#随机梯度下降
optim = torch.optim.SGD(sen.parameters(), lr=0.01)#这里是进行一轮一轮的学习
for epoch in range(20):running_loss = 0.0#这里只是进行了一次的学习for data in dataloader:imgs, tatgets = dataoutputs = sen(imgs)result_loss = loss(outputs, tatgets)#对参数进行梯度清零optim.zero_grad()# 向后传播result_loss.backward()#这一步对数值进行调优aoptim.step()#这一步就相当于所有误差的一个整体求和running_loss = running_loss + result_lossprint(running_loss)

运行结果如下,可以看到,整个神经网络在所有的数据当中,它的误差之和如下:

在这里插入图片描述

在第一轮优化的时候,整个神经网络的误差之和是18779
在第二轮优化的时候,整个神经网络的误差之和是16205
在第三轮优化的时候,整个神经网络的误差之和是15448

可以看到,通过优化器的一轮轮优化,整体的loss值会一直降低,从而达到数据优化的效果。

三、分类模型VGG16

pytorch为我们提供了很多网络模型,其中包括分类模型VGG16

分类模型VGG16是基于ImageNet数据集进行训练的,所以我们需要下载ImageNet数据集

由于ImageNet数据集的内存为143g,会发生以下报错,需要我们自己去下载ImageNet数据集再放在根目录当中。
在这里插入图片描述

既然ImageNet数据集太大,那么就换一条思路,用一下方法加载vgg16

import torchvision.datasets
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_True = torchvision.models.vgg16(pretrained=True)
print('ok')

如果pretrained = True,说明这个数据集已经是训练好的了。
如果pretrained = False,说明这些参数是一个初始参数,没有在任何参数集上面进行训练。
如果progress = True,显示下载进度条
如果progress = Flase,则不显示下载进度条

vgg16_false = torchvision.models.vgg16(pretrained=False),这代码表示只是加载网络模型(也就是像之前的网络模型那样,只是加载模型,含有卷积,池化等,其中的参数都是默认的),所以它不需要下载。
vgg16_True = torchvision.models.vgg16(pretrained=True),这代码表示需要把网络模型参数进行一个下载,还要加载对应的参数。故它需要进行下载。
简单理解就是False不需要进行下载,而True需要进行下载。
VGG16将数据集分成1000个类。

print(vgg16_true)
输出结果:
在这里插入图片描述
在这里插入图片描述
看它把各种卷积层,最大池化都自动按参数下载好了。

常用的CIFAR10会把数据集分成10个类。
vgg16会把数据集分成1000个类,如上图的out_features=1000

四、现有网络模型的修改

方法:像上面得到的是out_features=1000,我们可以进行一个新的处理,通过Linear将输入是1000,而输出为10,从而达到降类的效果。

vgg16_true.add_module("add_linear", nn.Linear(1000, 10))

运行得到:
在这里插入图片描述
可以看到,在add_linear这里的out_features=10

如果要想类的改变在classifier当中,那么代码只需要添加上classifier

vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))

运行结果:
在这里插入图片描述
整体代码如下:

import torchvision.datasets
from torch import nnvgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)print(vgg16_true)train_data = torchvision.datasets.CIFAR10("./data",train=True, transform=torchvision.transforms.ToTensor(),download=True)vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))

如果想直接在上面 (6)Linear 里面修改out_features,而不是新命名一个(add_linear)进行修改也是可以的

用vgg16_flase进行示范:

在没进行修改前print(vgg16_false)

运行结果:
在这里插入图片描述
直接在(6)Linear中修改out_features为10

代码:

vgg16_false.classifier[6] = nn.Linear(4096, 10)

运行结果:
在这里插入图片描述
可以看到out_features=10,从而成功修改现有的网络模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1539143.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【论文阅读笔记】YOLOv10: Real-Time End-to-End Object Detection

论文地址:https://arxiv.org/abs/2405.14458 文章目录 论文小结论文简介论文方法为NMS-free训练的一致性双标签分配双标签分配一致性匹配度量 效率-精度整体驱动的模型设计效率驱动模型设计轻量级分类检测头Spatial-channel 解耦下采样Rank-guided block design 精度…

linux 操作系统下的dhclient命令介绍和案例使用

linux 操作系统下的dhclient命令介绍和案例使用 dhclient 是 Linux 系统中用于动态主机配置协议(DHCP)客户端的命令。它的主要功能是从 DHCP 服务器获取网络配置,包括 IP 地址、子网掩码、默认网关和 DNS 服务器等信息 dhclient 命令概述 …

transformer共享权重对联模型

嵌入维度512,8头,1层 |分割中最从左到右依次是数据集上联,模型预测下联,数据集下联 ,有些对联对的还是可以的 嵌入维度512,8头,3层,最后一个输出层采用线性层,模型训练过程 上面是模型训练过程,下面是模型训练结果 从左到右,上联,模型生成,下…

满足10人同时绘图的图形工作站

在当今这个数字化与创意并重的时代,图形工作站作为设计师、艺术家及数字内容 创作者们的重要工具,其性能与效率直接关系到项目的成功与否。 当谈及满足10人同时绘图的图形工作站时,我们不仅要考虑硬件的峰值性能,还需兼顾软件的兼…

PSINS,GNSS速度与SINS滤波的MATLAB代码

文章目录 程序说明主要特点适用范围获取方式运行截图 程序说明 基于PSINS工具箱的GNSS和SINS滤波的MATLAB代码,观测量为GNSS的三轴速度。 专为工程师和研究人员设计,助您轻松实现高精度的导航和定位。 主要特点 高精度滤波算法:结合PSINS和…

中间件:maxwell、canal

文章目录 1、底层原理:基于mysql的bin log日志实现的:把自己伪装成slave2、bin log 日志有三种模式:2.1、statement模式:2.2、row模式:2.3、mixed模式: 3、maxwell只支持 row 模式:4、maxwell介…

思通数科开源智能文档识别平台的核心功能

思通数科的智能文档识别平台是一个综合性的解决方案,旨在通过人工智能技术提升文档识别处理的效率和准确性。 主要的功能是: 1. 信息抽取与数据结构化 票据识别与抽取:利用OCR技术自动识别和提取票据上的关键信息,如日期、金额等…

几何 | 数学专项

日期内容2024.9.19创建 { d > 0 , 递增数列 d < 0 , 递减数列 d 0 &#xff0c;常数列 \begin{cases} d>0,递增数列\\ d<0,递减数列\\ d0&#xff0c;常数列 \end{cases} ⎩ ⎨ ⎧​d>0,递增数列d<0,递减数列d0&#xff0c;常数列​ 【2010.13】 【1.历年真…

三菱变频器以模拟方式进行频率设定:(电压输入)

POINT 在STF(STR)信号 ON时&#xff0c;发出启动指令。 通过电位器(频率设定器)下达频率指令。(端子2-5之间连接(电压输入)) 接线示例 (变频器为频率设定器提供5V 电源。(端子 10)) 操作示例 以 60Hz 运行 NOTE. 1、正转开关(STF)与反转开关(STR)同时为0N时无法启动。此外&a…

利用Leaflet.js和turf.js创建交互式地图:航道路线绘制

引言 在现代Web应用中&#xff0c;地图的交互性是提供丰富用户体验的关键。Leaflet.js是一个轻量级的开源JavaScript库&#xff0c;它提供了简单易用的API来构建交云的地图。与此同时&#xff0c;turf.js作为一个强大的地理空间分析库&#xff0c;能够处理复杂的地理数据操作。…

SourceTree保姆级教程1:(克隆,提交,推送)

本人认为sourceTree 是最好用的版本管理工具&#xff0c;下面将讲解下sourceTree 客户端工具 克隆&#xff0c;提交&#xff0c;推送 具体使用过程&#xff0c;废话不多说直接上图。 使用步骤&#xff1a; 首先必须要先安装Git和sourceTree&#xff0c;如何按照参考其它文章&…

C++第四讲:模板初阶

C第四讲&#xff1a;模板初阶 1.泛型编程2.函数模板2.1什么是函数模板2.2函数模板的使用2.3函数模板的原理2.4函数模板的实例化2.4.1隐式实例化2.4.2显式实例化 2.5模板参数的匹配原则2.5.1原则12.5.2原则22.5.3原则3 3.类模板3.1类模板的定义格式3.2类模板的实例化 1.泛型编程…

GPT代码记录

#include <iostream>// 基类模板 template<typename T> class Base { public:void func() {std::cout << "Base function" << std::endl;} };// 特化的子类 template<typename T> class Derived : public Base<T> { public:void…

在线安全干货|如何更改IP地址?

更改IP地址是一个常见的需求&#xff0c;无论是为了保护个人隐私、绕过地理限制还是进行商业数据分析。不同的IP更改方法适用于不同的需求和环境。但请注意&#xff0c;更改IP地址应在合法场景下进行&#xff0c;无论使用什么方法&#xff0c;都需要在符合当地网络安全法律法规…

寻呼机爆炸,炸醒通讯安全警惕心

据央视新闻报道&#xff1a;当地时间17日下午&#xff0c;黎巴嫩首都贝鲁特以及黎巴嫩东南部和东北部多地发生寻呼机爆炸事件。黎巴嫩公共卫生部长阿卜亚德称&#xff0c;爆炸已造成9人死亡&#xff0c;约有2800人受伤&#xff0c;其中约200人伤情危重。 来源&#xff1a;央视新…

20240919在友善之臂的NanoPC-T6开发板上适配宸芯的数传模块CX6602N

20240919在友善之臂的NanoPC-T6开发板上适配宸芯的数传模块CX6602N 2024/9/19 16:54 缘起&#xff0c;大毛PK二毛战况激烈&#xff0c;穿越机大卖&#xff01;我司拆同行的图传作品。 发现&#xff1a; 主控&#xff1a;飞凌OK3588-C核心板 图传模块&#xff1a;宸芯的数传模块…

不断挑战才有不断机遇!Eagle Trader等你来加入!

2024“Eagle Trader杯”全国职业交易联赛S1赛季已火热进行一个多月&#xff0c;吸引了超过355名交易员的积极参与&#xff01;目前&#xff0c;每天都有新的交易员踊跃报名参加&#xff01; 经过严格地交易考核&#xff0c;13名选手成功通过初试&#xff0c;正进入下一阶段的挑…

【C++初阶】vector模拟实现

✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅ ✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨ &#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1…

方法:批量提取PPT幻灯片中图片

处理包含大量图片的PPT&#xff08;PowerPoint&#xff09;幻灯片已成为许多专业人士的日常任务之一。然而&#xff0c;手动从每张幻灯片中逐一提取图片不仅耗时耗力&#xff0c;还容易出错。为了提升工作效率&#xff0c;减少重复劳动&#xff0c;探索并实现一种高效批量提取P…

WebGL系列教程十(模型Model、视图View、投影Projection变换)

目录 1 前言2 模型变换3 视图变换3.1 公式推导3.1.1 确定摄像机的参数3.1.2 构建摄像机坐标系3.1.3 构建视图变换矩阵3.1.4 组合视图矩阵 3.2 方法调用 4 投影变换4.1 正交投影推导4.2 正交投影调用4.3 透视投影推导4.3 透视投影调用 5 总结 1 前言 上一讲我们讲了动画&#xf…