【Pytorch笔记】4.梯度计算

深度之眼官方账号 - 01-04-mp4-计算图与动态图机制

前置知识:计算图
可以参考我的笔记:
【学习笔记】计算机视觉与深度学习(2.全连接神经网络)

计算图

在这里插入图片描述
以这棵计算图为例。这个计算图中,叶子节点为x和w。

import torchw = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)# 调用backward()方法,开始反向求梯度
y.backward()
print(w.grad)print("is_leaf:\n", w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)
print("gradient:\n", w.grad, x.grad, a.grad, b.grad, y.grad)

输出:

tensor([5.])
is_leaf:True True False False False
gradient:tensor([5.]) tensor([2.]) None None None

由此可见,非叶子节点在最后不会被保留梯度。这是出于节省空间的需要而这样设计的。实际的计算图会非常大,如果每个节点都保留梯度,会占用非常大的存储空间,而这些节点的梯度对于我们学习并没有什么帮助。

如果非要看他们的梯度,可以这样操作:在a = torch.add(w, x)的后面加上一句a.retain_grad(),这样a的梯度就会被存储起来。
输出会变成:

tensor([5.])
is_leaf:True True False False False
gradient:tensor([5.]) tensor([2.]) tensor([2.]) None None

对于节点,还可以看这些节点进行的运算。grad_fn,gradient function的缩写,表示这个节点的tensor是什么运算产生的。加一句:

print("gradient function:\n", w.grad_fn, '\n', x.grad_fn, '\n', a.grad_fn, '\n', b.grad_fn, '\n', y.grad_fn)

会输出

gradient function:NoneNone<AddBackward0 object at 0x000001B1DA3651C0><AddBackward0 object at 0x000001B1DA3651F0><MulBackward0 object at 0x000001B1DA3515B0>

retain_graph

import torchw = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)
a.retain_grad()
b = torch.add(w, 1)
y = torch.mul(a, b)# 调用backward()方法,开始反向求梯度
y.backward()
y.backward()

连续两次调用backward()方法,会报这样的错误:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因是我们进行第一次backward()后,计算图就被自动释放掉了,进行第二次backward()时,没有计算图可以计算梯度,于是报错。

解决方案:backward内部添加一个参数:retain_graph=True,意思是计算完梯度后保留计算图。

# 调用backward()方法,开始反向求梯度
y.backward(retain_graph=True)
y.backward()

这样就不会报错了。

gradient

当计算图末部的节点有1个以上时,有时我们会希望他们之间的梯度有一个权重关系。这时就会用上gradient

import torchw = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)
b = torch.add(w, 1)# 不难看出,y0和y1是两个互不干扰的末部节点
y0 = torch.mul(a, b)
y1 = torch.add(a, b)# 将两个末部节点打包起来
loss = torch.cat([y0, y1], dim=0)
grad_tensors = torch.tensor([1., 2.])# 将grad_tensors中的内容作为权重,变成y0+2y1
loss.backward(gradient=grad_tensors)print(w.grad)

输出

tensor([9.])

如果把grad_tensors改成:

grad_tensors = torch.tensor([1., 3.])

输出变成:

tensor([11.])

torch.autograd.grad()

除了加减乘除法,我们还可以对torch进行求导操作。求的是 d ( o u t p u t s ) d ( i n p u t s ) \frac{d(outputs)}{d(inputs)} d(inputs)d(outputs)

torch.autograd.grad(outputs,inputs,grad_outputs=None,retain_graph=None,create_graph=False)

outputs和inputs已在上述定义中给出;
grad_outputs:多梯度权重;
retain_graph:保留计算图;
create_graph:创建计算图。

import torch# y = x ** 2
x = torch.tensor([3.], requires_grad=True)
y = torch.pow(x, 2)# grad_1 = dy / dx = 2x = 6
grad_1 = torch.autograd.grad(y, x, create_graph=True)
print(grad_1)# grad_2 = d(dy / dx) / dx = 2
grad_2 = torch.autograd.grad(grad_1, x)
print(grad_2)

输出

(tensor([6.], grad_fn=<MulBackward0>),)
(tensor([2.]),)

autograd注意事项

1.梯度不会自动清零

import torchw = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)for i in range(4):a = torch.add(w, x)b = torch.mul(w, x)y = torch.mul(a, b)y.backward()print("w's grad: ", w.grad)# w.grad.zero_()

输出:

w's grad:  tensor([8.])
w's grad:  tensor([16.])
w's grad:  tensor([24.])
w's grad:  tensor([32.])

由此可以看出,在不加上注释掉的那一行时,梯度在w处是不断累积的。而如果我们把print后面的那句w.grad.zero_()加上,输出就会变成:

w's grad:  tensor([8.])
w's grad:  tensor([8.])
w's grad:  tensor([8.])
w's grad:  tensor([8.])

w.grad.zero_()的意思就是把w处积累的梯度清零。

