【AI知识点】反向传播(Backpropagation)

反向传播(Backpropagation) 是训练神经网络的核心算法,它通过反向逐层计算损失函数对每个权重的梯度,来反向逐层更新网络的权重,从而最小化损失函数。


一、反向传播的基本概念

1. 前向传播(Forward Propagation)

在前向传播中,输入数据从输入层通过隐藏层传递到输出层。网络通过层与层之间的连接(即权重)来计算每个节点的输出,最终生成网络的预测结果。

2. 计算损失(Compute Loss)

将网络的预测输出与真实值进行比较,计算损失函数(如均方误差),用来衡量网络的预测输出与真实值的差距。

3. 反向传播(Backward Propagation)

反向传播的过程主要由链式法则驱动。它通过逐层计算误差对权重的偏导数(梯度),从输出层反向传递到隐藏层,再传递到输入层(与前向传播顺序相反),以反向更新每层的权重,减少预测误差。

  • 前向传播相当于将输入数据从输入层逐步传递到输出层,得到预测结果。
  • 反向传播相当于从输出层开始反向传递误差,更新每一层的权重,使得网络在下次预测时能够减少误差。

4. 权重更新(Weights Update)

使用优化算法(如梯度下降)根据梯度更新权重。使得下一次前向传播时损失函数值减小。


二、反向传播的数学推导

对于一个简单的神经网络,损失函数 L L L 是关于网络输出 y y y 和真实值 t t t 的函数,而网络输出 y y y 又是关于输入 x x x 和权重 w w w 的函数。

通过链式法则,损失函数对权重的梯度可以表示为:

∂ L ∂ w = ∂ L ∂ y ⋅ ∂ y ∂ w \frac{\partial L}{\partial w} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial w} wL=yLwy


三、反向传播的图示

在这里插入图片描述
图片来源:https://ai.stackexchange.com/questions/31566/different-ways-to-calculate-backpropagation-derivatives-any-difference

  • 前向传播(蓝色箭头)负责计算输出预测值(Out)和误差(Err)。
  • 反向传播(绿色和红色箭头)从输出误差(Err)开始,将误差逐层传播到隐藏层( a a a)和输入层(X),计算每个权重(W)的梯度,用于后续的权重更新。

四、反向传播的简单计算示例

假设我们有一个简单的两层神经网络:

在这里插入图片描述

  • 输入层(x):一个节点,输入值为 x x x
  • 隐藏层(a):一个节点,激活函数为 Sigmoid 函数。
  • 输出层(y):一个节点,激活函数为线性函数,输出值为 y y y

网络的权重:

  • 输入层到隐藏层的权重: w 1 w_1 w1
  • 隐藏层到输出层的权重: w 2 w_2 w2

给定以下初始条件:

  • 输入 x = 1 x = 1 x=1
  • 目标输出 t = 0 t = 0 t=0
  • 初始权重 w 1 = 0.5 w_1 = 0.5 w1=0.5 w 2 = 0.5 w_2 = 0.5 w2=0.5
  • 学习率 η = 0.1 \eta = 0.1 η=0.1

步骤1:前向传播

  1. 计算隐藏层的输入和输出

z = w 1 ⋅ x = 0.5 ⋅ 1 = 0.5 z = w_1 \cdot x = 0.5 \cdot 1 = 0.5 z=w1x=0.51=0.5

隐藏层的激活输出(使用 Sigmoid 函数):

a = σ ( z ) = 1 1 + e − z = 1 1 + e − 0.5 ≈ 0.6225 a = \sigma(z) = \frac{1}{1 + e^{-z}} = \frac{1}{1 + e^{-0.5}} \approx 0.6225 a=σ(z)=1+ez1=1+e0.510.6225

  1. 计算输出层的输入和输出

y = w 2 ⋅ a = 0.5 ⋅ 0.6225 = 0.3112 y = w_2 \cdot a = 0.5 \cdot 0.6225 = 0.3112 y=w2a=0.50.6225=0.3112


步骤2:计算损失

使用均方误差(MSE)作为损失函数:

