用 Python 从零开始创建神经网络(八):梯度、偏导数和链式法则

梯度、偏导数和链式法则

  • 引言
  • 1. 偏导数
  • 2. 和的偏导数
  • 3. 乘法的偏导数
  • 4. Max 的偏导数
  • 5. 梯度(The Gradient)
  • 6. 链式法则(The Chain Rule)

引言

在我们继续编写我们的神经网络代码之前,最后两个需要解决的难题是梯度和偏导数的相关概念。我们到目前为止解决的导数案例都是函数中只有一个独立变量的情况——也就是说,结果完全依赖于 x x x(在我们的案例中)。然而,我们的神经网络由多个输入的神经元组成。每个输入都与相应的权重(2个参数的函数)相乘,并且与偏置(与输入数量一样多的参数,再加上一个偏置)相加。正如我们很快将详细解释的,为了学习所有输入、权重和偏置对神经元输出以及最终损失函数的影响,我们需要计算神经元和整个模型在前向传递过程中执行的每个操作的导数。为了做到这一点并得到答案,我们需要使用链式法则,我们将很快在本章中解释这一点。


1. 偏导数

偏导数用来衡量单个输入对函数输出的影响程度。计算偏导数的方法与上一章中解释的导数方法相同;我们只需要对每个独立输入重复这个过程。

函数的每个输入都在一定程度上影响这个函数的输出,即使这种影响是0。我们需要知道这些影响;这意味着我们必须分别对每个输入计算导数,以了解它们各自的影响。这就是为什么我们称这些为针对给定输入的偏导数——我们在计算与单个输入相关的导数的一部分。偏导数是一个单一的方程,而完整的多变量函数的导数由一组称为梯度的方程组成。换句话说,梯度是一个向量,其大小等于包含针对每个输入的偏导数解的输入数量。我们很快会回到梯度的话题。

为了表示偏导数,我们将使用欧拉记法。它与莱布尼茨记法非常相似,我们只需要将微分算子 d d d替换为 ∂ \partial 。虽然 d d d算子可能被用来表示多变量函数的微分,但其含义略有不同——它可以表示函数相对于给定输入的变化率,但当其他输入也可能变化时,它主要在物理学中使用。我们感兴趣的是偏导数,这是一种尝试找到给定输入对输出的影响,同时将所有其他输入视为常数的情况。我们对单个输入的影响感兴趣,因为我们的目标是在模型中更新参数。 ∂ \partial 算子明确表示了这一点——偏导数:

在这里插入图片描述


2. 和的偏导数

针对给定输入计算偏导数意味着像计算一个输入的常规导数一样进行计算,只是在计算时将其他输入视为常数。例如:

在这里插入图片描述

首先,我们应用了和的规则——和的导数是各个导数的和。然后,我们已经知道 x x x相对于 x x x的导数等于1。新的情况是 y y y相对于 x x x的导数。正如我们提到的, y y y被视为常数,因为当我们相对于 x x x求导时, y y y不会改变,而常数的导数等于0。在第二种情况中,我们相对于 y y y求导,因此将 x x x视为常数。换句话说,不管这个例子中 y y y的值如何, x x x的斜率不依赖于 y y y。不过,情况并非总是如此,我们很快就会看到。

让我们尝试另一个例子:

在这里插入图片描述

在这个例子中,我们首先应用了和的规则,然后将常数移到导数的外面,并分别针对 x x x y y y计算剩余部分。与上一章中的非多变量导数的唯一区别是“偏导”部分,这意味着我们分别对每个变量进行求导。除此之外,这里没有什么新内容。

让我们尝试一些看似更复杂的内容:

在这里插入图片描述

非常简单——我们不断地重复应用相同的规则,并且我们没有在这个例子中添加任何新的计算或规则。


3. 乘法的偏导数

在继续之前,我们先来介绍一下乘法运算的偏导数:

在这里插入图片描述