2.依赖于叶子节点的节点,requires_grad默认为True

可以从上面的代码中发现,我们只有在定义w和x两个tensor时,设置requires_grad为True。这个参数在定义tensor时默认为False。后面我们的a、b、y都没有设置这个参数。

如果我们定义w和x的时候不加上requires_grad=True,那么y.backward()这一步就会报错,因为我们的预设,这两个tensor不需要梯度,于是就无法求梯度。而w和x是我们计算图上的叶子节点,所以必须加上requires_grad=True。

而后面通过w和x延伸定义出的a、b、y,由于依赖的w、x的requires_grad是True,那么a、b、y的这个参数也被默认设置为了True,不需要我们手动添加。

3.叶子节点不可执行in-place操作

计算图上叶子节点处的tensor不能进行原地修改。

什么是in-place操作?
t = torch.tensor([1., 2.])
t.add_(3.)
print(t)

输出

tensor([4., 5.])

torch.Tensor.add_就是torch.add的in-place版本。所谓in-place,就是在tensor上进行原地修改。大部分的torch.tensor的运算,名字后面加一个下划线,就变成inplace操作了。

再比如求绝对值:

t = torch.tensor([-1., -2.])
t.abs_()
print(t)

输出

tensor([1., 2.])

知道什么是in-place操作后,我们尝试一下在requires_grad=True的叶子节点上原地修改,代码如下:

import torchw = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)
b = torch.mul(w, x)
y = torch.mul(a, b)w.add_(1)y.backward()

报错信息:

RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/148021.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

使用关键字interface来声明使用接口-PHP8知识详解

继承特性简化了对象、类的创建&#xff0c;增加了代码的可重用性。但是php8只支持单继承&#xff0c;如果想实现多继承&#xff0c;就需要使用接口。PHP8可以实现多个接口。 接口类通过关键字interface来声明&#xff0c;接口中不能声明变量&#xff0c;只能使用关键字const声明…

机器人中的数值优化|【六】线性共轭梯度法,牛顿共轭梯度法

机器人中的数值优化|【六】线性共轭梯度法&#xff0c;牛顿共轭梯度法 往期回顾 机器人中的数值优化|【一】数值优化基础 机器人中的数值优化|【二】最速下降法&#xff0c;可行牛顿法的python实现&#xff0c;以Rosenbrock function为例 机器人中的数值优化|【三】无约束优化…

stm32 - 中断

stm32 - 中断 概念中断向量表NVIC 嵌套中断向量控制器优先级 中断EXTI概念基本结构例子- 对射式红外传感器计次例子 - 旋转编码器 概念 stm32 支持的中断资源&#xff08;都属于外设&#xff09; EXTITIMADCUSARtSPII2C stm32支持的中断 内核中断 外设中断 中断通道与优先级 一…

C# 读取Execl文件3种方法

方法 1&#xff0c;使用OLEDB可以对excel文件进行读取 1.1C#提供的数据连接有哪些 对于不同的.net数据提供者&#xff0c;ADO.NET采用不同的Connection对象连接数据库。这些Connection对我们屏蔽了具体的实现细节&#xff0c;并提供了一种统一的实现方法。 Connection类有四…

【Linux】线程池

目录 一、线程池1.什么是线程池2.线程池图解3.实现代码 二、单例模式1.单例模式的概念2.饿汉方式实现单例模式3.懒汉方式实现单例模式4.懒汉方式实现单例模式的线程池 一、线程池 1.什么是线程池 线程虽然比进程轻量了很多&#xff0c;但是每创建一个线程时&#xff0c;需要向…

UCOS的任务创建和删除

一、任务创建和删除的API函数 1、任务创建和删除本质就是调用uC/OS的函数 API函数 描述 OSTaskCreate() 创建任务 OSTaskDel() 删除任务 注意&#xff1a; 1&#xff0c;使用OSTaskCreate() 创建任务&#xff0c;任务的任务控制块以及任务栈空间所需的内存&#xff0c…

算法——买卖股票问题

309. 买卖股票的最佳时机含冷冻期 - 力扣&#xff08;LeetCode&#xff09; 一、 究其就是个动态规划的问题 算法实现图 初始化 由于有三个阶段&#xff0c;买入&#xff0c;可交易&#xff0c;冷冻期&#xff0c;那么用dp表表示现在为止的最大利润&#xff0c;则有 dp[0][…

asp.net core 远程调试

大概说下过程&#xff1a; 1、站点发布使用Debug模式 2、拷贝到远程服务器&#xff0c;以及iis创建站点。 3、本地的VS2022的安装目录&#xff1a;C:\Program Files\Microsoft Visual Studio\2022\Professional\Common7\IDE下找Remote Debugger 你的服务器是64位就拷贝x64的目…

详解Linux的系统调用fork()函数

