对pytorch optimizer中state_dict、state、param_groups的简要理解

先说结论:

  • state_dict():一个dict,里面有两个key(stateparam_groups),

    • state这个key对应的value是各个权重对应的优化器状态。具体来说,一个model有很多权重,model.parameters()会打印出该模型的各层的权重,比如使用Adam,每层权重都有一个momentum和variance,形状与权重相同,还有该层当前更新到的步数。state_dict()['state']是一个dict,每个key-value item结构如下:
      该权重在model.parameters()中的位置 : {'step': tensor, 'exp_avg': tensor, # exp_avg: exponential moving average of gradient values'exp_avg_sq: tensor # exp_avg_sq: exponential moving average of squared gradient values
      
    • param_groups这个key对应的value是一个list,其中每个元素都是超参数组成的一个dict,因为不同的权重可以使用不同的超参数,所以需要使用list来表示,而且dict中params表示该超参数配置作用于哪些权重。state_dict()['param_groups']是一个list,每个元素结构如下
      {'lr': 0.01, 'weight_decay': 0,  ...  , 'params', [该超参数配置作用于的权重的位置]}
      
  • state:是一个defaultdict,包含的信息类似于state_dict()['state']+model.parameters(),具体来说,每个key-value item结构如下:

    param_tensor :{'step': tensor, 'exp_avg': tensor, 'exp_avg_sq': tensor,	
    }
    
  • param_groups:是一个list,包含的信息类似于state_dict()['param_groups']+model.parameters(),具体来说,每个元素结构如下:

    {'params': [param1, param2, ...]'lr': 0.01, 'weight_decay': 0, ...# 注意相较于state_dict()['param_groups'],原来'params'这个key对应的是param的索引位置,现在直接就是tensor了
    }
    

示例代码:

import torch
from torch.nn import Module
from torch.optim import Adamclass MyModel(Module):def __init__(self, in_dim, hidden_dim):super(MyModel, self).__init__()self.linear = torch.nn.Linear(in_features=in_dim, out_features=hidden_dim, bias=True)self.linear2 = torch.nn.Linear(in_features=hidden_dim, out_features=in_dim, bias=False)def forward(self, x):y = self.linear(x)out = self.linear2(y)return outin_dim = 5
hidden_dim = 2
model = MyModel(in_dim=in_dim, hidden_dim=hidden_dim)optimier = Adam([{'params': model.linear.parameters(), 'lr': 0.05},{'params': model.linear2.parameters()}
], lr=0.01)x = torch.randn((in_dim))
out = model(x)
loss = torch.sum(out, dim=-1)
optimier.zero_grad()
loss.backward()
optimier.step()print('#' * 100)
print(optimier.state_dict())print('#' * 100)
print(optimier.state)print('#' * 100)
print(optimier.param_groups)

输出:

####################################################################################################
# state_dict()
{'state': {0: {'step': tensor(1.), 'exp_avg': tensor([[ 0.0503,  0.0738, -0.0199,  0.0365, -0.0079],[ 0.0139,  0.0204, -0.0055,  0.0101, -0.0022]]), 'exp_avg_sq': tensor([[2.5308e-04, 5.4452e-04, 3.9464e-05, 1.3313e-04, 6.2210e-06],[1.9335e-05, 4.1600e-05, 3.0150e-06, 1.0171e-05, 4.7527e-07]])}, 1: {'step': tensor(1.), 'exp_avg': tensor([0.0406, 0.0112]), 'exp_avg_sq': tensor([1.6472e-04, 1.2584e-05])}, 2: {'step': tensor(1.), 'exp_avg': tensor([[-0.0268,  0.0085],[-0.0268,  0.0085],[-0.0268,  0.0085],[-0.0268,  0.0085],[-0.0268,  0.0085]]), 'exp_avg_sq': tensor([[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06]])}}, 'param_groups': [{'lr': 0.05, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': False, 'params': [0, 1]}, {'lr': 0.01, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': False, 'params': [2]}]}####################################################################################################
# state
defaultdict(<class 'dict'>, {Parameter containing: tensor([[-0.1744, -0.0656,  0.3184, -0.2081,  0.2448],[ 0.3069, -0.4000, -0.0727,  0.3283,  0.1722]], requires_grad=True): {'step': tensor(1.), 'exp_avg': tensor([[ 0.0503,  0.0738, -0.0199,  0.0365, -0.0079],[ 0.0139,  0.0204, -0.0055,  0.0101, -0.0022]]), 'exp_avg_sq': tensor([[2.5308e-04, 5.4452e-04, 3.9464e-05, 1.3313e-04, 6.2210e-06],[1.9335e-05, 4.1600e-05, 3.0150e-06, 1.0171e-05, 4.7527e-07]])}, Parameter containing: tensor([ 0.1764, -0.1476], requires_grad=True): {'step': tensor(1.), 'exp_avg': tensor([0.0406, 0.0112]), 'exp_avg_sq': tensor([1.6472e-04, 1.2584e-05])}, Parameter containing: tensor([[-0.2588, -0.5732],[-0.2472,  0.2319],[ 0.4441, -0.6283],[ 0.5832,  0.3760],[-0.0654,  0.6558]], requires_grad=True): {'step': tensor(1.), 'exp_avg': tensor([[-0.0268,  0.0085],[-0.0268,  0.0085],[-0.0268,  0.0085],[-0.0268,  0.0085],[-0.0268,  0.0085]]), 'exp_avg_sq': tensor([[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06]])}}
)
####################################################################################################
# param_groups
[{'params': [Parameter containing: tensor([[-0.1744, -0.0656,  0.3184, -0.2081,  0.2448],[ 0.3069, -0.4000, -0.0727,  0.3283,  0.1722]], requires_grad=True), Parameter containing: tensor([ 0.1764, -0.1476], requires_grad=True)], 'lr': 0.05, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': False}, {'params': [Parameter containing: tensor([[-0.2588, -0.5732],[-0.2472,  0.2319],[ 0.4441, -0.6283],[ 0.5832,  0.3760],[-0.0654,  0.6558]], requires_grad=True)], 'lr': 0.01, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': False}
]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1488650.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

深入理解 Java 虚拟机第三版(周志明)

这次社招选的这本作为 JVM 资料查阅&#xff0c;记录一些重点 1. 虚拟机历史 Sun Classic VM &#xff1a;已退休 HotSpot VM&#xff1a;主流虚拟机&#xff0c;热点代码探测技术 Mobile / Embedded VM &#xff1a;移动端、嵌入式使用的虚拟机 2.2 运行时数据区域 程序计…

Vue入门记录(一)

效果 本文为实现如下前端效果的学习实践记录&#xff1a; 实践 入门的最佳实践我觉得是先去看官网&#xff0c;官网一般都会有快速入门指引。 根据官网的快速上手文档&#xff0c;构建一个新的Vue3TypeScript&#xff0c;查看新建的项目结构&#xff1a; 现在先重点关注comp…

学生信息管理系统详细设计文档

一、设计概述 学生信息管理系统是一个用于管理学生信息的软件系统&#xff0c;旨在提高学校对学生信息的管理效率。本系统主要包括学生信息管理、课程信息管理、成绩信息管理、班级信息管理等功能模块。详细设计阶段的目标是确定各个模块的实现算法&#xff0c;并精确地表达这…

【AIGC】Llama-3 官方技术报告

Llama-3 技术报告&#xff08;中文翻译&#xff09; 欢迎关注【youcans的AGI学习笔记】原创作品 0. 简介 现代人工智能&#xff08;AI&#xff09;系统的核心驱动力来自于基础模型。 本文介绍了一组新的基础模型&#xff0c;称为 Llama 3。它是一个语言模型系列&#xff0c;原…

基于STM32瑞士军刀--【FreeRTOS开发】学习笔记(二)|| 堆 / 栈

堆和栈 1. 堆 堆就是空闲的一块内存&#xff0c;可以通过malloc申请一小块内存&#xff0c;用完之后使用再free释放回去。管理堆需要用到链表操作。 比如需要分配100字节&#xff0c;实际所占108字节&#xff0c;因为为了方便后期的free&#xff0c;这一小块需要有个头部记录…

Mysql第五次作业 触发器和存储过程

1、建库建表 建立触发器&#xff0c;订单表中增加订单数量后&#xff0c;商品表商品数量同步减少对应的商品订单出数量,并测试 建立触发器&#xff0c;实现功能:客户取消订单&#xff0c;恢复商品表对应商品的数量 建立触发器&#xff0c;实现功能:客户修改订单&#xff0c;商品…

护眼大路灯哪个口碑最好?五款专业护眼大路灯分享

护眼大路灯哪个口碑最好&#xff1f;护眼大路灯作为一款能够真正改善光线环境&#xff0c;有效做到减少视觉疲劳的护眼大路灯&#xff0c;逐渐成为众多家庭的必备照明神器。然而&#xff0c;市面上的护眼大路灯品牌琳琅满目&#xff0c;性能参差不齐&#xff0c;部分低质产品在…

docker安装httpd服务

docker安装httpd 一、简介 1、docker Docker是一个开源的容器化平台&#xff0c;可以轻松构建、发布和运行应用程序 2、httpd Apache HTTP服务器&#xff08;httpd&#xff09;是一个流行的开源Web服务器软件&#xff0c;用于托管网站和Web应用 二、准备环境 1、CentOS …

Docker安全管理与HTTPS协议

1 Docker容器的安全管理注意事项 Docker本身的架构与机制就可能产生问题&#xff0c;例如这样一种攻击场景&#xff0c;黑客已经控制了宿主机上的一些容器&#xff0c;或者获得了通过在公有云上建立容器的方式&#xff0c;然后对宿主机或其他容器发起攻击。 1. 容器之间的局…

C++ Lambda表达式个人理解

1、Lambda概述 lambda表达式&#xff08;也称为lambda函数&#xff09;是在调用或作为函数参数传递的位置处定义匿名函数对象的便捷方法。通常&#xff0c;lambda用于封装传递给算法或异步方法的几行代码。 2、Lambda表达式定义 2.1 Lambda表达式实例 Lambda有很多叫法&…

按图搜索新体验:阿里巴巴拍立淘API返回值详解

阿里巴巴拍立淘API是一项基于图片搜索的商品搜索服务&#xff0c;它允许用户通过上传商品图片&#xff0c;系统自动识别图片中的商品信息&#xff0c;并返回与之相关的搜索结果。以下是对阿里巴巴拍立淘API返回值的详细解析&#xff1a; 一、主要返回值内容 商品信息 商品列表…

深度学习趋同性的量化探索:以多模态学习与联合嵌入为例

深度学习趋同性的量化探索&#xff1a;以多模态学习与联合嵌入为例 参考文献 据说是2024年最好的人工智能论文&#xff0c;是否有划时代的意义&#xff1f; [2405.07987] The Platonic Representation Hypothesis (arxiv.org) ​arxiv.org/abs/2405.07987 趋同性的量化表达 …

CentOS搭建Apache服务器

安装对应的软件包 [roothds ~]# yum install httpd mod_ssl -y 查看防火墙的状态和selinux [roothds ~]# systemctl status firewalld [roothds ~]# cat /etc/selinux/config 若未关闭&#xff0c;则关闭防火墙和selinux [roothds ~]# systemctl stop firewalld [roothds ~]# …

全新微软语音合成网页版源码,短视频影视解说配音网页版系统-仿真人语音

源码介绍 最新微软语音合成网页版源码&#xff0c;可以用来给影视解说和短视频配音。它是TTS文本转语言&#xff0c;API接口和PHP源码。 这个微软语音合成接口的源码&#xff0c;超级简单&#xff0c;就几个文件搞定。用的是官方的API&#xff0c;试过了&#xff0c;合成速度…

Github个人网站搭建详细教程【Github+Jekyll模板】

文章目录 前言一、介绍1 Github Pages是什么2 静态网站生成工具3 Jekyll简介Jekyll 和 GitHub 的关系 4 Mac系统Jekyll的安装及使用安装Jekyll的简单使用 二、快速搭建第一个Github Pages网站三、静态网站模板——Chirpy1 个人定制 四、WordPress迁移到Github参考资料 前言 23…

DMv8共享存储集群部署

DMv8共享存储集群部署 环境说明 操作系统&#xff1a;centos7.6 服务器&#xff1a;2台虚拟机 达梦数据库版本&#xff1a;达梦V8 安装前准备工作 参考达梦官方文档&#xff1a;https://eco.dameng.com/document/dm/zh-cn/ops/DSC-installation-cluster.html#%E4%B8%80%E3…

Java面试八股之什么是spring boot starter

什么是spring boot starter Spring Boot Starter是Spring Boot项目中的一个重要概念。它是一种依赖管理机制&#xff0c;用于简化Maven或Gradle配置文件中的依赖项声明。Spring Boot Starter提供了一组预定义的依赖关系&#xff0c;这些依赖关系被封装在一个单一的包中&#x…

昇思25天学习打卡营第22天|munger85

LSTMCRF序列标注 我们希望得到这个模型来对词进行标注&#xff0c;B是开始&#xff0c;I是实体词的非开始&#xff0c;O是非实体词。 我们首先需要lstm对序列里token的记忆&#xff0c;和计算每个token发到crf的分数&#xff0c;发完了再退出来&#xff0c;最后形成1模型。那么…

海山数据库(He3DB)技术解析:海山Redis定时任务与持久化管控设计

文章目录 引言一、背景介绍二、具体实现1、多副本容灾功能2、主备切换后任务断点续做功能3、持久化管控编排功能 三、总结作者 引言 云Redis数据库服务是目前广泛应用的模式&#xff0c;其数据持久化方案是现在研究的热点内容&#xff0c;数据持久化操作主要由参数设置自动触发…

500元左右有好用的开放式耳机吗?百元开放式耳机推荐

正所谓授人以鱼不如授人以渔&#xff0c;在此大圣分享一下我选开放式耳机的的一切技巧。 在挑选开放式耳机的时候&#xff0c;我主要会考察以下这些点&#xff1a; 1-音质表现 关注频响范围&#xff0c;确保能涵盖您常听音乐类型所需的频率。 留意声音的清晰度、层次感和失…