昇思25天学习打卡营第08天 | 模型训练

昇思25天学习打卡营第08天 | 模型训练

文章目录

  • 昇思25天学习打卡营第08天 | 模型训练
    • 超参数
    • 损失函数
    • 优化器
      • 优化过程
    • 训练与评估
    • 总结
    • 打卡

模型训练一般遵循四个步骤:

  1. 构建数据集
  2. 定义神经网络模型
  3. 定义超参数、损失函数和优化器
  4. 输入数据集进行训练和评估

构建数据集和网络模型在之前的内容在已经涉及,不再赘述。

超参数

超参数(Hyperparameters)是可以调整的参数,可以控制模型训练的过程。

深度学习模型多采用随机梯度下降算法SGD进行优化:
w t + 1 = w t − η 1 n ∑ x ∈ B ∇ l ( x , w t ) w_{t+1}=w_t- \eta\frac1n\sum_{x\in B}\nabla l(x,w_t) wt+1=wtηn1xBl(x,wt)
其中, η \eta η是学习率, n n n是batch大小,都是超参数,这两个参数是直接影响模型性能收敛的重要参数。
一般会定义三个超参数:

  • epoch:遍历数据集的次数
  • batch size:每个批次数据的大小。size 过小导致花费时间多,梯度震荡严重,不利于收敛;size 过大容易陷入局部极小值。
  • learning rate:学习率国小会导致收敛速度变慢;过大则可能会导致训练不收敛。

损失函数

损失函数用于评估模型预测值和目标值之间的误差。
常见的损失函数包括:

  • nn.MSELoss:均方误差,用于回归
  • nn.NLLLoss:负对数似然,用于分类
  • nn.CrossEntropyLoss:结合了nn.LogSoftmaxnn.NLLLoss,可以对logits进行归一化并计算预测误差
loss_fn = nn.CrossEntropyLoss()

优化器

优化器内部定义了模型参数的优化过程,所有的优化逻辑都封装在优化器对象中。

optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

优化过程

通过自动微分获得的微分函数,计算参数对应的梯度,并传入优化器中,即可实现参数优化。

grads = grad_fn(inputs)
optimizer(grads)

训练与评估

遍历一次数据集被称为一轮(epoch),每轮执行训练时包含两个步骤:

  1. 训练:迭代训练数据集,并尝试收敛到最佳参数。
  2. 验证/测试:迭代测试数据集,检查模型性能是否提升。
# Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train_loop(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")def test_loop(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练过程一般为:

loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(model, train_dataset)test_loop(model, test_dataset, loss_fn)

总结

这一节的内容对深度学习模型训练的一般过程进行了详细的介绍,从数据集构建到模型定义,接着定义超参数并选择合适的值,创建损失函数和优化器对象完成训练前的准备。通过封装一个模型调用和loss计算的前向计算函数并自动微分,在每个epoch中计算loss并优化参数,从而完成模型的训练。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1473907.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【致知功夫 各随分限】成长需要时间,助人须考虑对方的承受程度

帮助他人需考虑各人的分限所能及的,初学圣学需时间沉淀,存养心性 任何人都应该受到教育,不应受到贫富、贵贱的差异而排除在教育之外,对于不同材质的学生,需要因材施教; 每天都有新的认知,大我…

基于工业互联网的智慧矿山解决方案PPT(38页)

文章摘要 工业互联网与智慧矿山 基于工业互联网的新一代智慧矿山解决方案,将互联网和新一代IT技术与工业系统深度融合,形成关键的产业和应用生态,推动工业智能化发展。该方案以“四级、三层、两网、一平台”为总体框架,强调应用目…

[Vite]Vite插件生命周期了解

[Vite]Vite插件生命周期了解 Chunk和Bundle的概念 Chunk: 在 Vite 中,chunk 通常指的是应用程序中的一个代码片段,它是通过 Rollup 或其他打包工具在构建过程中生成的。每个 chunk 通常包含应用程序的一部分逻辑,可能是一个路由视…

2024菜鸟春招笔试

第一题 解题思路: 签到题,把帖子按好评度降序排列,再将人按升序排列。 第二题 解题思路 从左到右遍历,如果当前元素没有错排,将其与后一个交换,这样两个元素一定都错排。 第三题 、 解题思路 这题当时暴力…

智能运维场景探索 | 运营分析

【本场景来源于 擎创科技《一体化数智运维AIOps解决方案》白皮书,经过重新编写】 该场景主要围绕生产运行、运营决策两个维度进行展开,通过对配置、性能、业务等运行数据的加工计算,形成可量化运营效果、可衡量发展方向的运营数据。整体以低…

陈志泊主编《数据库原理及应用教程第4版微课版》的实验题目参考答案实验2

实验目的 1.掌握在SQL Server中使用对象资源管理器和SQL命令创建数据库与修改数据库的方法。 2.掌握在SQL Server中使用对象资源管理器或者SQL命令创建数据表和修改数据表的方 法(以SQL命令为重点)。 实验设备 操作系统:Win11…

CV03_mAP计算以及COCO评价标准

COCO数据集回顾:CV02_超强数据集:MSCOCO数据集的简单介绍-CSDN博客 1.1 简介 在目标检测领域中,mAP(mean Average Precision,平均精度均值)是一个广泛使用的性能评估指标,用于衡量目标检测模型…