L = 1 2 ( y − t ) 2 = 1 2 ( 0.3112 − 0 ) 2 ≈ 0.0484 L = \frac{1}{2}(y - t)^2 = \frac{1}{2}(0.3112 - 0)^2 \approx 0.0484 L=21(yt)2=21(0.31120)20.0484


步骤3:反向传播

  1. 计算输出层对权重 w 2 w_2 w2 的梯度

∂ L ∂ w 2 = ∂ L ∂ y ⋅ ∂ y ∂ w 2 \frac{\partial L}{\partial w_2} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial w_2} w2L=yLw2y

计算各部分:

  • 损失函数对输出 y y y 的导数:

∂ L ∂ y = y − t = 0.3112 − 0 = 0.3112 \frac{\partial L}{\partial y} = y - t = 0.3112 - 0 = 0.3112 yL=yt=0.31120=0.3112

  • 输出 y y y 对权重 w 2 w_2 w2 的导数:

∂ y ∂ w 2 = a = 0.6225 \frac{\partial y}{\partial w_2} = a = 0.6225 w2y=a=0.6225

  • 合并计算梯度:

∂ L ∂ w 2 = 0.3112 × 0.6225 ≈ 0.1938 \frac{\partial L}{\partial w_2} = 0.3112 \times 0.6225 \approx 0.1938 w2L=0.3112×0.62250.1938

  1. 计算隐藏层对权重 w 1 w_1 w1 的梯度

∂ L ∂ w 1 = ∂ L ∂ a ⋅ ∂ a ∂ z ⋅ ∂ z ∂ w 1 \frac{\partial L}{\partial w_1} = \frac{\partial L}{\partial a} \cdot \frac{\partial a}{\partial z} \cdot \frac{\partial z}{\partial w_1} w1L=aLzaw1z

计算各部分:

  • 损失函数对隐藏层输出 a a a 的导数:

∂ L ∂ a = ∂ L ∂ y ⋅ ∂ y ∂ a = ( y − t ) ⋅ w 2 = 0.3112 ⋅ 0.5 = 0.1556 \frac{\partial L}{\partial a} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial a} = (y - t) \cdot w_2 = 0.3112 \cdot 0.5 = 0.1556 aL=yLay=(yt)w2=0.31120.5=0.1556

  • 隐藏层输出 a a a 对输入 z z z 的导数(Sigmoid 函数导数):

∂ a ∂ z = a ( 1 − a ) = 0.6225 ⋅ ( 1 − 0.6225 ) ≈ 0.2350 \frac{\partial a}{\partial z} = a (1 - a) = 0.6225 \cdot (1 - 0.6225) \approx 0.2350 za=a(1a)=0.6225(10.6225)0.2350

  • 输入 z z z 对权重 w 1 w_1 w1 的导数:

∂ z ∂ w 1 = x = 1 \frac{\partial z}{\partial w_1} = x = 1 w1z=x=1

  • 合并计算梯度:

∂ L ∂ w 1 = 0.1556 × 0.2350 × 1 ≈ 0.0365 \frac{\partial L}{\partial w_1} = 0.1556 \times 0.2350 \times 1 \approx 0.0365 w1L=0.1556×0.2350×10.0365


步骤4:更新权重

使用梯度下降法更新权重:

  1. 更新权重 w 2 w_2 w2

w 2 new = w 2 − η ⋅ ∂ L ∂ w 2 = 0.5 − 0.1 × 0.1938 ≈ 0.4806 w_2^{\text{new}} = w_2 - \eta \cdot \frac{\partial L}{\partial w_2} = 0.5 - 0.1 \times 0.1938 \approx 0.4806 w2new=w2ηw2L=0.50.1×0.19380.4806

  1. 更新权重 w 1 w_1 w1

w 1 new = w 1 − η ⋅ ∂ L ∂ w 1 = 0.5 − 0.1 × 0.0365 ≈ 0.4963 w_1^{\text{new}} = w_1 - \eta \cdot \frac{\partial L}{\partial w_1} = 0.5 - 0.1 \times 0.0365 \approx 0.4963 w1new=w1ηw1L=0.50.1×0.03650.4963


