昇思MindSpore学习入门-回调机制

在深度学习训练过程中,为及时掌握网络模型的训练状态、实时观察网络模型各参数的变化情况和实现训练过程中用户自定义的一些操作,MindSpore提供了回调机制(Callback)来实现上述功能。

Callback回调机制一般用在网络模型训练过程Model.train中,MindSpore的Model会按照Callback列表callbacks顺序执行回调函数,用户可以通过设置不同的回调类来实现在训练过程中或者训练后执行的功能。

Callback介绍

当聊到回调Callback的时候,大部分用户都会觉得很难理解,是不是需要堆栈或者特殊的调度方式,实际上我们简单的理解回调:

假设函数A有一个参数,这个参数是个函数B,当函数A执行完以后执行函数B,那么这个过程就叫回调。

Callback是回调的意思,MindSpore中的回调函数实际上不是一个函数而是一个类,用户可以使用回调机制来观察训练过程中网络内部的状态和相关信息,或在特定时期执行特定动作

例如监控损失函数Loss、保存模型参数ckpt、动态调整参数lr、提前终止训练任务等。下面我们继续以手写体识别模型为例,介绍常见的内置回调函数和自定义回调函数。

常用的内置回调函数

MindSpore提供Callback能力,支持用户在训练/推理的特定阶段,插入自定义的操作。

ModelCheckpoint

用于保存训练后的网络模型和参数,方便进行再推理或再训练,MindSpore提供了ModelCheckpoint接口,一般与配置保存信息接口CheckpointConfig配合使用。

LossMonitor

用于监控训练或测试过程中的损失函数值Loss变化情况,可设置per_print_times控制打印Loss值的间隔。

训练场景下,LossMonitor监控训练的Loss值;边训练边推理场景下,监控训练的Loss值和推理的Metrics值

TimeMonitor

用于监控训练或测试过程的执行时间。可设置data_size控制打印执行时间的间隔。

自定义回调机制

MindSpore不仅有功能强大的内置回调函数,当用户有自己的特殊需求时,还可以基于Callback基类自定义回调类。

用户可以基于Callback基类,根据自身的需求,实现自定义Callback。Callback基类定义如下所示:

class Callback():

    """Callback base class"""

    def on_train_begin(self, run_context):

        """Called once before the network executing."""

    def on_train_epoch_begin(self, run_context):

        """Called before each epoch beginning."""

    def on_train_epoch_end(self, run_context):

        """Called after each epoch finished."""

    def on_train_step_begin(self, run_context):

        """Called before each step beginning."""

    def on_train_step_end(self, run_context):

        """Called after each step finished."""

    def on_train_end(self, run_context):

        """Called once after network training."""

回调机制可以把训练过程中的重要信息记录下来,通过把一个字典类型变量RunContext.original_args(),传递给Callback对象,使得用户可以在各个自定义的Callback中获取到相关属性,执行自定义操作,也可以自定义其他变量传递给RunContext.original_args()对象。

RunContext.original_args()中的常用属性有:

  • epoch_num:训练的epoch的数量
  • batch_num:一个epoch中step的数量
  • cur_epoch_num:当前的epoch数
  • cur_step_num:当前的step数
  • loss_fn:损失函数
  • optimizer:优化器
  • train_network:训练的网络
  • train_dataset:训练的数据集
  • net_outputs:网络的输出结果
  • parallel_mode:并行模式
  • list_callback:所有的Callback函数

通过下面两个场景,我们可以增加对自定义Callback回调机制功能的了解。

自定义终止训练

实现在规定时间内终止训练功能。用户可以设定时间阈值,当训练时间达到这个阈值后就终止训练过程。

下面代码中,通过run_context.original_args方法可以获取到cb_params字典,字典里会包含前文描述的主要属性信息。

