PyTorch核心概念:从梯度、计算图到连续性的全面解析(二)

文章目录

  • pytorch中的Autograd
    • 计算图
    • 叶子张量
  • inplace操作
  • PyTorch的两大特点
    • 动态图
    • eager execution
  • PyTorch中的Variable
  • 参考文献

pytorch中的Autograd

pytorch提供了自动求导机制和对GPU的支持
了解自动求导背后的原理和规则:当使用pytorch中没有的loss function时,需要我们自己写loss function

计算图

假设我们有一个复杂的神经网络模型,我们把它想象成一个错综复杂的管道结构,不同的管道之间通过节点连接起来,我们有一个注水口,一个出水口。我们在入口注入数据之后,数据就沿着设定好的管道路线缓缓流动到出水口,这时候我们就完成了一次正向传播。想象一下输入的 tensor 数据在管道中缓缓流动的场景,这就是为什么 TensorFlow 叫 TensorFlow 的原因
计算图中的两个元素:tensor和Function

  • Function:在计算中某个节点所进行的计算,比如加、减、乘、除、卷积

Function 内部forward()backward()两个方法

a = torch.tensor(2.0, requires_grad=True)
b = a.exp()
print(b)
# tensor(7.3891, grad_fn=<ExpBackward>)

在我们做正向传播的过程中,除了执行forward()操作之外,还要为反向计算图添加 Function 节点。在上边这个例子中,变量 b 在反向传播中所需要进行的操作是 <ExpBackward>
假如我们需要计算这么一个模型:

l1 = input x w1
l2 = l1 + w2
l3 = l1 x w3
l4 = l2 x l3
loss = mean(l4)

在这里插入图片描述

正向传播计算图
在整张计算图中,只有 input 一个变量是 requires_grad=False 的。正向传播过程的具体代码如下:
x = torch.ones([2, 2], requires_grad=False)
w1 = torch.tensor(2.0, requires_grad=True)
w2 = torch.tensor(3.0, requires_grad=True)
w3 = torch.tensor(4.0, requires_grad=True)l1 =x * w1
l2 = l1 + w2
l3 = l1 * w3
l4 = l2 * l3
loss = l4.mean()print(w1.data, w1.grad, w1.grad_fn)
# tensor(2.) None Noneprint(l1.data, l1.grad, l1.grad_fn)
# tensor([[2., 2.],
#         [2., 2.]]) None <MulBackward0 object at 0x000001EBE79E6AC8>print(loss.data, loss.grad, loss.grad_fn)
# tensor(40.) None <MeanBackward0 object at 0x000001EBE79D8208>

在正向传播的过程中,变量 l1grad_fn 储存着乘法操作符 <MulBackward0>,用于在反向传播中指导梯度的计算;w1 是用户自己定义的,不是通过计算得来的,所以其 grad_fn 为空,同时因为还没有进行反向传播,grad 的值也为空
在这里插入图片描述

反向传播计算图
反向图也比较简单,从 loss 这个变量开始,通过链式法则,依次计算出各部分的梯度
x = [1.0, 1.0, 1.0, 1.0]
w1 = [2.0, 2.0, 2.0, 2.0]
w2 = [3.0, 3.0, 3.0, 3.0]
w3 = [4.0, 4.0, 4.0, 4.0]l1 = x * w1 = [2.0, 2.0, 2.0, 2.0]
l2 = l1 + w2 = [5.0, 5.0, 5.0, 5.0]
l3 = l1 * w3 = [8.0, 8.0, 8.0, 8.0] 
l4 = l2 * l3 = [40.0, 40.0, 40.0, 40.0] 
loss = mean(l4) = 40.0loss.backward()print(w1.grad, w2.grad, w3.grad)
# 梯度之和
# tensor(28.) tensor(8.) tensor(10.)print(l1.grad, l2.grad, l3.grad, l4.grad, loss.grad)
# None None None None None

由于l1l2l3l4均未设置requires_grad=True,所以PyTorch不会自动追踪其梯度

叶子张量