我们已经提到过,我们需要将其他独立变量视为常数,并且我们还学习了可以将常数移到导数的外面。这正是我们解决乘法偏导数计算的方法——我们将其他变量视为常数,如同数字一样,并将它们移到导数外面。结果显示,当我们对 x x x求导时, y y y被视为常数,结果等于 y y y乘以 x x x x x x的导数,即1。整个导数的结果就是 y y y。这个例子背后的直觉是,当计算关于 x x x的偏导数时, x x x的每增加1,函数的输出就增加 y y y。例如,如果 y = 3 y=3 y=3 x = 1 x=1 x=1,结果是 1 ⋅ 3 = 3 1\cdot 3 = 3 13=3。当我们将 x x x增加1,使 y = 3 y=3 y=3 x = 2 x=2 x=2时,结果是 2 ⋅ 3 = 6 2\cdot 3 = 6 23=6。我们将 x x x增加了1,结果增加了3,即增加了 y y y。这就是这个函数关于 x x x的偏导数告诉我们的内容。
让我们引入第三个输入变量,并为另一个例子添加变量的乘法:

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

这里唯一的新操作,正如前面提到的,是将我们求导时不涉及的其他变量移出导数。这个例子中的结果看起来更复杂,但这只是因为其中包含了其他变量——在求导过程中被视为常数的变量。导数的方程虽然较长,但并不一定更复杂。

学习偏导数的原因是我们很快将计算多变量函数的偏导数,其中一个例子是神经元。从代码的角度和更具体地说,密集层类的前向方法来看,我们传入一个单一变量——输入数组,其中包含一批样本或前一层的输出。从数学的角度来看,这个单一变量(一个数组)的每个值都是一个单独的输入——它包含的输入数量与输入数组中的数据量一样多。例如,如果我们向神经元传递一个包含4个值的向量,在代码中它是一个单一变量,但在方程中它是4个单独的输入。这形成了一个接受多个输入的函数。为了了解每个输入对函数输出的影响,我们需要计算这个函数关于每个输入的偏导数,这将在下一章中详细解释。


4. Max 的偏导数

导数和偏导数不仅限于加法和乘法运算或常数。我们需要为前向传递中使用的其他函数推导它们,其中一个是 max() 函数的导数:

在这里插入图片描述

max函数返回最大的输入值。我们知道 x x x相对于 x x x的导数等于1,因此如果 x x x大于 y y y,这个函数相对于 x x x的导数就等于1,因为函数将返回 x x x。在另一种情况下,如果 y y y大于 x x x,并且 y y y将被返回,那么max()函数相对于 x x x的导数等于0——我们将 y y y视为常数, y y y相对于 x x x的导数等于0。我们可以将其表示为 1 ( x > y ) 1(x > y) 1(x>y),这意味着如果条件满足则为1,否则为0。我们也可以计算max()相对于 y y y的偏导数。

max()函数的导数的一个特殊情况是当我们只有一个变量参数,而另一个参数总是恒定为0时。这意味着我们希望返回更大的值——0或输入值,实际上是在正方向上将输入值限制在0。当我们计算ReLU激活函数的导数时,处理这种情况将非常有用,因为该激活函数定义为 m a x ( x , 0 ) max(x, 0) max(x,0)

在这里插入图片描述


5. 梯度(The Gradient)

正如我们在本章开始时提到的,梯度是一个向量,由一个函数的所有偏导数组成,每个偏导数都是针对每个输入变量计算的。

让我们回顾一下我们之前计算过的求和操作的一个偏导数:

在这里插入图片描述

如果我们计算所有的偏导数,我们就可以形成函数的梯度。使用不同的符号,它看起来如下:

在这里插入图片描述

这就是我们需要了解的关于梯度的所有信息——它是一个向量,包含了函数的所有可能的偏导数,我们使用∇(nabla)符号来表示它,这个符号看起来像一个倒置的delta符号。

我们将使用单参数函数的导数和多变量函数的梯度来执行梯度下降,使用链式法则,换句话说,来执行反向传递,这是模型训练的一部分。我们将如何具体做到这一点将是下一章的主题。


6. 链式法则(The Chain Rule)

在前向传递过程中,我们将数据通过神经元,然后通过激活函数,再通过下一层的神经元,然后通过另一个激活函数,依此类推。我们用输入参数调用一个函数,取得输出,并将该输出作为另一个函数的输入。以这个简单的例子,让我们考虑两个函数: f f f g g g

在这里插入图片描述

x x x 是输入数据, z z z 是函数 f f f的输出,但也是函数 g g g的输入, y y y 是函数 g g g的输出。我们可以将相同的计算写为:

在这里插入图片描述

