深度学习——管理模型的参数

改编自李沐老师《动手深度学习》5.2. 参数管理 — 动手学深度学习 2.0.0 documentation (d2l.ai)

  在深度学习中,一旦我们选择了模型架构并设置了超参数,我们就会进入训练阶段。训练的目标是找到能够最小化损失函数的模型参数。这些参数在训练后用于预测,有时我们也需要将它们提取出来,以便在其他环境中使用,或者保存模型以便在其他软件中执行,甚至是为了科学理解而进行检查。

参数访问

访问模型参数

在PyTorch中,我们可以通过模型的层来访问参数。每一层都有自己的参数,比如权重和偏置。我们可以通过索引来访问这些参数。

import torch
from torch import nn# 定义一个简单的模型
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
X = torch.rand(size=(2, 4))
output = net(X)

我们可以通过索引来检查模型中特定层的参数。

# 打印第二层(全连接层)的参数
print(net[2].state_dict())

这会显示第二层的权重和偏置,它们是模型学习的关键部分。

访问特定参数的值

我们可以进一步提取特定参数的值。这通常在我们需要对参数进行特定操作时非常有用。

# 打印第二层的偏置参数
print(net[2].bias)
print(net[2].bias.data)

参数是复合对象,包含值、梯度和其他信息。在没有进行反向传播的情况下,参数的梯度处于初始状态。

一次性访问所有参数

当需要对所有参数执行操作时,可以一次性访问所有参数。这在处理大型模型时尤其有用。

# 打印所有层的参数名称和形状
print(*[(name, param.shape) for name, param in net.named_parameters()])

从嵌套块收集参数

当模型由多个子模块组成时,我们可以通过类似列表索引的方式来访问这些子模块的参数。

# 定义一个子模块
def block1():return nn.Sequential(nn.Linear(4, 8), nn.ReLU(),nn.Linear(8, 4), nn.ReLU())# 定义一个包含多个子模块的模型
def block2():net = nn.Sequential()for i in range(4):net.add_module(f'block {i}', block1())return net# 创建一个包含嵌套子模块的模型
rgnet = nn.Sequential(block2(), nn.Linear(4, 1))
output = rgnet(X)# 打印模型结构
print(rgnet)# 访问嵌套子模块的参数
print(rgnet[0][1][0].bias.data)

参数初始化

内置初始化

PyTorch提供了多种预置的初始化方法,我们可以根据需要选择。

# 初始化所有权重为高斯随机变量,偏置为0
def init_normal(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, mean=0, std=0.01)nn.init.zeros_(m.bias)
net.apply(init_normal)

自定义初始化

有时,我们需要自定义初始化方法来满足特定的需求。

# 自定义初始化方法
def my_init(m):if type(m) == nn.Linear:print("Init", *[(name, param.shape)for name, param in m.named_parameters()][0])nn.init.uniform_(m.weight, -10, 10)m.weight.data *= m.weight.data.abs() >= 5net.apply(my_init)

参数绑定

有时我们希望在多个层间共享参数。在PyTorch中,我们可以通过引用同一个层的参数来实现这一点。

# 定义一个共享层
shared = nn.Linear(8, 8)# 使用共享层构建模型
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(),shared, nn.ReLU(),shared, nn.ReLU(),nn.Linear(8, 1))
output = net(X)# 检查参数是否相同
print(net[2].weight.data[0] == net[4].weight.data[0])# 改变一个参数,另一个也会改变
net[2].weight.data[0, 0] = 100
print(net[2].weight.data[0] == net[4].weight.data[0])

这个例子展示了如何在模型的不同层之间共享参数,以及如何通过改变一个参数来影响另一个参数。这种技术在构建复杂的神经网络时非常有用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1540419.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

计算机毕业设计 美发管理系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…

Chat2VIS: Generating Data Visualizations via Natural Language

Chat2VIS:通过使用ChatGPT, Codex和GPT-3大型语言模型的自然语言生成数据可视化 梅西大学数学与计算科学学院,新西兰奥克兰 IEEE Access 1 Abstract 数据可视化领域一直致力于设计直接从自然语言文本生成可视化的解决方案。自然语言接口 (NLI) 的研究为这些技术的…

从虚拟到现实:数字孪生与数字样机的进化之路

数字化技术高速发展的当下,计算机辅助技术已成为产品设计研发中不可或缺的一环,数字样机(Digital Prototype, DP)与数字孪生技术便是产品研发数字化的典型方法。本文将主要介绍数字样机与数字孪生在国内外的发展,并针对…

Java 并发编程 —— AQS 抽象队列同步器

文章目录 什么是 AQS底层数据结构—— CLH 队列入队和出队状态标志位AQS 的代码设计思路AQS 提供的钩子方法参考资料 什么是 AQS AQS 是 JUC 提供的一个用于构建锁和同步容器的基础类,用于减少由于无效争夺导致的资源浪费和性能恶化。JUC 包内的许多类都是基于 AQS…

【JPCS出版】第四届电气工程与计算机技术国际学术会议(ICEECT 2024,9月27-29)

会议信息 会议官网:www.iceect.com 2024 4th International Conference on Electrical Engineering and Computer Technologywww.iceect.com 时间地点:2024年9月27日-29日 | 线上(ZOOM) 最终截稿时间:9月23日 主办…

