模型训练中出现loss为NaN怎么办?

在这里插入图片描述

文章目录

  • 一、模型训练中出现loss为NaN原因
    • 1. 学习率过高
    • 2. 梯度消失或爆炸
    • 3. 数据不平衡或异常
    • 4. 模型不稳定
    • 5. 过拟合
  • 二、 针对梯度消失或爆炸的解决方案
    • 1. 使用`torch.autograd.detect_anomaly()`
    • 2. 使用 torchviz 可视化计算图
    • 3. 检查梯度的数值范围
    • 4. 调整梯度剪裁
  • 三、更具体的办法
    • 3.1 可能导致梯度爆炸的部分
    • 3.2 解决方案

一、模型训练中出现loss为NaN原因

1. 学习率过高

在训练的某个阶段,学习率可能设置得过高,导致模型参数更新幅度过大,甚至可能出现数值不稳定的情况。你可以尝试降低学习率,并观察训练过程中的变化。

2. 梯度消失或爆炸

如果模型的某些层出现梯度消失或爆炸的问题,可能会导致loss变得异常低。你可以检查梯度的大小,确保它们在合理范围内。

3. 数据不平衡或异常

训练数据中可能存在异常值或分布不平衡的情况,导致模型在某些批次的训练过程中出现异常。你可以检查数据集,确保数据质量。

4. 模型不稳定

模型架构或训练过程中的某些设置可能导致不稳定,比如过深的网络、过复杂的模型等。你可以尝试简化模型架构或添加正则化项。

5. 过拟合

模型可能在某些阶段已经过拟合到训练数据上,导致训练loss异常低而验证loss较高。你可以通过早停法(early stopping)、正则化、数据增强等方法来缓解过拟合问题。
解决方法

  1. 调节学习率:适当降低学习率,观察训练过程中的变化。
  2. 检查梯度:通过torch.autograd检查梯度的大小,确保没有出现梯度消失或爆炸。
  3. 数据检查:确保数据集没有异常值或分布不平衡的情况。
  4. 模型架构:简化模型架构,增加正则化项,如L2正则化、dropout等。
  5. 验证集监控:通过监控验证集的loss和指标,防止过拟合。\

二、 针对梯度消失或爆炸的解决方案

使用 torch.autograd.detect_anomaly() 和相关工具确实可以帮助你检测和排除训练过程中出现的梯度问题。以下是如何在你的代码中使用这些工具来检测异常和可视化梯度的示例。

1. 使用torch.autograd.detect_anomaly()

这个函数可以帮助检测反向传播过程中出现的异常,并输出具体的错误信息和位置。

import torch# 定义模型
model = MyModel()# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 输入数据
inputs = torch.randn(56, 1024, 28, 28)
targets = torch.randint(0, 10, (56,))# 在训练过程中使用 detect_anomaly
with torch.autograd.detect_anomaly():outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()

2. 使用 torchviz 可视化计算图

torchviz 是一个可以帮助你可视化计算图的工具,这对于调试复杂的模型非常有用。

首先,安装 torchviz:

pip install torchviz

然后,可以使用以下代码来生成和保存计算图:

from torchviz import make_dot# 定义模型
model = MyModel()# 输入数据
inputs = torch.randn(56, 1024, 28, 28)# 获取模型输出
outputs = model(inputs)# 创建计算图
dot = make_dot(outputs, params=dict(model.named_parameters()))# 保存计算图
dot.format = 'png'
dot.render('model_graph')

3. 检查梯度的数值范围

你可以在每个训练步骤之后检查模型中各个参数的梯度,以确保梯度的数值范围正常。

# 定义模型
model = MyModel()# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 输入数据
inputs = torch.randn(56, 1024, 28, 28)
targets = torch.randint(0, 10, (56,))# 在训练过程中使用 detect_anomaly
with torch.autograd.detect_anomaly():outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()# 检查梯度数值范围for name, param in model.named_parameters():if param.grad is not None:grad_min = param.grad.min().item()grad_max = param.grad.max().item()print(f'{name} - grad_min: {grad_min}, grad_max: {grad_max}')optimizer.step()

4. 调整梯度剪裁

在训练过程中,可以使用梯度剪裁来防止梯度爆炸。以下是如何在 PyTorch 中实现梯度剪裁的示例:

# 定义模型
model = MyModel()# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 输入数据
inputs = torch.randn(56, 1024, 28, 28)
targets = torch.randint(0, 10, (56,))# 在训练过程中使用 detect_anomaly
with torch.autograd.detect_anomaly():outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()# 梯度剪裁torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()

通过以上方法,可以更好地检测和调试训练过程中出现的梯度问题,提高模型的训练稳定性和效率。如果在使用过程中发现任何异常或需要进一步调试,请随时提供更多细节。

三、更具体的办法

