【PyTorch 攻略 (4/7)】张量和梯度函数

一、说明

 

        W在训练神经网络时,最常用的算法是反向传播。在该算法中,参数(模型权重)根据损失函数相对于给定参数的梯度进行调整。损失函数计算神经网络产生的预期输出和实际输出之间的差异。
        目标是获得尽可能接近零的损失函数的结果。反向传播算法通过神经网络向后遍历,以调整权重和偏差以重新训练模型。这种随着时间的推移重新训练模型的来回和前进过程将损失减少到 0,称为梯度下降

        为了计算这些梯度,PyTorch有一个名为torch.autograd的内置微分引擎。它支持任何计算图的梯度自动计算。

%matplotlib inline
import torchx = torch.ones(5)   # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

二、张量、函数和计算图

        这段代码定义了以下计算图:
        在这个网络中,wb是我们需要优化的参数。因此,我们需要能够计算损失函数相对于这些变量的梯度。为了做到这一点,我们设置了这些张量的requires_grad属性。

        我们应用于张量来构造计算图的函数实际上是一个对象类函数。此对象知道如何在向前方向上计算函数,以及如何向后传播步骤中计算其导数。对向后传播函数的引用存储在张量的 grad_fn 属性中。

print('Gradient function for z =',z.grad_fn)
print('Gradient function for loss =', loss.grad_fn)
Gradient function for z = <AddBackward0 object at 0x00000280CC630CA0>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward object at 0x00000280CC630310>

三、计算梯度

        为了优化神经网络中参数的权重,我们需要计算损失函数相对于参数的导数,即我们需要在xy的固定值下∂w/∂loss和∂loss/∂b。为了计算这些导数,我们调用 loss.backward(),然后从 w.grad 和 b.grad 中检索值。

loss.backward()
print(w.grad)
print(b.grad)
tensor([[0.2739, 0.0490, 0.3279],[0.2739, 0.0490, 0.3279],[0.2739, 0.0490, 0.3279],[0.2739, 0.0490, 0.3279],[0.2739, 0.0490, 0.3279]])
tensor([0.2739, 0.0490, 0.3279]) 

四、禁用渐变跟踪

        默认情况下,所有 requires_grad=True 的张量都在跟踪它们的计算历史并支持梯度计算。但是,在某些情况下,我们不需要这样做,例如,当我们训练了模型并只想将其应用于某些输入数据时,即我们只想通过网络进行前向计算。我们可以通过在计算代码周围用 torch.no_grad() 块来停止跟踪计算。

z = torch.matmul(x, w)+b
print(z.requires_grad)with torch.no_grad():z = torch.matmul(x, w)+b
print(z.requires_grad)
True
False

        您可能想要禁用梯度跟踪的原因如下:
- 将神经网络中的某些参数标记为冻结参数。这是微调预训练网络的一种非常常见的方案。
- 在只执行正向传递时加快计算速度,因为对不跟踪梯度的张量的计算会更有效。

        C从概念上讲,Autograd 在由函数对象组成的有向无环图 (DAG) 中记录数据(张量)和所有执行的操作(以及生成的新张量)。在此 DAG 中,叶子是输入张量,根是输出张量。通过跟踪从根到叶的图形,您可以使用链式规则自动计算梯度。

        在正向传递中,autograd 同时做两件事:
- 运行请求的操作以计算生成的张量
- 在 DAG 中维护操作的梯度函数

        在向后传递中,.back() 在 DAG 根目录上调用。然后,
autograd :- 计算每个.grad_fn
的梯度 - 将它们累积在相应张量的 .grad 属性
中 - 使用链式规则一直传播到叶张量

DAG 在 PyTorch 中是动态的。

        需要注意的重要一点是,图形是从头开始重新创建的;每次 .backward() 调用后,Autograd 开始填充一个新图形。这正是允许您在模型中使用控制流语句的原因。
如果需要,您可以在每次迭代时更改形状、大小和操作。

五、张量梯度和雅可比积

        在许多情况下,我们有一个标量损失函数,我们需要计算相对于某些参数的梯度。但是,在某些情况下,输出函数是任意张量。在这种情况下,PyTorch 允许您计算所谓的雅可比乘积,而不是实际的梯度。

        对于向量函数 y* = f(x*),其中 x* = (x1, ..., xn) 和 y* = (y1, ..., ym),
y* 相对于 x* 的梯度由雅可比矩阵给出其元素 J 包含 ∂xi/∂yj

        PyTorch 不是计算雅可比矩阵本身,而是允许您计算雅可比乘积。J 对于给定的输入向量 v = (v1, ..., vm)。
这是通过使用 v 作为参数向调用来实现的。v 的大小应该与原始张量的大小相同,我们想要计算乘积。