对于任意一个张量来说,我们可以用 tensor.is_leaf 来判断它是否是叶子张量(leaf tensor)。在反向传播过程中,只有 is_leaf=True 的时候,需要求导的张量的梯度结果才会被最后保留下来
requires_grad=True时,如何判断是否是叶子张量:当这个 tensor 是用户创建的时候,它是一个叶子节点,当这个 tensor 是由其他运算操作产生的时候,它就不是一个叶子节点

a = torch.ones([2, 2], requires_grad=True)
print(a.is_leaf)
# Trueb = a + 2
print(b.is_leaf)
# False
# 因为 b 不是用户创建的,是通过计算生成的

提出叶子张量概念的目的是节省内存或显存
那些非叶子结点,是通过用户所定义的叶子节点的一系列运算生成的,也就是这些非叶子节点都是中间变量,一般情况下,用户不会去使用这些中间变量的梯度,所以为了节省内存,它们在用完之后就会被释放
在上述反向传播计算图中,标绿的是叶子张量
对于叶子节点来说,它们的 grad_fn 属性都为空;而对于非叶子结点来说,因为它们是通过一些操作生成的,所以它们的 grad_fn 不为空

inplace操作

inplace 指的是在不更改变量的内存地址的情况下,直接修改变量的值
我们来看两种情况,大家觉得这两种情况哪个是 inplace 操作,哪个不是?或者两个都是 inplace?

# 情景 1
a = a.exp()# 情景 2
a[0] = 10

答案是:情景1不是 inplace,类似 Python 中的 i=i+1, 而情景2是 inplace 操作,类似 i+=1

接下来以 PyTorch 不同的报错信息作为驱动介绍inplace操作
第一个报错信息:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# 我们要用到 id() 这个函数,其返回值是对象的内存地址
# 情景 1
a = torch.tensor([3.0, 1.0])
print(id(a)) # 2112716404344a = a.exp()
print(id(a)) # 2112715008904
# 在这个过程中 a.exp() 生成了一个新的对象,然后再让 a
# 指向它的地址,所以这不是个 inplace 操作# 情景 2
a = torch.tensor([3.0, 1.0])
print(id(a)) # 2112716403840a[0] = 10
print(id(a), a) # 2112716403840 tensor([10.,  1.])
# inplace 操作,内存地址没变

PyTorch通过tensor._version检测tensor是否发生inplace操作

a = torch.tensor([1.0, 3.0], requires_grad=True)
b = a + 2
print(b._version) # 0loss = (b * b).mean()
b[0] = 1000.0
print(b._version) # 1loss.backward()
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation ...

每次 tensor 在进行 inplace 操作时,变量 _version 就会加1,其初始值为0。在正向传播过程中,求导系统记录的 b 的 version 是0,但是在进行反向传播的过程中,求导系统发现 b 的 version 变成1了,所以就会报错了。但是还有一种特殊情况不会报错,就是反向传播求导的时候如果没用到 b 的值(比如 y=x+1, y 关于 x 的导数是1,和 x 无关),自然就不会去对比 b 前后的 version 了,所以不会报错
对于 requires_grad=True 的叶子节点来说,要求更加严格了,甚至在叶子节点被使用之前修改它的值都不行

RuntimeError: leaf variable has been moved into the graph interior

上述报错信息是经过 inplace 操作把一个叶子节点变成了非叶子节点,我们知道,非叶子节点的导数在默认情况下是不会被保存的

a = torch.tensor([10., 5., 2., 3.], requires_grad=True)
print(a, a.is_leaf)
# tensor([10.,  5.,  2.,  3.], requires_grad=True) Truea[:] = 0
print(a, a.is_leaf)
# tensor([0., 0., 0., 0.], grad_fn=<CopySlices>) Falseloss = (a*a).mean()loss.backward()
# RuntimeError: leaf variable has been moved into the graph interior

我们观察到,在对变量 a 进行重新赋值后,a 变成了通过复制操作<CopySlices>生成的张量,它不再是叶子节点。原本应该保留梯度值的变量,现在却成为了梯度会被自动释放的中间变量
另外一种情况:

a = torch.tensor([10., 5., 2., 3.], requires_grad=True)
a.add_(10.) # 或者 a += 10.
# RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