3.1 可能导致梯度爆炸的部分

  1. ReLU 激活函数的使用:激活函数可参考激活函数汇总
    ReLU 是一种常见的激活函数,但如果输入有较大的正值,经过 ReLU 之后,这些值会直接传递下去,可能导致后续层的梯度爆炸。考虑使用其他激活函数,如 Leaky ReLU 或 SELU,它们在某些情况下对梯度更友好。

    embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True)
    
  2. 特征插值:
    插值操作可能会生成较大的值,尤其是在上采样过程中。如果插值后的值过大,可能会导致梯度爆炸。
    upsample_feat = F.interpolate(feat_high, scale_factor=2., mode=‘nearest’)

  3. 特征拼接:
    多个特征拼接后,如果这些特征值过大,会导致拼接后的张量值过大,进而影响后续层的梯度。

    inner_out = self.fpn_blocks[len(proj_feats) - 1 - idx](torch.concat([upsample_feat, feat_low], dim=1))
    
  4. 全连接层:
    全连接层的权重初始化方式可能会导致梯度爆炸。确保使用了合适的初始化方法,如 Xavier 初始化或 He 初始化。

  5. 权重共享:
    如果多个部分共享权重,需要确保这些共享权重不会导致梯度的累积效应。

3.2 解决方案

  1. 梯度剪裁:
    在反向传播过程中使用梯度剪裁,可以防止梯度爆炸。你可以在 optimizer.step() 之前加上梯度剪裁。

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
  2. 使用更稳定的激活函数:
    尝试使用 Leaky ReLU 或 SELU,它们在某些情况下对梯度更友好。

  3. 检查权重初始化:
    确保所有层的权重初始化方式合理,避免初始值过大。

    for m in model.modules():if isinstance(m, (nn.Conv2d, nn.Linear)):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    
  4. 监控梯度值:
    在每次反向传播后,监控梯度的值,确保梯度不会爆炸。

    for name, param in model.named_parameters():if param.grad is not None:grad_min = param.grad.min().item()grad_max = param.grad.max().item()print(f'{name} - grad_min: {grad_min}, grad_max: {grad_max}')
    

Enjoy~

∼ O n e p e r s o n g o f a s t e r , a g r o u p o f p e o p l e c a n g o f u r t h e r ∼ \sim_{One\ person\ go\ faster,\ a\ group\ of\ people\ can\ go\ further}\sim One person go faster, a group of people can go further

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1483484.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

C++树(二)【直径,中心】

目录: 树的直径: 树的直径的性质: 性质1:直径的端点一定是叶子节点 性质2:任意点的最长链端点一定是直径端点。 性质3:如果一棵树有多条直径,那么它们必然相交,且有极长连…

自定义注解 + Redis 实现业务的幂等性

1.实现幂等性思路 实现幂等性有两种方式: ⭐ 1. 在数据库层面进行幂等性处理(数据库添加唯一约束). 例如:新增用户幂等性处理,username 字段可以添加唯一约束. ⭐ 2. 在应用程序层面进行幂等性处理. 而在应用程序…

一款由AI编写,简洁而实用的开源IP信息查看器

大家好,今天给大家分享一款用于查询和显示用户当前 IP 地址的轻量级项目MyIP。 MyIP提供了多种功能,包括IP地址查询、网络连通性检查、WebRTC连接检测、DNS泄露检查、网速测试、MTR测试等等。 使用MyIP,我们可以轻松地查看自己的公网IP地址&…

Linux网络——套接字与UdpServer

目录 一、socket 编程接口 1.1 sockaddr 结构 1.2 socket 常见API 二、封装 InetAddr 三、网络字节序 四、封装通用 UdpServer 服务端 4.1 整体框架 4.2 类的初始化 4.2.1 socket 4.2.2 bind 4.2.3 创建流式套接字 4.2.4 填充结构体 4.3 服务器的运行 4.3.1 rec…

迁移学习在乳腺浸润性导管癌病理图像分类中的应用

1. 引言 乳腺癌主要有两种类型:原位癌:原位癌是非常早期的癌症,开始在乳管中扩散,但没有扩散到乳房组织的其他部分。这也称为导管原位癌(DCIS)。浸润性乳腺癌:浸润性乳腺癌已经扩散(侵入)到周围的乳腺组织。侵袭性癌症比原位癌更难治愈。将乳汁输送到乳…

2024717-VSCode-1.19.1-部署gcc13-C++23-win10-22h2

2024717-VSCode-1.19.1-部署gcc13-C++23-win10-22h2 一、软件环境 标签:C++ VSCode mingw gcc13分栏:C++操作系统:Windows10 x64 22h2二、操作步骤 1. 下载安装VScode 1.1官网 打开官网【https://code.visualstudio.com/Download】,选择【System Installer】【x64】,按…

Java面试八股之什么是Redis的缓存更新

