pytorch的动态计算图机制

pytorch的动态计算图机制

一,动态计算图简介

在这里插入图片描述

Pytorch的计算图由节点和边组成,节点表示张量或者Function,边表示张量和Function之间的依赖关系。

Pytorch中的计算图是动态图。这里的动态主要有两重含义。

第一层含义是:计算图的正向传播是立即执行的。无需等待完整的计算图创建完毕,每条语句都会在计算图中动态添加节点和边,并立即执行正向传播得到计算结果。

第二层含义是:计算图在反向传播后立即销毁。下次调用需要重新构建计算图。如果在程序中使用了backward方法执行了反向传播,或者利用torch.autograd.grad方法计算了梯度,那么创建的计算图会被立即销毁,释放存储空间,下次调用需要重新创建。

1,计算图的正向传播是立即执行的。

import torch 
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.randn(10,2)
Y = torch.randn(10,1)
Y_hat = X@w.t() + b  # Y_hat定义后其正向传播被立即执行,与其后面的loss创建语句无关
loss = torch.mean(torch.pow(Y_hat-Y,2))print(loss.data)
print(Y_hat.data)
tensor(17.8969)
tensor([[3.2613],[4.7322],[4.5037],[7.5899],[7.0973],[1.3287],[6.1473],[1.3492],[1.3911],[1.2150]])

2,计算图在反向传播后立即销毁。

import torch 
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.randn(10,2)
Y = torch.randn(10,1)
Y_hat = X@w.t() + b  # Y_hat定义后其正向传播被立即执行,与其后面的loss创建语句无关
loss = torch.mean(torch.pow(Y_hat-Y,2))#计算图在反向传播后立即销毁,如果需要保留计算图, 需要设置retain_graph = True
loss.backward()  #loss.backward(retain_graph = True) #loss.backward() #如果再次执行反向传播将报错

二,计算图中的Function

计算图中的另外一种节点是Function, 实际上就是 Pytorch中各种对张量操作的函数。

这些Function和我们Python中的函数有一个较大的区别,那就是它同时包括正向计算逻辑和反向传播的逻辑。

我们可以通过继承torch.autograd.Function来创建这种支持反向传播的Function

class MyReLU(torch.autograd.Function):#正向传播逻辑,可以用ctx存储一些值,供反向传播使用。@staticmethoddef forward(ctx, input):ctx.save_for_backward(input)return input.clamp(min=0)#反向传播逻辑@staticmethoddef backward(ctx, grad_output):input, = ctx.saved_tensorsgrad_input = grad_output.clone()grad_input[input < 0] = 0return grad_input
import torch 
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.tensor([[-1.0,-1.0],[1.0,1.0]])
Y = torch.tensor([[2.0,3.0]])relu = MyReLU.apply # relu现在也可以具有正向传播和反向传播功能
Y_hat = relu(X@w.t() + b)
loss = torch.mean(torch.pow(Y_hat-Y,2))loss.backward()print(w.grad)
print(b.grad)
tensor([[4.5000, 4.5000]])
tensor([[4.5000]])
# Y_hat的梯度函数即是我们自己所定义的 MyReLU.backwardprint(Y_hat.grad_fn)
<torch.autograd.function.MyReLUBackward object at 0x1205a46c8>

三,计算图与反向传播

了解了Function的功能,我们可以简单地理解一下反向传播的原理和过程。理解该部分原理需要一些高等数学中求导链式法则的基础知识。

import torch x = torch.tensor(3.0,requires_grad=True)
y1 = x + 1
y2 = 2*x
loss = (y1-y2)**2loss.backward()

loss.backward()语句调用后,依次发生以下计算过程。

1,loss自己的grad梯度赋值为1,即对自身的梯度为1。

2,loss根据其自身梯度以及关联的backward方法,计算出其对应的自变量即y1和y2的梯度,将该值赋值到y1.grad和y2.grad。