inp = torch.eye(5, requires_grad=True)
out = (inp+1).pow(2)
out.backward(torch.ones_like(inp), retain_graph=True)
print("First call\n", inp.grad)out.backward(torch.ones_like(inp), retain_graph=True)
print("\nSecond call\n", inp.grad)inp.grad.zero_()
out.backward(torch.ones_like(inp), retain_graph=True)
print("\nCall after zeroing gradients\n", inp.grad)
First calltensor([[4., 2., 2., 2., 2.],[2., 4., 2., 2., 2.],[2., 2., 4., 2., 2.],[2., 2., 2., 4., 2.],[2., 2., 2., 2., 4.]])Second calltensor([[8., 4., 4., 4., 4.],[4., 8., 4., 4., 4.],[4., 4., 8., 4., 4.],[4., 4., 4., 8., 4.],[4., 4., 4., 4., 8.]])Call after zeroing gradientstensor([[4., 2., 2., 2., 2.],[2., 4., 2., 2., 2.],[2., 2., 4., 2., 2.],[2., 2., 2., 4., 2.],[2., 2., 2., 2., 4.]])

        请注意,当我们使用相同的参数第二次向后调用时,梯度的值是不同的。发生这种情况是因为在进行向后传播时,PyTorch 会累积梯度,即计算梯度的值被添加到计算图的所有叶节点的 grad 属性中。如果要计算正确的梯度,则需要在之前将 grad 属性归零。在现实生活中的训练中,优化器可以帮助我们做到这一点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/139061.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

QUIC协议报文解析(三)

在前面的两篇文字里我们简单介绍了QUIC的发展历史&#xff0c;优点以及QUIC协议的连接原理。本篇文章将会以具体的QUIC报文为例&#xff0c;详细介绍QUIC报文的结构以及各个字段的含义。 早期QUIC版本众多&#xff0c;主要有谷歌家的gQUIC&#xff0c;以及IETF致力于将QUIC标准…

打印由数字组成的金字塔图案——python

1222 33333 4444444 555555555打印由数字组成的金字塔图案。但n9时&#xff0c;如下图所示。 输入格式: 输入一个整数n&#xff08;1<A<9&#xff09;。 输出格式: 输出由数字组成的金字塔图案。 输入样例: 在这里给出一组输入。例如&#xff1a; 5输出样例: 在这…

【 2023华为杯C题】大规模创新类竞赛评审方案研究(思路、代码......)

目录 1 题目概述 2 问题 3 极差的定义及标准分的计算方法 4 题目及数据下载 5 思路、代码下载...... 1 题目概述 现在创新类竞赛很多&#xff0c;其中规模较大的竞赛&#xff0c;一般采用两阶段&#xff08;网评、现场评审&#xff09;或三阶段&#xff08;网评、现场评审…

高效畅通的iOS平台S5配置指南

在iOS平台上&#xff0c;使用S5代理ip访问互联网是一种非常有用的技巧。无论是为了保证隐私安全&#xff0c;还是解决网络限制问题&#xff0c;S5代理ip都能为您提供更快、更稳定的互联网访问体验。本文将为您详细介绍如何在iOS平台上配置和使用S5代理ip&#xff0c;让您的网络…

git之撤销工作区的修改和版本回溯

有时候在工作区做了一些修改和代码调试不想要了,可如下做 (1)步骤1:删除目录代码,确保.git目录不能修改 (2)git log 得到相关的commit sha值 可配合git reflog 得到相要的sha值 (3)执行git reset --hard sha值,可以得到时间轴任意版本的代码 git reset --hard sha值干净的代…

【Java 基础篇】Java网络编程实战:P2P文件共享详解

Java网络编程是现代软件开发中不可或缺的一部分&#xff0c;因为它允许不同计算机之间的数据传输和通信。在本篇博客中&#xff0c;我们将深入探讨Java中的P2P文件共享&#xff0c;包括什么是P2P文件共享、如何实现它以及一些相关的重要概念。 什么是P2P文件共享&#xff1f; …

23个销量最高的3D扫描仪【2023】

如果你可以 3D 扫描它&#xff0c;你就可以 3D 打印它。 市场上 3D 扫描仪的种类和质量非常丰富&#xff0c;机器尺寸、功能和价格各异。 这样的选择虽然本身是一件很棒的事情&#xff0c;但也会让从无用的东西中挑选出宝石成为一件苦差事。 推荐&#xff1a;用 NSDT编辑器 快速…

HTTP各版本差异

HTTP1.0 无法复用连接 HTTP1.0为每个请求单独新开一个TCP连接 #mermaid-svg-9N3exXRS4VvT4bWF {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-9N3exXRS4VvT4bWF .error-icon{fill:#552222;}#mermaid-svg-9N3exXRS…