在Linux系统中&#xff0c;fork()是一个非常重要的系统调用&#xff0c;它的作用是创建一个新的进程。具体来说&#xff0c;fork()函数会在当前进程的地址空间中复制一份子进程&#xff0c;并且这个子进程几乎完全与父进程相同&#xff0c;包括进程代码、数据、堆栈以及打开的文…

WebSocket实战之四WSS配置

一、前言 上一篇文章WebSocket实战之三遇上PAC &#xff0c;碰到的问题只能上安全的WebSocket&#xff08;WSS&#xff09;才能解决&#xff0c;配置证书还是挺麻烦的&#xff0c;主要是每年都需要重新更新证书&#xff0c;我配置过的证书最长有效期也只有两年&#xff0c;搞不…

ElasticSearch第四讲:ES详解:ElasticSearch和Kibana安装

ElasticSearch第四讲&#xff1a;ES详解&#xff1a;ElasticSearch和Kibana安装 本文是ElasticSearch第四讲&#xff1a;ElasticSearch和Kibana安装&#xff0c;主要介绍ElasticSearch和Kibana的安装。了解完ElasticSearch基础和Elastic Stack生态后&#xff0c;我们便可以开始…

ctfshow—1024系列练习

1024 柏拉图 有点像rce远程执行&#xff0c;有四个按钮&#xff0c;分别对应四份php文件&#xff0c;开始搞一下。一开始&#xff0c;先要试探出 文件上传到哪里&#xff1f; 怎么读取上传的文件&#xff1f; 第一步&#xff1a;试探上传文件位置 直接用burp抓包&#xff0c;…

力扣练习——链表在线OJ

目录 提示&#xff1a; 一、移除链表元素 题目&#xff1a; 解答&#xff1a; 二、反转链表 题目&#xff1a; 解答&#xff1a; 三、找到链表的中间结点 题目&#xff1a; 解答&#xff1a; 四、合并两个有序链表&#xff08;经典&#xff09; 题目&#xff1a; 解…

【数据结构---排序】很详细的哦

本篇文章介绍数据结构中的几种排序哦~ 文章目录 前言一、排序是什么&#xff1f;二、排序的分类 1.直接插入排序2.希尔排序3.选择排序4.冒泡排序5.快速排序6.归并排序总结 前言 排序在我们的生活当中无处不在&#xff0c;当然&#xff0c;它在计算机程序当中也是一种很重要的操…

聊聊常见的IO模型 BIO/NIO/AIO 、DIO、多路复用等IO模型

聊聊常见的IO模型 BIO/NIO/AIO/DIO、IO多路复用等IO模型 文章目录 一、前言1. 什么是IO模型2. 为什么需要IO模型 二、常见的IO模型1. 同步阻塞IO&#xff08;Blocking IO&#xff0c;BIO&#xff09;2. 同步非阻塞IO&#xff08;Non-blocking IO&#xff0c;NIO&#xff09;3.…

排序算法之【希尔排序】

&#x1f4d9;作者简介&#xff1a; 清水加冰&#xff0c;目前大二在读&#xff0c;正在学习C/C、Python、操作系统、数据库等。 &#x1f4d8;相关专栏&#xff1a;C语言初阶、C语言进阶、C语言刷题训练营、数据结构刷题训练营、有感兴趣的可以看一看。 欢迎点赞 &#x1f44d…

八大排序源码(含优化)

文章目录 1、直接插入排序2、希尔排序3、选择排序4、冒泡排序5、堆排序6、快速排序快速排序递归实现霍尔法挖坑法前后指针法快速排序小区间优化 快速排序非递归实现 7、归并排序归并排序递归实现归并排序非递归 8、计数排序 大家好&#xff0c;我是纪宁&#xff0c;这篇文章是关…

java Spring Boot 自动启动热部署 (别再改点东西就要重启啦)

上文 java Spring Boot 手动启动热部署 我们实现了一个手动热部署的代码 但其实很多人会觉得 这叫说明热开发呀 这么捞 写完还要手动去点一下 很不友好 其实我们开发人员肯定是希望重启这种事不需要自己手动去做 那么 当然可以 我们就让它自己去做 Build Project 这个操作 我们…

Linux性能优化--性能工具:系统内存

3.0.概述 本章概述了系统级的Linux内存性能工具。本章将讨论这些工具可以测量的内存统计信息&#xff0c;以及如何使用各种工具收集这些统计结果。阅读本章后&#xff0c;你将能够&#xff1a; 理解系统级性能的基本指标&#xff0c;包括内存的使用情况。明白哪些工具可以检索…

Java21 新特性

文章目录 1. 概述2. JDK21 安装与配置3. 新特性3.1 switch模式匹配3.2 字符串模板3.3 顺序集合3.4 记录模式&#xff08;Record Patterns&#xff09;3.5 未命名类和实例的main方法&#xff08;预览版&#xff09;3.6 虚拟线程 1. 概述 2023年9月19日 &#xff0c;Oracle 发布了…