昇思MindSpore进阶教程-优化器

大家好,我是刘明,明志科技创始人,华为昇思MindSpore布道师。
技术上主攻前端开发、鸿蒙开发和AI算法研究。
努力为大家带来持续的技术分享,如果你也喜欢我的文章,就点个关注吧

模型训练过程中,使用优化器更新网络参数,合适的优化器可以有效减少训练时间,提高模型性能。

最基本的优化器是随机梯度下降算法(SGD),很多优化器在SGD的基础上进行了改进,以实现目标函数能更快速更有效地收敛到全局最优点。MindSpore中的nn模块提供了常用的优化器,如nn.SGD、nn.Adam、nn.Momentum等。本章主要介绍如何配置MindSpore提供的优化器以及如何自定义优化器。

在这里插入图片描述

nn.optim

配置优化器

参数配置

在构建优化器实例时,需要通过优化器参数params配置模型网络中要训练和更新的权重。Parameter中包含了一个requires_grad的布尔型的类属性,用于表示模型中的网络参数是否需要进行更新。

网络中大部分参数的requires_grad默认值为True,少部分默认值为False,例如BatchNorm中的moving_mean和moving_variance。

MindSpore中的trainable_params方法会屏蔽掉Parameter中requires_grad为False的属性,在为优化器配置 params 入参时,可使用net.trainable_params()方法来指定需要优化和更新的网络参数。

import numpy as np
import mindspore
from mindspore import nn, ops
from mindspore import Tensor, Parameterclass Net(nn.Cell):def __init__(self):super().__init__()self.conv = nn.Conv2d(1, 6, 5, pad_mode="valid")self.param = Parameter(Tensor(np.array([1.0], np.float32)), 'param')def construct(self, x):x = self.conv(x)x = x * self.paramout = ops.matmul(x, x)return outnet = Net()# 配置优化器需要更新的参数
optim = nn.Adam(params=net.trainable_params())
print(net.trainable_params())

用户可以手动修改网络权重中 Parameter 的 requires_grad 属性的默认值,来决定哪些参数需要更新。

如下例所示,使用 net.get_parameters() 方法获取网络中所有参数,并手动修改巻积参数的 requires_grad 属性为False,训练过程中将只对非卷积参数进行更新。

conv_params = [param for param in net.get_parameters() if 'conv' in param.name]
for conv_param in conv_params:conv_param.requires_grad = False
print(net.trainable_params())
optim = nn.Adam(params=net.trainable_params())
学习率

学习率作为机器学习及深度学习中常见的超参,对目标函数能否收敛到局部最小值及何时收敛到最小值有重要影响。学习率过大容易导致目标函数波动较大,难以收敛到最优值,太小则会导致收敛过程耗时过长。除了设置固定学习率,MindSpore还支持设置动态学习率,这些方法在深度学习网络中能明显提升收敛效率。
固定学习率:
使用固定学习率时,优化器传入的learning_rate为浮点类型或标量Tensor。

以nn.Momentum为例,固定学习率为0.01,示例如下:

# 设置学习率为0.01
optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.01, momentum=0.9)

动态学习率:
mindspore.nn提供了动态学习率的模块,分为Dynamic LR函数和LearningRateSchedule类。其中Dynamic LR函数会预先生成长度为total_step的学习率列表,将列表传入优化器中使用,训练过程中,第i步使用第i个学习率的值作为当前step的学习率,其中total_step的设置值不能小于训练的总步数;LearningRateSchedule类将实例传递给优化器,优化器根据当前step计算得到当前的学习率。

运行中修改优化器参数

运行中修改学习率

mindspore.experimental.optim.Optimizer 中学习率为 Parameter,除通过上述动态学习率模块 mindspore.experimental.optim.lr_scheduler 动态修改学习率,也支持使用 assign 赋值的方式修改学习率。

例如下述样例,在训练step中,设置如果损失值相比上一个step变化小于0.1,将优化器第1个参数组的学习率调整至0.01:

net = Net()
loss_fn = nn.MAELoss()
optimizer = optim.Adam(net.trainable_params(), lr=0.1)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
last_step_loss = 0.1def forward_fn(data, label):logits = net(data)loss = loss_fn(logits, label)return lossgrad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)if ops.abs(loss - last_step_loss) < 0.1:ops.assign(optimizer.param_groups[1]["lr"], Tensor(0.01))return loss