同时可以对字典内的值进行修改和添加,在begin函数中定义一个init_time对象传递给cb_params字典。每个数据迭代结束step_end之后会进行判断,当训练时间大于设置的时间阈值时,会向run_context传递终止训练的信号,提前终止训练,并打印当前的epoch、step、loss的值。

从上面的打印结果可以看出,当第3个epoch的第4673个step执行完时,运行时间到达了阈值并结束了训练。

自定义阈值保存模型

该回调机制实现当loss小于设定的阈值时,保存网络模型权重ckpt文件。

示例代码如下:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1472585.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

sublime如何运行Html文件?

背景: 在sublime上面写了html代码以后,怎么运行html文件来进行debug呢?如果去点击保存的HTML文件,每次这样就会很麻烦,能不能直接在sublime里面点什么就可以直接打开浏览器运行呢?答案是OK的。 1-确认Vie…

通过营销本地化解锁全球市场

在一个日益互联的世界里,企业必须接触到全球各地的不同受众。营销本地化是打开这些全球市场的关键。它包括调整营销材料,使其与不同地区的文化和语言细微差别产生共鸣。以下是有效的营销本地化如何推动您的全球扩张,并用实际例子来说明每一点…

长难句打卡6.27

After about 40 minutes of shopping, most people stop struggling to be rationally selective, and instead begin shopping emotionally—which is the point at which we accumulate the 50 percent of stuff in our cart that we never intended buying. 购物约40分钟后&…

【C++】 解决 C++ 语言报错:Invalid Use of Incomplete Type

文章目录 引言 在 C 编程中,“Invalid Use of Incomplete Type” 是一种常见错误。此错误通常在程序试图使用未完全定义的类或结构时发生。这种错误不仅会导致编译失败,还可能导致程序行为不可预测。本文将详细探讨无效使用不完整类型的成因、检测方法及…

半导体切割研磨废水的处理技术

半导体切割研磨废水处理是一个复杂而关键的过程,其废水主要来源于切割、研磨等工艺环节,这些过程中使用的化学品、冷却水、洗涤水等会产生含有重金属、有机物、酸碱度不稳定以及高浓度硅化合物等污染物的废水。针对这些废水的特性,半导体行业…