在这种形式中,我们没有使用中间变量 z z z,显示函数 g g g直接将函数 f f f的输出作为输入。这与上面的两个方程式没有太大差别,但显示了这样链接的函数的一个重要特性——既然 x x x是函数 f f f的输入,然后函数 f f f的输出是函数 g g g的输入,函数 g g g的输出以某种方式受到 x x x的影响,因此必须存在一个导数可以告诉我们这种影响。

我们模型的前向传递是一连串类似这些例子的函数。我们输入样本,数据流通过所有层和激活函数形成输出。在这里插入图片描述

在这里插入图片描述

如果你仔细观察,你会发现我们将损失描述为一个大函数,或者是多个输入的函数链——输入数据、权重和偏置。我们将输入数据传递到第一层,在那里我们也有该层的权重和偏置,然后输出通过ReLU激活函数流动,再通过另一层,带来更多的权重和偏置,再经过另一个ReLU激活,一直到最后——输出层和softmax激活。模型输出连同目标一起传递到损失函数,该函数返回模型的误差。我们不仅可以将损失函数视为一个函数,它接受模型的输出和目标作为参数来产生误差,而且还可以将其视为一个函数,如果我们将在前向传递期间执行的所有函数串联起来,就像我们刚才在图像中展示的那样,它接受目标、样本以及所有的权重和偏置作为输入。为了改善损失,我们需要了解每个权重和偏置是如何影响它的。如何对函数链进行这样的操作呢?通过使用链式法则。这条规则说明,一个函数链的导数是这个链中所有函数的导数的乘积,例如:

在这里插入图片描述

首先,我们写了外函数 f ( g ( x ) ) f(g(x)) f(g(x))关于内函数 g ( x ) g(x) g(x)的导数,因为这个内函数是它的参数。接下来,我们乘以内函数 g ( x ) g(x) g(x)关于其参数 x x x的导数。我们还用两种不同的记法表示了这个导数。对于有3个函数和多个输入的情况,这个函数关于 x x x的偏导数如下(在这种情况下我们不能使用撇号记法,因为我们必须提及我们相对于哪个变量进行求导):

在这里插入图片描述

为了计算一系列函数关于某个参数的偏导数,我们取链中外函数相对于内函数的偏导数,然后将这个偏导数乘以链中内函数相对于更内部函数的偏导数,然后再将其乘以更内部函数相对于链中其他函数的偏导数。我们一直重复到所讨论的参数。例如,请注意,中间的导数是关于 h ( x , z ) h(x, z) h(x,z)而不是 y y y的,因为 h ( x , z ) h(x, z) h(x,z)在参数 x x x的链中。链式法则被证明是找到单个输入对一系列函数输出的影响的最重要的规则,在我们的情况下,这是损失的计算。我们将在下一章讨论和编写反向传播时再次使用它。现在,让我们举一个链式法则的例子。

让我们求解 h ( x ) = 3 ( 2 x 2 ) 5 h(x) = 3(2x^2)^5 h(x)=3(2x2)5的导数。我们首先注意到的是我们有一个复杂的函数,可以分解成两个更简单的函数。第一个是方程中包含在括号内的部分,我们可以将其写为 g ( x ) = 2 x 2 g(x) = 2x^2 g(x)=2x2。这是我们指数化和与方程其余部分相乘的内部函数。然后方程的其余部分可以写成 f ( y ) = 3 y 5 f(y) = 3y^5 f(y)=3y5。在这种情况下, y y y是我们将其表示为 g ( x ) = 2 x 2 g(x)=2x^2 g(x)=2x2,当我们将它合并回去时,我们得到 h ( x ) = f ( g ( x ) ) = 3 ( 2 x 2 ) 5 h(x) = f(g(x)) = 3(2x^2)^5 h(x)=f(g(x))=3(2x2)5。要计算这个函数的导数,我们首先取外部的指数5,并将它放在我们要指数化乘以的组件前,后乘以前面的3,得到15。然后我们从5的指数中减去1,留下4。

在这里插入图片描述
然后链式法则告诉我们将外部函数的上述导数与内部函数的导数相乘,得到:

在这里插入图片描述

回想一下, 4 x 4x 4x 2 x 2 2x^2 2x2的导数,这是内部函数 g ( x ) g(x) g(x)。这在一个例子中突出了链式法则的概念,允许我们通过将导数链接在一起来计算更复杂函数的导数。请注意,我们乘以了内部函数的导数,但在外部函数的导数中保留了未改变的内部函数。