步骤5:验证更新后的网络

再次进行前向传播,计算新的输出和损失。

  1. 新的隐藏层输入和输出

z ′ = w 1 new ⋅ x = 0.4963 ⋅ 1 = 0.4963 z' = w_1^{\text{new}} \cdot x = 0.4963 \cdot 1 = 0.4963 z=w1newx=0.49631=0.4963

a ′ = σ ( z ′ ) = 1 1 + e − 0.4963 ≈ 0.6216 a' = \sigma(z') = \frac{1}{1 + e^{-0.4963}} \approx 0.6216 a=σ(z)=1+e0.496310.6216

  1. 新的输出层输出

y ′ = w 2 new ⋅ a ′ = 0.4806 ⋅ 0.6216 ≈ 0.2988 y' = w_2^{\text{new}} \cdot a' = 0.4806 \cdot 0.6216 \approx 0.2988 y=w2newa=0.48060.62160.2988

  1. 新的损失

L ′ = 1 2 ( y ′ − t ) 2 = 1 2 ( 0.2988 − 0 ) 2 ≈ 0.0447 L' = \frac{1}{2}(y' - t)^2 = \frac{1}{2}(0.2988 - 0)^2 \approx 0.0447 L=21(yt)2=21(0.29880)20.0447


结果分析

更新权重后,损失从 0.0484 减少到 0.0447,说明网络朝着最小化损失的方向更新,模型性能有所提升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1556993.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

文件丢失一键找回,四大数据恢复免费版工具推荐!

丢失数据的情况虽然不经常出现,但一旦出现都会让人头疼不已,而这时候,要如何恢复丢失的数据呢?一款免费好用的数据恢复工具就派上用场了!接下来就为大家推荐几款好用的数据恢复工具! 福昕数据恢复 直达链…

Redis list 类型

list类型 类型介绍 列表类型 list 相当于 数组或者顺序表 list内部的编码方式更接近于 双端队列 ,支持头插 头删 尾插 尾删。 需要注意的是,Redis的下标支持负数下标。 比如数组大小为5,那么要访问下标为 -2 的值可以理解为访问 5 - 2 3 …

【韩顺平Java笔记】第8章:面向对象编程(中级部分)【272-284】

272. 包基本介绍 272.1 看一个应用场景 272.2 包的三大作用 272.3 包的基本语法 273. 包原理 274. 包快速入门 在不同的包下面创建不同的Dog类 275. 包命名 276. 常用的包 一个包下,包含很多的类,java 中常用的包有: java.lang.* //lang 包是基本包,默认引入&…

农业政策与市场分析:解读当前政策导向下的农业发展趋势

在快速变化的全球经济格局中,农业作为国家稳定发展的基石,其政策走向与市场动态备受瞩目。本文将深入剖析当前的农业政策背景,探讨其对设计的导向作用,以及市场趋势的反馈与影响,为农业可持续发展提供洞见。 1. 政策背…

【大模型理论篇】大模型相关的周边技术分享-关于《NN and DL》的笔记

本文所要介绍的一本书《Neural Networks and Deep Learning》,该书作者Michael Nielsen,Y Combinator Research的研究员,是多年之前自己看的一本基础书籍,很适合入门了解一些关于深度学习的概念知识,当然也包含了一些小…

MyBatis 批量插入方案

MyBatis 批量插入 MyBatis 插入数据的方法有几种: for 循环,每次都重新连接一次数据库,每次只插入一条数据。 在编写 sql 时用 for each 标签,建立一次数据库连接。 使用 MyBatis 的 batchInsert 方法。 下面是方法 1 和 2 的…

三相逆变器中LCL滤波器分析

1.LCL滤波器 传统三相逆变器使用的是L型滤波器,其设计简单,但也存在着一些问题,如在同样的滤波效果下,L型滤波器电感尺寸、重量较大,成本较高,并且随着电感值的增大,其上的电压降增加比较明显&…

【MySQL必知会】事务