【C++篇】C++类与对象深度解析(六):全面剖析拷贝省略、RVO、NRVO优化策略

文章目录 C类与对象前言读者须知RVO 与 NRVO 的启用条件如何确认优化是否启用? 1. 按值传递与拷贝省略1.1 按值传递的概念1.2 示例代码1.3 按值传递的性能影响1.3.1 完全不优化 1.4 不同编译器下的优化表现1.4.1 Visual Studio 2019普通优化1.4.2 Visual Studio 202…

2024.9.20营养小题【1】

这道题并不难,但是通过这道题,对知识有了一些更深一点的理解吧。 我们知道,数组名代表的其实是数组中首元素的指针;字符串其实是一个数组;所以字符串名是指向字符串中首元素地址的指针;strlen(字符串名&am…

Spring Boot利用dag加速Spring beans初始化

1.什么是Dag? 有向无环图(Directed Acyclic Graph),简称DAG,是一种有向图,其中没有从节点出发经过若干条边后再回到该节点的路径。换句话说,DAG中不存在环路。这种数据结构常用于表示并解决具有依赖关系的问题。 DAG的…

什么是损失函数?常见的损失函数有哪些?

损失函数 什么是损失函数?损失函数作用如何设计损失函数常见的损失函数有哪些? 什么是损失函数? 损失函数(Loss Function),也称为误差函数,是机器学习和深度学习中的一个重要概念。它用于衡量模…

python怎么打开编辑器

1、在电脑开始菜单中点击所有程序,找到Python程序,点击其中idle。 2、然后点击左上角的“File”,打开菜单,在下拉菜单中选择“New File”选项,就可打开python编辑器了。 3、在打开的python编辑器中就可以输入自己想写的…

105.游戏安全项目-基址的技术原理-分析技巧

免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 内容参考于:易道云信息技术研究院 本人写的内容纯属胡编乱造,全都是合成造假,仅仅只是为了娱乐,请不要盲目相信…

如何衡量企业品牌力?判断指标有哪些?

企业品牌力是指品牌在市场中的竞争力和影响力,它反映了品牌的价值、知名度、忠诚度、感知质量、差异化以及市场表现等方面。要去衡量一个企业的品牌力,大多从品牌的知名度、忠诚度、所占市场份额、顾客口碑、社媒影响力、品牌资产价值等多方面去判断。我…

sqoop的安装与简单使用

文章目录 一、安装1、上传,解压,重命名2、修改环境变量3、修改配置文件4、上传驱动包5、拷贝jar包 二、import命令1、将mysql的数据导入到hdfs上2、将mysql的数据导入到hive上3、增量导入数据 三、export命令1、从hdfs导出到mysql中2、从hive导出到mysql…

企业微信oauth2提示应用无法使用

问题描述: 生成oauth2之后,我a公司是服务商,我给b公司的人去点授权链接会提示这个 应用服务商还没有在企业微信为你开通接口调用许可」,导致无法使用此应用,请联系服务商开通 正文 你先要知道一件事!&…

Jenkins私有化部署

最终目标 与GitLab配合,实践前端自动化,详细内容移步基于Jenkins和GitLab的前端自动化实践 前置条件 一台云服务器云服务器上已安装Docker了解Docker基础 使用Docker安装Jenkins 参考github文档安装 docker run --name docker_jenkins --privilege…

操作系统 --- 进程的同步和互斥问题以及进程互斥实现方法(软件、硬件实现)、同步机制遵循的四条准则

目录 一、进程同步 二、进程互斥 三、进程互斥的实现方法 3.1 软件实现 3.1.1 单标志法(存在的主要问题:违背“空闲让进”原则) 3.1.1.1 基本思想 3.1.1.2 单标志法的基本概念及执行流程 3.1.1.3 特点 3.1.2 双标志先检查法&#…

进程间的通信 2 消息队列

system V IPC IPC : Inter-Process Communication (进程间通讯) System V IPC 对象共有三种: 消息队列共享内存信号量 System V IPC 是由内核维护的若干个对象,通过ipcs命名查询 每个 IPC 对象都有一个唯一的 ID,可以通过ftok()函数生成 …

使用SoapUI、Postman工具调用Webservice方法

SoapUI工具更适合调用Webservice使用。 1.使用SoapUI工具调用Webservice 创建“New SOAP Project” 自行定义一个项目名称,输入wsdl地址: 在左侧列表找到方法名,双击“Request 1”, 在请求数据中,添加对应的参数,然…

Linux--禁止root用户通过ssh直接登录

原文网址:Linux--禁止root用户通过ssh直接登录_IT利刃出鞘的博客-CSDN博客 简介 本文介绍Linux服务器怎样禁止root用户通过ssh直接登录。 为什么要禁止? 因为root用户是每个Linux系统都有的,黑客可以使用root用户名尝试不同的密码来暴力破…

【笔记】自动驾驶预测与决策规划_Part3_路径与轨迹规划

文章目录 0. 前言1. 基于搜索的路径规划1.1 A* 算法1.2 Hybrid A* 算法 2. 基于采样的路径规划2.1 Frent Frame方法2.2 Cartesian →Frent 1D ( x , y ) (x, y) (x,y) —> ( s , l ) (s, l) (s,l)2.3 Cartesian →Frent 3D2.4 贝尔曼Bellman最优性原理2.5 高速轨迹采样——…