UG NX二次开发(C#)-根据草图创建拉伸特征(UFun+NXOpen)

文章目录 1、前言2、在UG NX中创建草图,然后创建拉伸特征3、基于UFun函数的实现4、基于NXOpen的实现代码1、前言 UG NX是基于特征的三维建模软件,其中拉伸特征是一个很重要的特征,有读者问如何根据草图创建拉伸特征,我在这篇博客中讲述一下草图创建拉伸特征的UG NX二次开发…

“免费”的可视化大屏案例分享-智慧园区综合管理平台

一.智慧园区是什么? 智慧园区是一种融合了新一代信息与通信技术的先进园区发展理念。它通过迅捷信息采集、高速信息传输、高度集中计算、智能事务处理和无所不在的服务提供能力,实现了园区内及时、互动、整合的信息感知、传递和处理。这样的园区旨在提高…

vue2+ant(上传+下载)

下载(导出) 第一步——封装的axios 导出必须加responseType: blob 它是一个常用于 XMLHttpRequest 或 fetch API 的选项,它指定了响应的类型。当设置为 ‘blob’ 时,这意味着预期服务器返回的是一个二进制大对象(Blob…

多模态融合算法应用:CT + 临床文本数据 + pyradiomics提取到的图像特征

多模态融合算法应用 CT 临床文本数据 pyradiomics提取图像特征 单模态建模临床数据建模pyradiomics提取图像特征建模CT建模 多模态建模前融合为什么能直接合并在一起? 后融合Med-CLIP:深度学习 可解释性 单模态建模 临床数据建模 临床文本数据&…

【讨论C++多态】

讨论C多态 多态概念定义及实现 虚函数虚函数重写 final和override重载、重写和重定义抽象类纯虚函数接口继承和实现继承 多态的原理虚函数表打印单继承虚函数表动态绑定和静态绑定多继承虚函数表 多态 概念 多态即完成某个行为,不同对象会产生不同的状态。 定义及实…

基于Qt实现的PDF阅读、编辑工具

记录一下实现pdf工具功能 语言:c、qt IDE:vs2017 环境:win10 一、功能演示: 二、功能介绍: 1.基于saribbon主体界面框架,该框架主要是为了实现类似word导航项 2.加载PDF放大缩小以及预览功能 3.pdf页面跳转…

【MySQL】事务实现原理

目录 事务 如何使用 ACID 原子性(Atomicity) 原子性实现原理 持久性(Durability) 持久性实现原理 隔离性 隔离级别 读未提交 读已提交 可重复读 串行化 隔离级别原理 锁 共享锁&独占锁 意向锁 索引记录锁 间隙锁 临键锁 插入意向锁 自增锁 MVCC 实现…

Node.js 核心知识点 - Koa 框架

一、Koa 基本概念 官网:Koa - next generation web framework for node.js 1、Koa 是什么? Koa 是一个基于 Node.js 的轻量级 web 框架,由 Express 团队创造。Koa 的设计理念是使用现代的 JavaScript 特性(如 async/await&#x…

【Spring cloud】 认识微服务

文章目录 🍃前言🌴单体架构🎋集群和分布式架构🌲微服务架构🎍微服务带来的挑战⭕总结 🍃前言 本篇文章将从架构的演变过程来简单介绍一下微服务,大致分为一下几个部分 单体架构集群和分布式架…

九芯电子手把手教你选购电动车防盗语音报警器芯片

电动车,也叫电瓶车,加装的防盗器声音非常大,能使电动车防盗报警器变得更智能化,功能多样化。本文将介绍在选购电动车防盗语音报警芯片,应该考虑哪些因素,以确保所选产品既满足安全需求,又具备物…

深度学习-数学基础(四)

深度学习数学基础 数学基础线性代数-标量和向量线性代数-向量运算向量加和向量内积向量夹角余弦值 线性代数-矩阵矩阵加法矩阵乘法矩阵点乘矩阵计算的其他内容 人工智能-矩阵的操作矩阵转置(transpose)矩阵与向量的转化 线性代数-张量(tensor…

CFS三层内网渗透——第二层内网打点并拿下第三层内网(三)

目录 八哥cms的后台历史漏洞 配置socks代理 ​以我的kali为例,手动添加 socks配置好了,直接sqlmap跑 ​登录进后台 蚁剑配置socks代理 ​ 测试连接 ​编辑 成功上线 上传正向后门 生成正向后门 上传后门 ​内网信息收集 ​进入目标二内网机器&#xf…

phpcms 升级php8.3.8

windows 2008 server 不支持php8.3.8,需升级为windows 2012 1.下载php8.3.8 PHP8.3.9 For Windows: Binaries and sources Releases 2.配置php.ini (1.)在php目录下找到php.ini-development文件,把它复制一份,改名为php.ini (2.)修改php安装目录 根…

昇思MindSpore学习笔记4-02生成式--DCGAN生成漫画头像

摘要: 记录了昇思MindSpore AI框架使用70171张动漫头像图片训练一个DCGAN神经网络生成式对抗网络,并用来生成漫画头像的过程、步骤。包括环境准备、下载数据集、加载数据和预处理、构造网络、模型训练等。 一、概念 深度卷积对抗生成网络DCGAN Deep C…

Runway Gen-3 实测,这就是 AI 视频生成的 No.1!视频高清化EvTexture 安装配置使用!

Runway Gen-3 实测,这就是 AI 视频生成的 No.1!视频高清化EvTexture 安装配置使用! 由于 Runway 作为一个具体的工具或平台,其详细信息在搜索结果中没有提供,我将基于假设 Runway 是一个支持人工智能和机器学习模型的创意工具,提供一个关于使用技巧和类似开源项目的文稿总…