MongoDB集群搭建-最简单

目录 前言 一、分片概念 二、搭建集群的步骤 总结 前言 MongoDB分片(Sharding)是一种水平扩展数据库的方法,它允许将数据分散存储在多个服务器上,从而提高数据库的存储容量和处理能力。分片是MongoDB为了应对大数据量和高吞吐量需…

创新引领未来,智慧水利在路上:数字孪生技术为水库管理开辟新机遇,带来新挑战,引领水利行业迈向智能化新纪元

目录 前言 一、数字孪生技术概述 二、新机遇:数字孪生技术如何重塑水库管理 1、精准预测,科学调度 2、智能监测,及时预警 3、优化资源配置,提升管理效率 4、促进公众参与,增强透明度 三、新挑战:数字…

Fill - UVA 10603

网址如下&#xff1a; Fill - UVA 10603 - Virtual Judge (vjudge.net) 感觉有点浮躁&#xff0c;没法完全将思绪投入题的思考中 脑袋糊糊的 一道bfs题 代码如下&#xff1a; #include<queue> #include<cstdio> #include<cstring> #include<vector&g…

奇迹MU 骷髅战士在哪

BOSS分布图介绍 我为大家带来各地区怪物分布图。在游戏前期&#xff0c;很多玩家可能会不知道该去哪里寻找怪物&#xff0c;也不知道哪些怪物值得打。如果选择了太强的怪物&#xff0c;弱小的玩家可能会无法抵御攻击。如果选择了低等级的boss&#xff0c;收益可能并不理想。所…

吴恩达机器学习 第三课 week3 强化学习(月球着陆器自动着陆)

目录 01 学习目标 02 概念 2.1 强化学习 2.2 深度Q学习&#xff08;Deep Q-Learning &#xff09; 03 问题描述 04 算法中的概念及原理 05 月球着陆器自动着陆的算法实现 06 拓展&#xff1a;基于pytorch实现月球着陆器着陆 07 总结 写在最前&#xff1a;关于强化学习…

【MindSpore学习打卡】应用实践-自然语言处理-基于RNN的情感分类:使用MindSpore实现IMDB影评分类

情感分类是自然语言处理&#xff08;NLP&#xff09;中的一个经典任务&#xff0c;广泛应用于社交媒体分析、市场调研和客户反馈等领域。本篇博客将带领大家使用MindSpore框架&#xff0c;基于RNN&#xff08;循环神经网络&#xff09;实现一个情感分类模型。我们将详细介绍数据…

UE5 07-给物体添加一个拖尾粒子

添加一个(旧版粒子系统)cascade粒子系统组件 ,在模板中选择一个开发学习初始包里的粒子

智慧文旅(景区)解决方案PPT(42页)

智慧文旅解决方案摘要 行业分析中国旅游业正经历消费大众化、需求品质化、发展全域化和产业现代化的发展趋势。《“十三五”旅游业发展规划》的发布&#xff0c;以及文化和旅游部的设立&#xff0c;标志着旅游业的信息化和智能化建设成为国家战略。2018年推出的旅游行业安全防范…

cs224n作业4

NMT结构图&#xff1a;&#xff08;具体结构图&#xff09; LSTM基础知识 nmt_model.py&#xff1a; 参考文章&#xff1a;LSTM输出结构描述 #!/usr/bin/env python3 # -*- coding: utf-8 -*-""" CS224N 2020-21: Homework 4 nmt_model.py: NMT Model Penchen…

2024年导游资格证题库备考题库,高效备考!

1.台湾著名的太鲁阁公园的特色是&#xff08;&#xff09;。 A.丘陵和溶洞 B.森林和瀑布 C.峡谷和断崖 D.彩林和彩池 答案&#xff1a;C 解析&#xff1a;台湾著名的太鲁阁公园的特色是峡谷和断崖。 2.下列位于台湾的景区中&#xff0c;素有"神秘的森林王国"之…

51单片机STC89C52RC——15.1 AD/DA(模数数模)

目的/效果 1 LCD1602 显示 可调电阻、光敏电阻、热敏电阻值&#xff08;AD&#xff09; 2 模拟信号控制LED明暗&#xff08;DA&#xff09; 一&#xff0c;STC单片机模块 二&#xff0c;AD/DA 2.1 AD/DA 介绍 AD&#xff08;Analog to Digital&#xff09;&#xff1a;模拟…

金丝键合强度测试仪试验条件要求:键合拉脱/引线拉力/剪切力等

金丝键合强度测试仪是测量引线键合强度&#xff0c;评估键合强度分布或测定键合强度是否符合有关的订购文件的要求。键合强度试验机可应用于采用低温焊、热压焊、超声焊或有关技术键合的、具有内引线的器件封装内部的引线-芯片键合、引线-基板键合或内引线一封装引线键合&#…

Redis的zset的zrem命令可以做到O(1)吗?

事情是这样的&#xff0c;当我用zrem命令去移除value的时候&#xff0c;我知道他之前会做的几个步骤 1、查找这个value对应的score&#xff08;通过zset中的dict&#xff09;2、根据这个score查找到跳表中的节点3、删除这个节点 我就想了一下为什么dict为什么要保存score呢&a…