什么是Redis的缓存更新 Redis的缓存更新是指当缓存中的数据发生变化时,需要将这些变化同步到缓存中以保持数据的一致性。缓存更新的目的是确保缓存中的数据始终是最新的,以便用户可以获取到最新的数据。 常见的缓存更新策略包括: 直接覆盖…

AWS基础知识

VPC (Virtual Private Cloud): 参考:https://docs.aws.amazon.com/vpc/latest/userguide/what-is-amazon-vpc.html With Amazon Virtual Private Cloud (Amazon VPC), you can launch AWS resources in a logically isolated virtual network that you’ve defined…

昇思25天学习打卡营第30天 | MindNLP ChatGLM-6B StreamChat

今天是第30天,学习了MindNLP ChatGLM-6B StreamChat。 今天是参加打卡活动的最后一天,经过这些日子的测试,昇思MindSpore效果还是不错的。 ChatGLM-6B 是一个开源的、支持中英双语的对话语言模型,具有62亿参数,基于 …

PyTorch 深度学习实践-卷积神经网络高级篇

视频指路 参考博客笔记 参考笔记二 文章目录 上课笔记10.1GoogleNet(Inception 层)代码实现10.2 Residual Net代码实现 上课笔记 可以设置padding‘same’ 使输入输出大小一致 10.1GoogleNet(Inception 层) 说明:In…

【Node.js】初识 Node.js

Node.js 概念 Node.js 是一个开源与跨平台的 JavaScript运行时环境 ,在浏览器外运行 V8 JavaScript 引擎(Google Chrome的内核),利用事件驱动、非阻塞和异步输入输出 等技术提高性能。 可以理解为 Node.js就是一个服务器端的、非阻塞式 l/O 的、事件驱…

Mac 安装MySQL 配置环境变量 修改密码

文章目录 1 下载与安装2 配置环境变量3 数据库常用命令3.1 Mac使用设置管理mysql服务启停 4 数据库修改root密码4.1 知道当前密码4.2 忘记当前密码4.3 问题 参考 1 下载与安装 官网:https://www.mysql.com/ 找到开源下载方式 下载社区版 2 配置环境变量 对于Mac…

NVIDIA Container Toolkit 安装与配置帮助文档(Ubuntu,Docker)

NVIDIA Container Toolkit 安装与配置帮助文档(Ubuntu,Docker) 本文档详细介绍了在 Ubuntu Server 22.04 上使用 Docker 安装和配置 NVIDIA Container Toolkit 的过程。 概述 NVIDIA 容器工具包使用户能够构建和运行 GPU 加速容器。即可以在容器中使用NVIDIA显卡。 架构图如…

观测云对接 Fluentd 采集业务日志最佳实践

概述 Fluentd 是一个开源数据收集器,专为简化日志管理和使日志数据更加易于访问、使用而设计。作为一个高度可扩展的工具,它能够统一数据收集和消费过程,使得构建实时分析的日志系统变得更加高效。 观测云目前已集成 Fluentd ,可…

milvus的collection操作

milvus的collection操作 创建collection import uuidfrom pymilvus import (connections,FieldSchema, CollectionSchema, DataType,Collection, )collection_name "hello_milvus" host "192.168.230.71" port 19530 username "" password…

VSCode中通过launch.json文件打断点DeBug调试代码(详细图文教程)

先吐槽 IDE编译工具调试代码是非常重要的,之前使用Pycharm很方便,直接在Configuration中配置参数就行,见下。使用VSCode进行有命令代码调试时相对麻烦一些,看其它教程没撤清楚,这里做个总结,学者耐心学习。…

01 MySQL

学习资料:B站视频-黑马程序员JavaWeb基础教程 文章目录 JavaWeb整体介绍 MySQL1、数据库相关概念2、MySQL3、SQL概述4、DDL:数据库操作5、DDL:表操作6、DML7、DQL8、约束9、数据库设计10、多表查询11、事务 JavaWeb整体介绍 JavaWeb Web:全球广域网&…

网络准入控制设备是什么?有哪些?网络准入设备臻品优选

小李:“小张,最近公司网络频繁遭遇外部攻击,我们得加强一下网络安全了。” 小张:“是啊,我听说实施网络准入控制是个不错的选择。但具体什么是网络准入控制设备?我们有哪些选择呢?” 小李微笑…

基于 MelosBoom ,捕获 DePIN 赛道发展红利

Melos是一个Web3音乐领域的先驱性生态,其允许任何人通过其工具创作音乐,生成的内容可以保存为NFT并进入流通,同时支持该音乐资产支持开放再创作。最为最具影响力以及发展潜力的Web3音乐生态,其不仅获得了来自于头部VC Binance Lab…

分布式缓存-Redis持久化

使用缓存的时候,我们经常需要对内存中的数据进行持久化(将内存中的数据写入到硬盘中)。 原因:重用数据(比如重启机器、机器故障之后恢复数据),做数据同步(比如 Redis 集群的主从节点…