这个情况更为严重:在你调用 backward 之前,只要对需要求导的叶子张量执行了这些操作,就会立即报错。那么,是否意味着一旦叶子节点被初始化赋值后,就不能再修改它们的值呢?如果在某些情况下我们确实需要重新对叶子变量赋值,该怎么办呢?

# 方法一
a = torch.tensor([10., 5., 2., 3.], requires_grad=True)
print(a, a.is_leaf, id(a))
# tensor([10.,  5.,  2.,  3.], requires_grad=True) True 2501274822696a.data.fill_(10.)
# 或者 a.detach().fill_(10.)print(a, a.is_leaf, id(a))
# tensor([10., 10., 10., 10.], requires_grad=True) True 2501274822696loss = (a*a).mean()
loss.backward()print(a.grad)
# tensor([5., 5., 5., 5.])# 方法二
a = torch.tensor([10., 5., 2., 3.], requires_grad=True)
print(a, a.is_leaf)
# tensor([10.,  5.,  2.,  3.], requires_grad=True) Truewith torch.no_grad():a[:] = 10.
print(a, a.is_leaf)
# tensor([10., 10., 10., 10.], requires_grad=True) Trueloss = (a*a).mean()
loss.backward()print(a.grad)
# tensor([5., 5., 5., 5.])

我们需要注意的是,要在变量被使用之前修改,不然等计算完之后再修改,还会造成求导上的问题
为什么 PyTorch 的求导不支持绝大部分 inplace 操作呢?从上边我们也看出来了,因为真的很 tricky。比如有的时候在一个变量已经参与了正向传播的计算,之后它的值被修改了,在做反向传播的时候如果还需要这个变量的值的话,我们肯定不能用那个后来修改的值吧,但没修改之前的原始值已经被释放掉了,我们怎么办?一种可行的办法就是我们在 Function 做 forward 的时候每次都开辟一片空间储存当时输入变量的值,这样无论之后它们怎么修改,都不会影响了,反正我们有备份在存着。但这样有什么问题?这样会导致内存(或显存)使用量大大增加。因为我们不确定哪个变量可能之后会做 inplace 操作,所以我们每个变量在做完 forward 之后都要储存一个备份,成本太高了
PyTorch 不推荐使用 inplace 操作,当求导过程中发现有 inplace 操作影响求导正确性的时候,会采用报错的方式提醒。但这句话反过来说就是,因为只要有 inplace 操作不当就会报错,所以如果我们在程序中使用了 inplace 操作却没报错,那么说明我们最后求导的结果是正确的

PyTorch的两大特点

PyTorch 的两大特点是动态图和eager execution,这使得它的使用非常流畅,几乎和编写 Python 程序一样舒适,同时调试过程也极为方便;同时PyTorch 十分注重占用内存(或显存)大小,没有用的空间释放很及时,可以很有效地利用有限的内存

动态图

PyTorch 使用的是动态图(Dynamic Computational Graphs)的方式,而 TensorFlow 使用的是静态图(Static Computational Graphs)

  • 动态图:每次当我们搭建完一个计算图,然后在反向传播结束之后,整个计算图就在内存中被释放了。如果想再次使用的话,必须从头再搭一遍
  • 静态图:每次都先设计好计算图,需要的时候实例化这个图,然后送入各种输入,重复使用,只有当会话结束的时候创建的图才会被释放
# 这是一个关于 PyTorch 是动态图的例子:
a = torch.tensor([3.0, 1.0], requires_grad=True)
b = a * a
loss = b.mean()
loss.backward() # 正常
loss.backward() # RuntimeError# 第二次:从头再来一遍
a = torch.tensor([3.0, 1.0], requires_grad=True)
b = a * a
loss = b.mean()
loss.backward() # 正常

理论上,静态图的效率比动态图的效率高。因为静态图只需要一次构建便可重复使用

eager execution

当遇到 tensor 计算的时候,马上就回去执行计算,实际上 PyTorch 根本不会去构建正向计算图,而是遇到操作就执行。真正意义上的正向计算图是把所有的操作都添加完,构建好了之后,再运行神经网络的正向传播

PyTorch中的Variable

tensor是硬币的话,那Variable就是钱包,它记录着里面的钱的多少,和钱的流向
在这里插入图片描述
torch0.4版本以后 torch.tensor() 就可以搞定所有