理论上,我们可以在这里就停下来,得到一个完全可用的函数导数。我们可以输入一些值到 15 ( 2 x 2 ) 4 ⋅ 4 x 15(2x^2)^4 \cdot 4x 15(2x2)44x中并得到答案。话虽如此,我们也可以继续前进并简化这个函数以便更多练习。回到原始问题,到目前为止我们已经找到了:

在这里插入图片描述

为了简化这个导数函数,我们首先取 ( 2 x 2 ) 4 (2x^2)^4 (2x2)4并分配4的指数:
在这里插入图片描述
合并 x:
在这里插入图片描述
常数如下:
在这里插入图片描述

我们稍后也会简化导数,以便加快计算速度——当我们可以提前解决问题时,没有理由重复相同的操作。

希望你现在能理解什么是导数和偏导数,什么是梯度,损失函数相对于权重和偏置的导数是什么意思,以及如何使用链式法则。目前,这些术语可能听起来没有联系,但我们将使用它们全部来执行反向传播步骤中的梯度下降,这将是下一章的主题。

本章的章节代码、更多资源和勘误表:https://nnfs.io/ch8

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/18201.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

并查集 poj 2524,1611,1703,2236,2492,1988 练习集【蓝桥杯备赛】

目录 前言 并查集优势 Ubiquitous Religions poj 2524 问题描述 问题分析 代码 The Suspects poj 1611 问题描述 问题分析 代码 Wireless Network poj 2236 问题描述 问题分析 代码 分类 带权并查集合 权值树构建步骤 Find them, Catch them poj 1703 问题描述 问题分…

zabbix监控tomcat

1. 准备JDK环境 #vim /etc/profile export JAVA_HOME/usr/local/jdk export TOMCAT_HOME/usr/local/tomcat export PATH$PATH:$JAVA_HOME/bin:$JAVA_HOME/jre/bin:$TOMCAT_HOMOE/bin [rootCentOS8 ~]# source /etc/profile [rootCentOS8 ~]# java -version openjdk version &q…

Nuget For Unity插件介绍

NuGet for Unity:提升 Unity 开发效率的利器 NuGet 是 .NET 开发生态中不可或缺的包管理工具,你可以将其理解为Unity的Assets Store或者UPM,里面有很多库可以帮助我们提高开发效率。当你想使用一个库,恰好这个库没什么依赖(比如newtonjson),那么下载包并找到Dll直接…

如何在 Ubuntu 上安装 Mattermost 团队协作工具

简介 Mattermost 是一个开源、自托管的通信平台,专为团队协作设计。它类似于 Slack,提供聊天、消息传递和集成功能。Mattermost 在重视数据隐私的组织中特别受欢迎,因为它允许团队在自己的服务器上管理通信。以下是 Mattermost 的一些关键特…

初识Linux—— 基本指令(上)

前言 Linux简述 ​ Linux是一种开源、自由、类UNIX的操作系统,由著名的芬兰程序员林纳斯托瓦兹(Linus Torvalds)于1991年首次发布。Linux的内核在GNU通用公共许可证(GPL)下发布,这意味着任何人都可以自由…

VBA技术资料MF223:从文件添加新模块

我给VBA的定义:VBA是个人小型自动化处理的有效工具。利用好了,可以大大提高自己的工作效率,而且可以提高数据的准确度。“VBA语言専攻”提供的教程一共九套,分为初级、中级、高级三大部分,教程是对VBA的系统讲解&#…

利用RAGflow和LM Studio建立食品法规问答系统

前言 食品企业在管理标准、法规,特别是食品原料、特殊食品法规时,难以通过速查法规得到准确的结果。随着AI技术的发展,互联网上出现很多AI知识库的解决方案。 经过一轮测试,找到问题抓手、打通业务底层逻辑、对齐行业颗粒度、沉…

路径规划——RRT*算法

路径规划——RRT*算法 算法原理 RRT Star 算法是一种渐近最优的路径规划算法,它是 RRT 算法的优化版本。RRT Star 算法通过不断地迭代和优化,最终可以得到一条从起点到目标点的最优路径。 在学习RRT Star 算法之前最好先学习一下RRT原始算法&#xff1…

Java——并发工具类库线程安全问题