Reinforcement Learning(二)--on-policy和off-policy

1.前言 强化学习&#xff08;Reinforcement learning&#xff0c;RL&#xff09;是机器学习的一个分析&#xff0c;特点是概念多、公式多、入门门槛高&#x1f972;&#xff08;别问我怎么知道的&#xff09;。本篇文章着重讲解RL最重要的概念之一&#xff0c;即on-policy和of…

2023工博会强势回归!智微工业携八大系列重磅亮相

中国国际工业博览会&#xff08;简称"中国工博会"&#xff09;自1999年创办以来&#xff0c;历经二十余年发展创新&#xff0c;通过专业化、市场化、国际化、品牌化运作&#xff0c;已发展成为通过国际展览业协会&#xff08;UFI&#xff09;认证、中国工业领域规模最…

mybatis/mp批量插入非自增主键数据

文章目录 前言一、mp的批量插入是假的二、真正的批量插入1.利用sql注入器处理2.采用自编码,编写xml批量执行生成内容如下: 三 问题问题描述问题原因问题解决粘贴一份,兼容集合替换原有文件 总结自增与非自增区别: 前言 mybatis/mp 在实际开发中是常用的优秀持久层框架,但是在非…

Linux:GlusterFS 集群

GlusterFS介绍 1&#xff09;Glusterfs是一个开源的分布式文件系统,是Scale存储的核心,能够处理千数量级的客户端.在传统的解决 方案中Glusterfs能够灵活的结合物理的,虚拟的和云资源去体现高可用和企业级的性能存储. 2&#xff09;Glusterfs通过TCP/IP或InfiniBand RDMA网络链…

【C++】String类基本接口介绍及模拟实现(多看英文文档)

string目录 如果你很赶时间&#xff0c;那么就直接看我本标题下的内容即可&#xff01;&#xff01; 一、STL简介 1.1什么是STL 1.2STL版本 1.3STL六大组件 1.4STL重要性 1.5如何学习STL 二、什么是string&#xff1f;&#xff1f;&#xff08;本质上是一个类&#xff0…

【Redis】深入探索 Redis 的数据类型 —— 列表 List

文章目录 一、List 类型介绍二、List 类型相关命令2.1 LPUSH 和 RPUSH、LPUSHX 和 RPUSHX2.2 LPOP 和 RPOP、BLPOP 和 BRPOP2.3 LRANGE、LINDEX、LINSERT、LLEN2.4 列表相关命令总结 三、List 类型内部编码3.1 压缩列表&#xff08;ziplist&#xff09;3.2 链表&#xff08;lin…

Git错误解决:如何处理“could not determine hash algorithm“问题

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

openssl创建CA证书教程

配置生成CA证书 总示意图&#xff1a; (1)&#xff0c;通过openssl创建CA证书 第一步&#xff1a;创建一个秘钥&#xff0c;这个便是CA证书的根本&#xff0c;之后所有的东西都来自这个秘钥 # 通过rsa算法生成2048位长度的秘钥 openssl genrsa -out myCA.key 2048 第二步&#…

Android Camera2获取摄像头的视场角(FOV)信息

一、概念 FOV&#xff08;Field of View&#xff09;是一个用于描述视野范围的术语。它通常用于计算设备&#xff08;如摄像机、虚拟现实头显或眼睛&#xff09;所能捕捉到的可见区域。 水平FOV&#xff08;Horizontal FOV&#xff09;&#xff1a;描述视野在水平方向上的范围…

JVM面试题-JVM对象的创建过程、内存分配、内存布局、访问定位等问题详解

对象 内存分配的两种方式 指针碰撞 适用场合&#xff1a;堆内存规整&#xff08;即没有内存碎片&#xff09;的情况下。 原理&#xff1a;用过的内存全部整合到一边&#xff0c;没有用过的内存放在另一边&#xff0c;中间有一个分界指针&#xff0c;只需要向着没用过的内存…

【最新面试问题记录持续更新,java,kotlin,android,flutter】

最近找工作&#xff0c;复习了下java相关的知识。发现已经对很多概念模糊了。记录一下。部分是往年面试题重新整理&#xff0c;部分是自己面试遇到的问题。持续更新中~ 目录 java相关1. 面向对象设计原则2. 面向对象的特征是什么3. 重载和重写4. 基本数据类型5. 装箱和拆箱6. …

【数据结构】顺序表与ArrayList

作者主页&#xff1a;paper jie 的博客 本文作者&#xff1a;大家好&#xff0c;我是paper jie&#xff0c;感谢你阅读本文&#xff0c;欢迎一建三连哦。 本文录入于《JAVA数据结构》专栏&#xff0c;本专栏是针对于大学生&#xff0c;编程小白精心打造的。笔者用重金(时间和精…