《PyTorch核心概念:从梯度、计算图到连续性的全面解析(一)》
《PyTorch核心概念:从梯度、计算图到连续性的全面解析(三)》

参考文献

1、PyTorch 的 Autograd
2、Pytorch入坑二:autograd 及Variable

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/6670.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

dayseven-因果分析-图模型与结构因果模型

在数学上&#xff0c;​“图”(graph)是顶点&#xff08;vertex&#xff0c;也可以称为节点&#xff09;和边(edge)的集合&#xff0c;表示为图G(V,E)&#xff0c;其中V是节点的集合&#xff0c;E是边的集合&#xff0c;图中的节点之间通过边相连&#xff08;也可以不相连&…

今天强的可怕,AI文风写作再也不用写指令了

AI写作最有用的事情之一就是捕捉特定的写作风格&#xff0c;市面上写作工具模仿文风需要下达复杂的prompt&#xff0c;经过一大段精细的微调才能实现&#xff01; 而现在文思助手只要一个按钮就能输出一篇文风相似的文章&#xff01;超级简单&#xff0c;你再也不用为一大段一大…

Vue2中使用firefox的pdfjs进行文件文件流预览

文章目录 1.使用场景2. 使用方式1. npm 包下载,[点击查看](https://www.npmjs.com/package/pdfjs-dist)2. 官网下载1. 放到public文件夹下面2. 官网下载地址[点我,进入官网](https://github.com/mozilla/pdf.js/tags?afterv3.3.122) 3. 代码演示4. 图片预览5. 如果遇到跨域或者…

哪些因素会影响 DC/DC 转换电路快速测试的性能?-纳米软件

DC/DC 转换电路在现代电子设备中起着至关重要的作用&#xff0c;其性能的快速准确测试对于确保电子系统的可靠性和稳定性至关重要。然而&#xff0c;有许多因素会影响 DC/DC 转换电路快速测试的性能。 电路复杂性和参数多样性 单片 DC/DC 转换器由于功能模块和参数复杂性&…

解线性方程组(二)

实验类型&#xff1a;●验证性实验 ○综合性实验 ○设计性实验 实验目的&#xff1a;进一步熟练掌握用Jacobi迭代法和Gauss-Seidel法解线性方程组的算法&#xff0c;提高编程能力和解算线性方程组问题的实践技能。 实验内容&#xff1a; 1)取初值性x(0)(0,0,0,0)T, 精度要求ε…

跨境电商营销:Pinterest的5个便捷营销工具

Pinterest是消费者寻找创意灵感的首选平台之一&#xff0c;同时&#xff0c;根据Global Web Index的调查数据&#xff0c;人们使用Pinterest的首要原因是寻找新产品和品牌&#xff0c;这意味着用户在使用Pinterest时已经有消费意愿和倾向。 因此&#xff0c;让更多目标受众注意…

JAVA基础:多重循环、方法、递归 (习题笔记)

一&#xff0c;编码题 1.打印九九乘法表 import java.util.*;public class PanTi {public static void main(String[] args) {Scanner input new Scanner(System.in);for (int i 0; i < 9; i) {//i控制行数/* System.out.println("。\t。\t。\t。\t。\t。\t。\t。\…

小林渗透入门:burpsuite+proxifier抓取小程序流量

目录 前提&#xff1a; 代理&#xff1a; proxifier&#xff1a; 步骤&#xff1a; bp证书安装 bp设置代理端口&#xff1a; proxifier设置规则&#xff1a; proxifier应用规则&#xff1a; 结果&#xff1a; 前提&#xff1a; 在介绍这两个工具具体实现方法之前&#xff0…

[笔记] Centos7 安装 Docker 和 Docker Compose 及 Docker 命令大全

Docker 和 Docker Compose 是相辅相成的工具&#xff0c;它们共同提供了一个强大的容器化解决方案。Docker 提供了容器化的基础功能&#xff0c;而 Docker Compose 则提供了更高级的编排和管理能力&#xff0c;使得部署和管理多个容器变得更加容易和高效。 Docker&#xff1a;…

el-message 同时弹出多个【改写el-message】

因为服务断开了 但是拦截器里对每个失败的接口都做了message弹出&#xff0c;因此改写el-message逻辑&#xff0c;仅展示一个同等类型的message窗体 1. 新建 /utils/rewriteElMessage.js /*** Event 解决 el-message 同类型重复打开的问题* description:* author: mhf* time:…

SSM宿舍管理系统-计算机毕业设计源码03732

目 录 1 绪论 1.1研究背景 1.2开发现状 1.3研究内容 1.4论文结构与章节安排 2 宿舍管理系统系统分析 2.1 可行性分析 2.2 系统流程分析 2.2.1 数据流程 3.3.2 业务流程 2.3 系统功能分析 2.3.1 功能性分析 2.3.2 非功能性分析 2.4 系统用例分析 2.5本章小结 3 …

淘宝商品描述,一键“爬”回家 —— Java爬虫的奇妙冒险

引言&#xff1a; 在这个信息爆炸的时代&#xff0c;我们每天都在被各种商品信息轰炸。你是否曾想过&#xff0c;如何能快速、准确地获取淘宝商品的描述信息&#xff1f;今天&#xff0c;就让我们一起开启一段Java爬虫的奇妙冒险&#xff0c;探索如何通过代码一键“爬”取淘宝…

线性代数:Matrix2x2和Matrix3x3

今天整理自己的框架代码&#xff0c;将Matrix2x2和Matrix3x3给扩展了一下&#xff0c;发现网上unity数学计算相关挺少的&#xff0c;所以记录一下。 首先扩展Matrix2x2&#xff1a; using System.Collections; using System.Collections.Generic; using Unity.Mathemati…

windows在两台机器上测试 MySQL 集群实现实时备份

在两台机器上测试 MySQL 集群实现实时备份的基本步骤&#xff1a; 一、环境准备 机器配置 确保两台机器&#xff08;假设为服务器 A 和服务器 B&#xff09;能够互相通信&#xff0c;例如它们在同一个局域网内&#xff0c;并且开放了 MySQL 通信所需的端口&#xff08;默认是 …

【stm32】RTC时钟的介绍与使用

RTC时钟的介绍与使用 一、时间戳1、Unix时间戳2、UTC/GMT3、时间戳转换 二、BKP简介及代码编写1、BKP简介2、BKP基本结构3、BKP库函数介绍&#xff1a;4、程序编写&#xff1a; 三、RTC简介及代码编写1、RTC简介2、RTC框图2、RTC基本结构3、RTC相关库函数介绍&#xff1a;4、程…

界面控件DevExpress JS ASP.NET Core v24.1亮点 - 支持Angular 18

DevExtreme拥有高性能的HTML5 / JavaScript小部件集合&#xff0c;使您可以利用现代Web开发堆栈&#xff08;包括React&#xff0c;Angular&#xff0c;ASP.NET Core&#xff0c;jQuery&#xff0c;Knockout等&#xff09;构建交互式的Web应用程序。从Angular和Reac&#xff0c…

如何检索 LINEMOD 数据集的相机内参

简介 BOP (Benchmark for 6D Object Pose Estimation) 是一个专为6D物体姿态估计而设计的基准测试平台。它为研究人员提供了多种数据集&#xff0c;以帮助评估和比较物体识别和姿态估计算法的性能。官方网站是 BOP&#xff0c;你可以在这里找到丰富的资源和信息。 检索 LINEM…

机器学习与数据挖掘_使用梯度下降法训练线性回归模型

目录 实验内容 实验步骤 1. 导入必要的库 2. 加载数据并绘制散点图 3. 设置模型的超参数 4. 实现梯度下降算法 5. 打印训练后的参数和损失值 6. 绘制损失函数随迭代次数的变化图 7. 绘制线性回归拟合曲线 8. 基于训练好的模型进行新样本预测 实验代码 实验结果 实验…

web——sqliabs靶场——第一关

今天开始搞这个靶场&#xff0c;从小白开始一点点学习,加油&#xff01;&#xff01;&#xff01;&#xff01; 1.搭建靶场 注意点&#xff1a;1.php的版本问题&#xff0c;要用老版本 2.小p要先改数据库的密码&#xff0c;否则一直显示链接不上数据库 2.第一道题&#xff0…