运行中修改除lr以外的优化器参数

下述样例,在训练step中,设置如果损失值相比上一个step变化小于0.1,将优化器第1个参数组的 weight_decay 调整至0.02:

net = Net()
loss_fn = nn.MAELoss()
optimizer = optim.Adam(net.trainable_params(), lr=0.1)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
last_step_loss = 0.1def forward_fn(data, label):logits = net(data)loss = loss_fn(logits, label)return lossgrad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)if ops.abs(loss - last_step_loss) < 0.1:optimizer.param_groups[1]["weight_decay"] = 0.02return loss

自定义优化器

与上述自定义优化器方式相同,自定义优化器时也可以继承优化器基类experimental.optim.Optimizer,并重写__init__方法和construct方法以自行设定参数更新策略。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1549143.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

什么是期望最大化算法?

一、期望最大化算法 期望最大化&#xff08;EM&#xff09;算法是一种在统计学和机器学习中广泛使用的迭代方法&#xff0c;它特别适用于含有隐变量的概率模型参数估计问题。在统计学和机器学习中&#xff0c;有很多不同的模型&#xff0c;例如高斯混合模型&#xff08;GMM&…

LeetCode讲解篇之3. 无重复字符的最长子串

文章目录 题目描述题解思路代码实现 题目描述 题解思路 因为我们需要求无重复字符的最长子串&#xff0c;这个我们首先需要想到使用滑动窗口&#xff0c;窗口内记录无重复的子串的所有字符&#xff0c;移动窗口的右边界时&#xff0c;发现当前字符在窗口内已经出现&#xff0c…

unreal engine5制作动作类游戏时,我们使用刀剑等武器攻击怪物或敌方单位时,发现攻击特效、伤害等没有触发

UE5系列文章目录 文章目录 UE5系列文章目录前言一、问题分析二、使用步骤2.玩家角色碰撞设置3.怪物角色碰撞预设 最终效果 前言 在使用unreal engine5制作动作类游戏时&#xff0c;我们使用刀剑等武器攻击怪物或敌方单位时&#xff0c;发现攻击特效、伤害等没有触发。检查动画…

二叉树进阶oj题【二叉树相关10道oj题的解析和c++代码实现】

目录 二叉树进阶oj题1.根据二叉树创建字符串2.二叉树的层序遍历3.二叉树的层序遍历 II4.二叉树的最近公共祖先5.二叉搜索树和双向链表6.从前序与中序遍历序列构造二叉树7.从中序和后序遍历序列来构造二叉树8.二叉树的前序遍历&#xff0c;非递归迭代实现9.二叉树中序遍历 &…

烟雾污染云层检测系统源码分享

烟雾污染云层检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer…

ssh方式连接上GitHub(超详细的哦)

目录 先检查本地有无生成过SSH key 生成 SSH Key&#xff08;如果没有的话&#xff09; 查看新生成的 SSH 公钥 添加新的 SSH Key 测试 SSH 连接 先检查本地有无生成过SSH key ls -al ~/.ssh终端输入以上命令查看 如果有应该能看到以下列表 id_rsa id_rsa.pub id_ed2551…

端侧Agent系列 | 端侧AI Agent任务拆解大师如何助力AI手机?(详解版)

引言 简介 Octo-planner 规划和执行Agent框架 规划数据集 基准设计 实验设计 结果 全量微调与LoRA 多LoRA训练与合并 不同基础模型的全量微调 不同数据集大小的全量微调 总结 实战 英文 中文示例1&#xff1a; 中文示例2&#xff1a; 0. 引言 人生到处知何似…

PREDATOR: Registration of 3D Point Clouds with Low Overlap

Abstract 这篇文章介绍了一种新的点云配准模型-Predator。该模型专注于处理低重叠的点云对&#xff0c;它更加关注于重叠区域的处理&#xff0c;其新颖之处在于一个重叠的注意块&#xff0c;作用是用于两个点云的潜在编码之间的早期信息交换。该模型大大提高了低重叠场景下的配…

【博弈强化学习】——UAV-BS 的联合功率分配和 3D 部署:基于博弈论的深度强化学习方法

【论文】&#xff1a;Joint Power Allocation and 3D Deployment for UAV-BSs: A Game Theory Based Deep Reinforcement Learning Approach 【引用】&#xff1a;Fu S, Feng X, Sultana A, et al. Joint power allocation and 3D deployment for UAV-BSs: A game theory based…