目录 🌈前言🌈 📁 事务概念 📁 事务操作 📁 事务提交方式 📁 隔离级别 📁 MVCC 📂 3个隐藏列字段 📂 undo日志 📂 Read View视图 📁 RR和R…

【GESP】C++一级练习BCQM3028,输入-计算-浮点型格式化输出

目前的几道题主要围绕浮点型的计算和格式化输出。强化基础语法练习。 详解详见:https://www.coderli.com/gesp-1-bcqm3028/ 【GESP】C一级练习BCQM3028,输入-计算-浮点型格式化输出 | OneCoder目前的几道题主要围绕浮点型的计算和格式化输出。强化基础语…

说说BPMN概念及应用

BPMN(Business Process Modeling and Notation)即业务流程建模与标注,是一种由OMG(Object Management Group,对象管理组织)制定的业务流程建模语言。以下是对BPMN标准的详细解释: 一、BPMN的起…

短剧系统源码短剧平台开发(H5+抖小+微小)部署介绍流程

有想法加入国内短剧赛道的请停下脚步,耐心看完此篇文章,相信一定会对您有所帮助的,下面将排序划分每一个步骤,短剧源码、申请资料、服务器选择、部署上架到正常运行等几个方面,整理了一些资料,来为大家举例…

Spring Boot助力医院数据管理

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统,它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等,非常适…

MySQL进阶学习一(2024.10.07版)

2024-10-06 -------------------------------------------------------------------------------------------------------------------------------- 1.一条SQL语句是如何执行的 单进程的多线程模型 MySQL的物理目录 show global variables like "%basedir%"; …

LSTM时序预测 | Python实现LSTM长短期记忆神经网络时间序列预测

本文内容:Python实现LSTM长短期记忆神经网络时间序列预测,使用的数据集为AirPassengers 目录 数据集简介 1.步骤一 2.步骤二 3.步骤三 4.步骤四 数据集简介 AirPassengers 数据集的来源可以追溯到经典的统计和时间序列分析文献。原始数据集由 Box,…

面向对象特性中 继承详解

目录 概念: 定义: 定义格式 继承关系和访问限定符 基类和派生类对象赋值转换: 继承中的作用域: 派生类的默认成员函数 继承与友元: 继承与静态成员: 复杂的菱形继承及菱形虚拟继承: 虚…

VGG16模型实现MNIST图像分类

MNIST图像数据集 MNIST(Modified National Institute of Standards and Technology)是一个经典的机器学习数据集,常用于训练和测试图像处理和机器学习算法,特别是在数字识别领域。该数据集包含了大约 7 万张手写数字图片&#xf…

喜讯 | 攸信技术入选第六批专精特新“小巨人”企业

日前,根据工信部评审结果,厦门市工业和信息化局公示了第六批专精特新“小巨人”企业和第三批专精特新“小巨人”复核通过企业名单,其中,厦门攸信信息技术有限公司进入第六批专精特新“小巨人”企业培育。 “专精特新”企业是指具有…

图像分割恢复方法

传统的图像分割方法主要依赖于图像的灰度值、纹理、颜色等特征,通过不同的算法将图像分割成多个区域。这些方法通常可以分为以下几类: 1.基于阈值的方法 2.基于边缘的方法 3.基于区域的方法 4.基于聚类的方法 下面详细介绍这些方法及其示例代码。 1. 基…

代码随想录--栈与队列--用栈实现队列

队列是先进先出,栈是先进后出。 如图所示: 题目 使用栈实现队列的下列操作: push(x) – 将一个元素放入队列的尾部。 pop() – 从队列首部移除元素。 peek() – 返回队列首部的元素。 empty() – 返回队列是否为空。 示例: MyQueue qu…

draw.io 设置默认字体及添加常用字体

需求描述 draw.io 是一个比较好的开源免费画图软件。但是其添加容器或者文本框时默认的字体是 Helvetica,一般的期刊、会议论文或者学位论文要求的英文字体是 Times New Roman,中文字体是 宋体,所以一般需要在文本字体选项里的下拉列表选择 …