3,y2和y1根据其自身梯度以及关联的backward方法, 分别计算出其对应的自变量x的梯度,x.grad将其收到的多个梯度值累加。

(注意,1,2,3步骤的求梯度顺序和对多个梯度值的累加规则恰好是求导链式法则的程序表述)

正因为求导链式法则衍生的梯度累加规则,张量的grad梯度不会自动清零,在需要的时候需要手动置零。

四,叶子节点和非叶子节点

执行下面代码,我们会发现 loss.grad并不是我们期望的1,而是 None。

类似地 y1.grad 以及 y2.grad也是 None.

这是为什么呢?这是由于它们不是叶子节点张量。

在反向传播过程中,只有 is_leaf=True 的叶子节点,需要求导的张量的导数结果才会被最后保留下来。

那么什么是叶子节点张量呢?叶子节点张量需要满足两个条件。

1,叶子节点张量是由用户直接创建的张量,而非由某个Function通过计算得到的张量。

2,叶子节点张量的 requires_grad属性必须为True.

Pytorch设计这样的规则主要是为了节约内存或者显存空间,因为几乎所有的时候,用户只会关心他自己直接创建的张量的梯度。

所有依赖于叶子节点张量的张量, 其requires_grad 属性必定是True的,但其梯度值只在计算过程中被用到,不会最终存储到grad属性中。

如果需要保留中间计算结果的梯度到grad属性中,可以使用 retain_grad方法。
如果仅仅是为了调试代码查看梯度值,可以利用register_hook打印日志。

import torch x = torch.tensor(3.0,requires_grad=True)
y1 = x + 1
y2 = 2*x
loss = (y1-y2)**2loss.backward()
print("loss.grad:", loss.grad)
print("y1.grad:", y1.grad)
print("y2.grad:", y2.grad)
print(x.grad)
loss.grad: None
y1.grad: None
y2.grad: None
tensor(4.)
print(x.is_leaf)
print(y1.is_leaf)
print(y2.is_leaf)
print(loss.is_leaf)
True
False
False
False

利用retain_grad可以保留非叶子节点的梯度值,利用register_hook可以查看非叶子节点的梯度值。

import torch #正向传播
x = torch.tensor(3.0,requires_grad=True)
y1 = x + 1
y2 = 2*x
loss = (y1-y2)**2#非叶子节点梯度显示控制
y1.register_hook(lambda grad: print('y1 grad: ', grad))
y2.register_hook(lambda grad: print('y2 grad: ', grad))
loss.retain_grad()#反向传播
loss.backward()
print("loss.grad:", loss.grad)
print("x.grad:", x.grad)
y2 grad:  tensor(4.)
y1 grad:  tensor(-4.)
loss.grad: tensor(1.)
x.grad: tensor(4.)

五,计算图在TensorBoard中的可视化

可以利用 torch.utils.tensorboard 将计算图导出到 TensorBoard进行可视化。

from torch import nn 
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.w = nn.Parameter(torch.randn(2,1))self.b = nn.Parameter(torch.zeros(1,1))def forward(self, x):y = x@self.w + self.breturn ynet = Net()
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('../data/tensorboard')
writer.add_graph(net,input_to_model = torch.rand(10,2))
writer.close()
%load_ext tensorboard
#%tensorboard --logdir ../data/tensorboard
from tensorboard import notebook
notebook.list() 
#在tensorboard中查看模型
notebook.start("--logdir ../data/tensorboard")

在这里插入图片描述


Reference:

https://jackiexiao.github.io/eat_pytorch_in_20_days/2.%E6%A0%B8%E5%BF%83%E6%A6%82%E5%BF%B5/2-3%2C%E5%8A%A8%E6%80%81%E8%AE%A1%E7%AE%97%E5%9B%BE/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/147613.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Swin Transformer—使用平移窗口的分层视觉转换器结构