摘要 本文探讨了Java并发工具类库中的线程安全问题,特别是ThreadLocal导致的用户信息错乱异常场景。文章通过一个Spring Boot Web应用程序示例,展示了在Tomcat线程池环境下,ThreadLocal如何因线程重用而导致异常,并讨论了其他并发…

网络编程套接字

前言: 认识了网络,我们就应该考虑一下如何编程实现不同主机上的应用进程之间如何进行双向互通的端点。 套接字(Socket)是网络编程的一种基本概念,套接字是应用程序通过网络协议进行通信的接口,是操作系统提…

计算机网络:运输层 —— TCP 的拥塞控制

文章目录 TCP的拥塞控制拥塞控制的基本方法流量控制与拥塞控制的区别拥塞控制分类闭环拥塞控制算法 TCP的四种拥塞控制方法(算法)窗口慢开始门限慢开始算法拥塞避免算法快重传算法快恢复算法 TCP拥塞控制的流程TCP拥塞控制与网际层拥塞控制的关系 TCP的拥…

vue学习第8章(vue的购物车案例)

🎉🎉🎉欢迎来到我的博客,我是一名自学了2年半前端的大一学生,熟悉的技术是JavaScript与Vue.目前正在往全栈方向前进, 如果我的博客给您带来了帮助欢迎您关注我,我将会持续不断的更新文章!!!🙏🙏🙏 文章目录…

【SpringBoot】什么是Maven,以及如何配置国内源实现自动获取jar包

前言 🌟🌟本期讲解关于Maven的了解和如何进行国内源的配置~~~ 🌈感兴趣的小伙伴看一看小编主页:GGBondlctrl-CSDN博客 🔥 你的点赞就是小编不断更新的最大动力 &#x1f3…

【Linux】:进程信号(详谈信号捕捉 OS 运行)

✨ 来去都是自由风,该相逢的人总会相逢 🌏 📃个人主页:island1314 🔥个人专栏:Linux—登神长阶 ⛺️ 欢迎关注:👍点赞…

七、利用CSS和多媒体美化页面的习题

题目一&#xff1a; 利用CSS技术&#xff0c;结合表格和列表&#xff0c;制作并美化 “ 翡翠阁 ”页面。运行效果如下 运行效果&#xff1a; 代码 <!DOCTYPE html> <html><head><meta charset"utf-8" /><title>翡翠阁</title>&…

动态规划 —— 子数组系列-等差数列划分

1. 等差数列划分 题目链接&#xff1a; 413. 等差数列划分 - 力扣&#xff08;LeetCode&#xff09;https://leetcode.cn/problems/arithmetic-slices/description/ 2. 算法原理 状态表示&#xff1a;以某一个位置为结尾或者以某一个位置为起点 dp[i]表示&#xff1a;以i位置为…

vue使用List.reduce实现统计

需要对集合的某些元素的值进行计算时&#xff0c;可以在计算属性中使用forEach方法 1.语法&#xff1a;集合.reduce ( ( 定义阶段性累加后的结果 , 定义遍历的每一项 ) > 定义每一项求和逻辑执行后的返回结果 , 定义起始值 ) 2、简单使用场景&#xff1a;例如下面…

TensorFlow 2.0 windows11 GPU 训练环境配置

前言 在一切开始之前&#xff0c;请确保你的cmd命令行和powershell命令行可以正常打开。如果不能&#xff0c;建议重装系统。我不确定这是否会影响你最终的结果&#xff0c;毕竟windows的坑太多了。 安装顺序&#xff1a;visual studio -> cuda -> cudnn -> python…

MyISAM和InnoDB介绍及切换存储引擎方法

MyISAM 和 InnoDB 都是 MySQL 数据库管理系统中常用的存储引擎&#xff08;Storage Engine&#xff09;。存储引擎决定了数据库如何存储、读取、更新数据以及如何管理事务、锁定等操作。 1. MyISAM 存储引擎 MyISAM 是 MySQL 的默认存储引擎之一&#xff0c;尤其是在早期版本…

什么是嵌入式?

目录 一、什么是嵌入式 二、嵌入式系统的特点 &#xff08;一&#xff09;专用性与隐蔽性 &#xff08;二&#xff09;高可靠性与实时性 &#xff08;三&#xff09;资源固定与小型化 三、嵌入式系统的发展历史 &#xff08;一&#xff09;20 世纪 60 年代早期雏形 &am…