C++深入学习string类成员函数(3):访问与修饰

引言 在 C 中&#xff0c;std::string 提供了丰富的成员函数来访问和修改字符串中的字符。通过这些函数&#xff0c;程序员可以灵活地处理字符串中的各个元素&#xff0c;无论是读取特定位置的字符&#xff0c;还是修改字符串的内容。此外&#xff0c;std::string 类还确保了访…

农牧场可视化管理:精准监测与优化运营

利用图扑可视化技术实现农牧场的实时数据监测和分析&#xff0c;优化资源配置&#xff0c;提高生产效率和可持续发展能力。

无需安装移动端的互传工具“快速分享”

本文首发于只抄博客&#xff0c;欢迎点击原文链接了解更多内容。 前言 前不久给大家介绍过 Windows 自带的 Nearby Sharing 附近分享&#xff0c;只需要在手机上安装个 App 就可以与 Windows 进行互传。而今天介绍的“快速分享”正好相反&#xff0c;是在 Windows 上安装 Goog…

tomcat安装与部署

一、基础准备 1. 节点规划 IP 主机名 节点 192.168.200.70 tomcat Tomcat 2. 环境准备 准备一台虚拟机&#xff0c;镜像为CentOS-7-x86_64&#xff0c;下载两个软件包&#xff0c;apache-tomcat-9.0.95.tar.gz&#xff1b;zrlog WAR包。 二、安装Tomcat 1.基础环境配…

【C++篇】从零实现 `list` 容器:细粒度剖析与代码实现

文章目录 从零实现 list 容器&#xff1a;细粒度剖析与代码实现前言1. list 的核心数据结构节点结构分析 2 迭代器设计与实现2.1 为什么 list 需要迭代器&#xff1f;2.2 实现一个简单的迭代器2.3 测试简单迭代器解释&#xff1a; 2.4 增加后向移动和 -> 运算符关键点&#…

多模态——基于XrayGLM的X光片诊断的多模态大模型

0.引言 近年来&#xff0c;通用领域的大型语言模型&#xff08;LLM&#xff09;&#xff0c;如ChatGPT&#xff0c;已在遵循指令和生成类似人类的响应方面取得了显著成就。这些成就不仅推动了多模态大模型研究的热潮&#xff0c;也催生了如MiniGPT-4、mPLUG-Owl、Multimodal-G…

Synchronized和 ReentrantLock有什么区别?

目录 一、java中的线程同步 二、Synchronized 使用方式 底层原理 synchronized 同步代码块的情况 synchronized 修饰方法的情况 总结 synchronized 和 volatile 有什么区别&#xff1f; 三、ReentrantLock 底层原理 使用方式 四、Synchronized和 ReentrantLock有什…

GPIO端口的使用

目录 一. 前言 二. APB2外设时钟使能寄存器 三. GPIO端口的描述 四. GPIO端口使用案例 一. 前言 基于库函数的开发方式就是使用ST官方提供的封装好的函数。而如果没有添加库函数&#xff0c;那就是基于寄存器的开发方式&#xff0c;这种方式一般不是很推荐。因为由于ST对寄存…

docker pull 超时的问题如何解决

docker不能使用&#xff0c;使用之前的阿里云镜像失败。。。 搜了各种解决方法&#xff0c;感谢B站UP主 <iframe src"//player.bilibili.com/player.html?isOutsidetrue&aid113173361331402&bvidBV1KstBeEEQR&cid25942297878&p1" scrolling"…

维护左边枚举右边

前言&#xff1a;一开始遇到这个题目的时候没啥思路&#xff0c;但是当我看到值域在1000的时候我想着直接暴力从右边枚举不就行了吗&#xff0c;时间复杂度刚刚好&#xff0c;试一下就过了 正解应该是啥呢&#xff0c;其实也是维护一遍&#xff0c;运行另外一边 O ( n ) O(n)…

所有测试人,下半年的新方向(大模型),赢麻了!!!

现在做测试&#xff0c;真的挺累的。 现在测试越来越难做&#xff0c;晋升困难&#xff0c;工资迟迟不涨……公司裁员&#xff0c;测试首当其冲&#xff01;&#xff01; 做测试几年了&#xff0c;还没升职&#xff0c;就先到了“职业天花板”。 想凭工作几年积累的经验&…