Swin Transformer解读 论文题目&#xff1a;Swin Transformer: Hierarchical Vision Transformer using Shifted Windows. 官方代码地址&#xff1a;https://github.com/microsoft/Swin-Transformer. 引言与概括 ICCV2021的最佳论文作者是来自微软亚洲研究院。 SwinTransforme…

基础实践:使用JQuery Ajax调用Servlet

前言 本博客介绍最简单的JQuery&#xff08;原生JS的封装库&#xff09;使用Ajax发送请求&#xff0c;并通过对应的servlet响应数据&#xff0c;并在页面显示&#xff0c;并且servlet响应的数据来自MySQL数据库。 实现需求&#xff1a;在前端页面的输入框中输入要注册的用户名&…

依赖库查看工具Dependencies

依赖库查看工具&#xff1a;Dependencies Dependencies 是一款 Windows 平台下的静态分析工具&#xff0c;用来分析可执行文件&#xff08;EXE、DLL 等&#xff09;所依赖的动态链接库&#xff08;DLL&#xff09;。它可以帮助开发者和系统管理员快速查找程序在运行时可能缺少的…

【机器学习】--- 决策树与随机森林

文章目录 决策树与随机森林的改进&#xff1a;全面解析与深度优化目录1. 决策树的基本原理2. 决策树的缺陷及改进方法2.1 剪枝技术2.2 树的深度控制2.3 特征选择的优化 3. 随机森林的基本原理4. 随机森林的缺陷及改进方法4.1 特征重要性改进4.2 树的集成方法优化4.3 随机森林的…

论文浅尝 | KAM-CoT: 利用知识图谱进行知识增强的多模态链式推理(AAAI2024)

笔记整理&#xff1a;沈小力&#xff0c;东南大学硕士&#xff0c;研究方向为多模态大预言模型、知识图谱 论文链接&#xff1a;https://arxiv.org/abs/2401.12863 发表会议&#xff1a;AAAI2024 1. 动机 本文探索了知识图谱在扩展大语言模型的多模态能力的效果&#xff0c;提出…

在jupyter notebook中取消代理服务器的解决方案

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

使用GPU 加速 Polars:高效解决大规模数据问题

Polars 最近新开发了一个可以支持 GPU 加速计算的执行引擎。这个引擎可以对超过 100GB 的数据进行交互式操作能。本文将详细讨论 Polars 中DF的概念、GPU 加速如何与 Polars DF协同工作&#xff0c;以及使用新的 CUDA 驱动执行引擎可能带来的性能提升。 Polars 核心概念 Polar…

go libreoffice word 转pdf

一、main.go 关键代码 完整代码 package mainimport ("fmt""github.com/jmoiron/sqlx""github.com/tealeg/xlsx""log""os/exec""path/filepath" ) import _ "github.com/go-sql-driver/mysql"import &q…

多态与绑定例题

答案&#xff1a; B D C 知识点&#xff1a; 多态是相同方法不同的表现&#xff0c;分为重写和重载 重写体现在父类与子类不同表现&#xff0c;主要表现为子类重现父类的方法 重载体现在同一个类中的不同表现 绑定分为动态绑定和静态绑定 动态绑定是在运行时 静态绑定是…

java23发布啦

2024年9月java23发布啦&#xff01;&#xff01;! JDK 23 提供了12 项增强功能&#xff0c;这些功能足以保证其自己的JDK 增强提案 - JEP &#xff0c;其中包括 8 项预览功能和 1 项孵化器功能。它们涵盖了对 Java 语言、API、性能和 JDK 中包含的工具的改进。除了 Java 平台上…

《独孤九剑》游戏源码(客户端+服务端+数据库+游戏全套源码)大小2.38G

《独孤九剑》游戏源码&#xff08;客户端服务端数据库游戏全套源码&#xff09;大小2.38G ​ 下载地址&#xff1a; 通过网盘分享的文件&#xff1a;【源码】《独孤九剑》游戏源码&#xff08;客户端服务端数据库游戏全套源码&#xff09;大小2.38G 链接: https://pan.baidu.co…

走在时代前沿:让ChatGPT成为你的职场超级助手

在当今快节奏的工作环境中&#xff0c;时间和效率是宝贵的资源。人工智能&#xff08;AI&#xff09;&#xff0c;尤其是自然语言处理技术的进步&#xff0c;为我们提供了强大的工具来优化工作流程。ChatGPT&#xff08;Generative Pre-trained Transformer&#xff09;就是其中…

计算机毕业设计之:基基于微信小程序的轻食减脂平台的设计与实现(源码+文档+讲解)

博主介绍&#xff1a; ✌我是阿龙&#xff0c;一名专注于Java技术领域的程序员&#xff0c;全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师&#xff0c;我在计算机毕业设计开发方面积累了丰富的经验。同时&#xff0c;我也是掘金、华为云、阿里云、InfoQ等平台…

PostgreSQL技术内幕10:PostgreSQL事务原理解析-日志模块介绍

文章目录 0.简介1.PG日志介绍2.事务日志介绍3.WAL分析3.1 WAL概述3.2 WAL设计考虑3.2.1 存储格式3.2.2 实现方式3.2.3 数据完整性校验3.3 check ponit 4.事务提交日志&#xff08;CLOG&#xff09;4.1 clog存储使用介绍4.2 slru缓冲池并发控制 0.简介 本文将延续上一篇文章内容…

59.【C语言】内存函数(memmove函数)

目录 2.memove函数 *简单使用 部分翻译 *模拟实现 方案1 方案2 1.有重叠 dest在src左侧 dest在src右侧 2.无重叠 代码 2.memove函数 *简单使用 memove:memory move cplusplus的介绍 点我跳转 对比第59篇的memcpy函数 对比memmcpy函数的介绍如下区别: 部分翻译 m…

金刚石切削工具学习笔记分享

CVD钻石-合成单晶钻石之一 金刚石具有极高的硬度和耐磨性、较低的摩擦系数、较高的弹性模量、较高的热导率、较低的热膨胀系数、与有色金属的亲和力较小等优点&#xff0c;是目前最硬的工具材料&#xff0c;主要分为单晶金刚石和聚晶金刚石两大类。单晶金刚石又分为天然单晶金…

常用卫星学习

文章目录 Landsat-8 Landsat-8 由一台操作陆地成像仪 &#xff08;OLI&#xff09; 和一台热红外传感器 &#xff08;TIRS&#xff09;的卫星&#xff0c;OLI 提供 9 个波段&#xff0c;覆盖 0.43–2.29 μm 的波长&#xff0c;其中全色波段&#xff08;一般指0.5μm到0.75μm左…

CentOS Stream 9部署MariaDB

1、更新系统软件包 sudo dnf update 2、安装MariaDB软件包&#xff08;替代mysql&#xff09; sudo dnf install mariadb-server 3、安装MariaDB服务 sudo systemctl enable --now mariadb 4、检查MariaDB服务状态 sudo systemctl status mariadb 5、配置MariaDB安全性 sudo my…

锐捷 睿易路由器存在RCE漏洞

漏洞描述 锐捷Ruijie路由器命令执行漏 漏洞复现 FOFA: icon_hash"-399311436" 点击左下角的“网络诊断”&#xff0c;在“Tracert检测”的“地址”框中&#xff0c;输入127.0.0.1;ls&#xff0c;接着点击“开始检测”&#xff0c;会在检测框中回显命令执行结果。…

代码编辑器 —— SourceInsight实用技巧

目 录 Source insight 重要性一、创建项目二、代码浏览三、代码同步 Source insight 重要性 Source Insight 是一款功能强大的代码编辑器&#xff0c;在软件开发中占据着重要地位。 Source Insight 能够帮助开发者更高效地解读和修改代码&#xff0c;提高开发效率和代码质量。…