反向传播(Back Propagation,简称BP)

反向传播算法是用于训练神经网络的核心算法之一,它通过计算损失函数(如均方误差或交叉熵)相对于每个权重参数的梯度,来优化神经网络的权重。

1.前向传播(Forward Propagation)

步骤

  • 输入层:接收输入数据。

  • 隐藏层:对输入数据进行加权求和,并通过激活函数得到输出。

  • 输出层:将隐藏层的输出再次加权求和,并通过激活函数得到最终输出。

代码

import torch
import torch.nn as nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()# 定义一个conv1的卷积层 1 input image channel, 6 output channels, 3x3 square convolution kernelself.conv1 = nn.Conv2d(1, 6, 3)def forward(self, x):#在前向传播过程中使用conv1卷积层x=self.conv1(x)return xnet=Net()input_tensor=torch.randn(1,1,32,32)
output_tensor=net(input_tensor)
print(output_tensor.size())
#torch.Size([1, 6, 30, 30])

前向传播主要用于神经网络的预测阶段

作用

  1. 计算神经网络的输出结果,用于预测或计算损失。

  2. 在反向传播中使用,通过计算损失函数相对于每个参数的梯度来优化网络。

2.BP基础之梯度下降算法

梯度下降算法的目标是找到使损失函数 L(\theta)最小的参数 \theta,其核心是沿着损失函数梯度的负方向更新参数,以逐步逼近局部或全局最优解,从而使模型更好地拟合训练数据。

公式

w_{ij}^{new}= w_{ij}^{old} - \alpha \frac{\partial E}{\partial w_{ij}}

其中,\alpha是学习率:

学习率是梯度下降算法中一个重要的参数,决定目标函数能否收敛到局部最小值,以及何时收敛到最小值

学习率太小,每次训练之后的效果太小,增加时间和算力成本。
学习率太大,大概率会跳过最优解,进入无限的训练和震荡中。
解决的方法就是,学习率也需要随着训练的进行而变化。

过程

第一步 初始化参数

第二步 计算梯度

第三步 更新参数

第四步 迭代更新

传统下降方式

批量梯度下降Batch Gradient Descent ,BGD

  • 特点

    • 每次更新参数时,使用整个训练集来计算梯度。

  • 优点

    • 收敛稳定,能准确地沿着损失函数的真实梯度方向下降。

    • 适用于小型数据集。

  • 缺点

    • 对于大型数据集,计算量巨大,更新速度慢。

    • 需要大量内存来存储整个数据集。

  • 公式

\theta := \theta - \alpha \frac{1}{m} \sum_{i=1}^{m} \nabla_\theta L(\theta; x^{(i)}, y^{(i)})

其中,m是训练集样本总数,x^{(i)}, y^{(i)}是第 i 个样本及其标签

随机梯度下降Stochastic Gradient Descent, SGD

可降低每次迭代时的计算代价

  • 特点

    • 每次更新参数时,仅使用一个样本来计算梯度。

  • 优点

    • 更新频率高,计算快,适合大规模数据集。

    • 能够跳出局部最小值,有助于找到全局最优解。

  • 缺点

    • 收敛不稳定,容易震荡,因为每个样本的梯度可能都不完全代表整体方向。

    • 需要较小的学习率来缓解震荡。

  • 公式

\theta := \theta - \alpha \nabla_\theta L(\theta; x^{(i)}, y^{(i)})

其中,x^{(i)}, y^{(i)} 是当前随机抽取的样本及其标签。

小批量梯度下降Mini-batch Gradient Descent ,MGBD

在每轮迭代中,随机均匀采样多个样本组成一个小批量,然后使用这个小批量来计算梯度。

  • 特点

    • 每次更新参数时,使用一小部分训练集(小批量)来计算梯度。

  • 优点

    • 在计算效率和收敛稳定性之间取得平衡。

    • 能够利用向量化加速计算,适合现代硬件(如GPU)。

  • 缺点

    • 选择适当的批量大小比较困难;批量太小则接近SGD,批量太大则接近批量梯度下降。

    • 通常会根据硬件算力设置为32\64\128\256等2的次方。

  • 公式

\theta := \theta - \alpha \frac{1}{b} \sum_{i=1}^{b} \nabla_\theta L(\theta; x^{(i)}, y^{(i)})

其中,b是小批量的样本数量,也就是 batch\_size

影响小梯度批量下降法的主要因素:

梯度估计 

批量大小

v_t = \beta v_{t-1} + (1 - \beta) x_t\alpha

优化下降方式

指数加权平均

Exponential Moving Average,简称EMA,是一种平滑时间序列数据的技术,它通过对过去的值赋予不同的权重来计算平均值。

t=1 : v_0 = x_0

t>1 :v_t = \beta v_{t-1} + (1 - \beta) x_t

import numpy as np
import matplotlib.pyplot as pltdef test():#生成模拟的股市数据np.random.seed(0)days=30#模拟股价的走势randprice=np.random.randn(days)stock_prices=np.cumsum(randprice*2+0.5)+100#cumsum:用于计算张量中元素的累积和(cumulative sum)的函数#计算SMA-简单移动平均window_size=5print(np.ones(window_size) / window_size)simple_moving_avg = np.convolve(stock_prices, np.ones(window_size) / window_size, mode="valid")print(stock_prices)print(simple_moving_avg)"""
np.ones(window_size) / window_size:
np.ones(window_size)生成一个长度为window_size的一维数组,其中所有元素都是1。
/ window_size则将这个数组中的每个元素都除以window_size,这样就得到了一个长度为window_size的数组,其中每个元素都是1/window_size。这个数组可以看作是一个简单的权重向量,表示每个数据点在计算移动平均时所占的比例。np.convolve(stock_prices, np.ones(window_size) / window_size, mode="valid"):
np.convolve函数执行卷积操作。在这个上下文中,它实际上是在计算滑动窗口内的平均值。
第一个参数stock_prices是一个包含了股票价格的时间序列。
第二个参数就是之前创建的权重向量np.ones(window_size) / window_size。
mode="valid"指定只计算那些完全重叠的部分。这意味着对于长度为n的价格序列和长度为m的权重向量,输出序列的长度将是n-m+1。由于我们关心的是完整的窗口,所以使用“valid”模式来确保每个平均值都是基于完整窗口计算得出的。最终的结果simple_moving_avg就是一个数组,其中包含了通过上述方式计算出的移动平均值。每个值都是连续window_size个股票价格的平均值,从而平滑了原始价格序列并帮助识别趋势。"""#计算EMAbeta=0.9ema=np.zeros_like(stock_prices)ema[0]=stock_prices[0]for i in range(1,days):ema[i]=beta*ema[i-1]+(1-beta)*stock_prices[i]# 定义窗口大小plt.figure(figsize=(8,6))# 绘制股市数据、简单移动平均值和指数加权平均值的走势plt.plot(stock_prices,label='stock price',color='b',marker='o')    plt.plot(range(window_size - 1, days),simple_moving_avg,label=f"SMA (N = {window_size})",color="orange",marker="x",)"""
range(window_size - 1, days):这里创建了一个范围列表,从window_size - 1开始,一直到days结束(不包括days)。这是因为简单移动平均线需要至少window_size个数据点才能开始计算,所以第一个点是在第window_size天,因此索引是从window_size - 1开始的。days可能代表总的天数或者数据点的数量。simple_moving_avg:这是包含简单移动平均值的数组或序列,与上面的range对应,用于标记图表上的Y轴值。label=f"SMA (N = {window_size})":这个参数设置了图例的标签文本。在这里,SMA代表简单移动平均,N代表窗口大小window_size,这通常是指用来计算移动平均的数据点数量。marker="x":设置标记样式为"x"。当数据点在图表上被标记出来时,它们将以"x"的形式出现。
整个代码块的作用是画出一条基于给定窗口大小计算得到的简单移动平均线"""plt.plot(ema,label=f'EMA=(beta={beta})',color='r',marker='x')plt.title("Stock Price Trends")plt.xlabel("Days")plt.ylabel("Price")plt.legend()#显示图例plt.show()#显示图表if __name__=="__main__":test()

Momentum

动量(Momentum)是对梯度下降的优化方法,可以更好地应对梯度变化和梯度消失问题,从而提高训练模型的效率和稳定性。

参数更新时在一定程度上保留之前更新的方向,同时又利用当前batch的梯度微调最终的更新方向,简言之就是通过积累之前的动量来加速当前的梯度。

Momentum 算法是对梯度值的平滑调整,但是并没有对梯度下降中的学习率进行优化。

  • 惯性效应: 该方法加入前面梯度的累积,这种惯性使得算法沿着当前的方向继续更新。如遇到鞍点,也不会因梯度逼近零而停滞。

  • 减少震荡: 该方法平滑了梯度更新,减少在鞍点附近的震荡,帮助优化过程稳定向前推进。

  • 加速收敛: 该方法在优化过程中持续沿着某个方向前进,能够更快地穿越鞍点区域,避免在鞍点附近长时间停留。

梯度更新算法包括两个步骤 :

更新动量项:利用当前梯度和历史动量来计算新的动量项。beta通常=0.9

v_{t} = \beta v_{t-1} + (1 - \beta) \nabla_\theta J(\theta_t)

更新参数:利用更新后的动量项来调整权重参数。

v_{t}=\beta v_{t-1}+(1-\beta)\nabla_\theta J(\theta_t) \\ \theta_{t}=\theta_{t-1}-\eta v_{t}

梯度计算:在每个时间步计算当前的梯度,用于更新动量项和权重参数。

AdaGrad

在标准的梯度下降中,每个参数每次迭代时都使用相同的学习率。由于每个参数的维度上收敛速度都不同,因此根据不同参数的收敛情况分别设置学习率。

AdaGrad(Adaptive Gradient)算法每次迭代时根据每个参数的梯度值自适应地调整每个参数的学习率,从而避免全局学习率难以适应所有维度的问题。

AdaGrad算法中,若某个参数的偏导数累乘积比较大,其学习率相对较小;相反,如果其偏导数累乘积较小,其学习率相对较大。但整体随着迭代次数的增加,学习率逐渐缩小。

优点:可以防止梯度过大导致的震荡、适合稀疏数据

缺点:学习率过度衰减、你适合非稀疏数据

AdaGrad过程:

1.初始化

初始化参数\theta和学习率\eta

2.梯度计算

g_t = \nabla_\theta J(\theta_t)

3.累积梯度的平方

G_{t,i} = G_{t-1,i} + g_{t,i}^2

4.参数更新

\theta_{t,i} = \theta_{t-1,i} - \frac{\eta}{\sqrt{G_{t,i} + \epsilon}} g_{t,i}

\eta是全局的初始学习率

\epsilon非常小,通常取10^{-8}

RMSProp

可以在有些情况下避免AdaGrad算法中学习率不断单调下降从而过早衰减的缺点。

RMSprop算法计算每次迭代梯度平方的指数衰减移动平均

RMSprop算法和AdaGrad算法的区别在于Gt的计算由累积方式变成了指数衰减移动平均。在迭代过程中,每个参数的学习率并不是呈衰减趋势,既可以变小也可以变大。

过程:

1.初始化

2.梯度计算

3.更新梯度平方的指数加权平均

4.参数更新

优点

  • 适应性强:RMSProp自适应调整每个参数的学习率,对于梯度变化较大的情况非常有效,使得优化过程更加平稳。

  • 适合非稀疏数据:相比于AdaGrad,RMSProp更加适合处理非稀疏数据,因为它不会让学习率减小到几乎为零。

  • 解决过度衰减问题:通过引入指数加权平均,RMSProp避免了AdaGrad中学习率过快衰减的问题,保持了学习率的稳定性

缺点

依赖于超参数的选择

Adam

Adam算法可以看做动量法和RMSprop算法的结合,不但使用动量作为参数更新方向,而且可以自适应调整学习率。

Adam(Adaptive Moment Estimation)算法将动量法和RMSProp的优点结合在一起:

  • 动量法:通过一阶动量(即梯度的指数加权平均)来加速收敛,尤其是在有噪声或梯度稀疏的情况下。

  • RMSProp:通过二阶动量(即梯度平方的指数加权平均)来调整学习率,使得每个参数的学习率适应其梯度的变化。

Adam过程

1.初始化

2.梯度计算

3.一阶动量估计-梯度的指数加权平均

4.二阶动量估计-梯度平方的指数加权平均

5.偏差校正

6.参数更新

优点:高效稳健 避免全局学习率设定不合适的问题  适用各种深度学习模型

缺点:超参数敏感 过拟合风险

3.BP基础之链式法则

链式求导法则(Chain Rule)是微积分中的一个重要法则,用于求复合函数的导数。在深度学习中,链式法则是反向传播算法的基础,这样就可以通过分层的计算求得损失函数相对于每个参数的梯度。

import torch
import torch.nn as nndef test():x=torch.tensor(1.0)w=torch.tensor(0.5,requires_grad=True)b=torch.tensor(0.5,requires_grad=True)y=(torch.exp(-(w*x+b))+1)**-1y.backward()print(w.grad)#tensor(0.1966)if __name__=="__main__":test()

4.反向传播

反向传播(BP)通过计算损失函数相对于每个参数的梯度来调整权重,使模型在训练数据上的表现逐渐优化。反向传播结合了链式求导法则和梯度下降算法,是神经网络模型训练过程中更新参数的关键步骤。

为了训练神经网络,首先要将权重随机初始化一个接近0的,范围在[-ε,ε]之间的数,然后进行反向传播,再进行梯度检验,最后使用梯度下降,或者其它高级优化算法 来最小化代价函数J。

步骤

  1. 前向传播:得到预测值。

  2. 计算损失:通过损失函数$$ L(y{\text{pred}}, y{\text{true}}) $$ 计算预测值和真实值的差距。

  3. 梯度计算:反向传播的核心是计算损失函数 $$L$$ 对每个权重和偏置的梯度。

  4. 更新参数:一旦得到每层梯度,就可以使用梯度下降算法来更新每层的权重和偏置,使得损失逐渐减小。

  5. 迭代训练:将前向传播、梯度计算、参数更新的步骤重复多次,直到损失函数收敛或达到预定的停止条件。

代码实现

import torch
import torch.nn as nn
import torch.optim as optim class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.linear1 = nn.Linear(2, 2)  # 输入2个特征,输出2个特征self.linear2 = nn.Linear(2, 1)  # 输入2个特征,输出1个特征# 正确的权重和偏置初始化self.linear1.weight.data = torch.tensor([[1.0, 2.0], [0.0, 1.0]])self.linear1.bias.data = torch.tensor([0.0, 0.0])self.linear2.weight.data = torch.tensor([[3.0]])self.linear2.bias.data = torch.tensor([0.0])def forward(self, x):x = self.linear1(x)x = torch.sigmoid(x)x = self.linear2(x)x = torch.sigmoid(x)return xif __name__ == '__main__':# 输入张量形状为(1, 2)inputs = torch.tensor([[0.1, 0.8]])# 目标张量形状为(1, 1)target = torch.tensor([[0.5]])net = Net()output = net(inputs)# 计算损失loss = torch.sum((output - target)**2) / 2# 定义优化器optimizer = optim.SGD(net.parameters(), lr=0.01)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 打印梯度print("Linear1 weight gradient before update:")print(net.linear1.weight.grad.data)print("Linear2 weight gradient before update:")print(net.linear2.weight.grad.data)# 更新参数optimizer.step()# 打印更新后的网络参数print("Updated network parameters:")print(net.state_dict())

没找到这块代码哪里出错了 但是运行不出来。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1537717.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

孙怡带你深度学习(2)--PyTorch框架认识

文章目录 PyTorch框架认识1. Tensor张量定义与特性创建方式 2. 下载数据集下载测试展现下载内容 3. 创建DataLoader(数据加载器)4. 选择处理器5. 神经网络模型构建模型 6. 训练数据训练集数据测试集数据 7. 提高模型学习率 总结 PyTorch框架认识 PyTorc…

计算机人工智能前沿进展-大语言模型方向-2024-09-13

计算机人工智能前沿进展-大语言模型方向-2024-09-13 1. OneEdit: A Neural-Symbolic Collaboratively Knowledge Editing System Authors: Ningyu Zhang, Zekun Xi, Yujie Luo, Peng Wang, Bozhong Tian, Yunzhi Yao, Jintian Zhang, Shumin Deng, Mengshu Sun, Lei Liang, Z…

基于51单片机的直流数字电流表proteus仿真

地址: https://pan.baidu.com/s/1adZbhgOBvvg0KsCO6_ZiAw 提取码:1234 仿真图: 芯片/模块的特点: AT89C52/AT89C51简介: AT89C52/AT89C51是一款经典的8位单片机,是意法半导体(STMicroelectro…

高级I/O知识分享【5种IO模型 || select || poll】

博客主页:花果山~程序猿-CSDN博客 文章分栏:Linux_花果山~程序猿的博客-CSDN博客 关注我一起学习,一起进步,一起探索编程的无限可能吧!让我们一起努力,一起成长! 目录 一,前文 2&a…

在Unity UI中实现UILineRenderer组件绘制线条

背景介绍 在Unity的UI系统中,绘制线条并不像在3D世界中那样直观(使用Unity自带的LineRender组件在UI中连线并不方便,它在三维中更合适)。没有内置的工具来处理这种需求。如果你希望在UI元素之间绘制连接线(例如在UI上连接不同的图标或控件)&a…

20240918 每日AI必读资讯

o1突发内幕曝光?谷歌8月论文已揭示原理,大模型光有软件不存在护城河 - 谷歌DeepMind一篇发表在8月的论文,揭示原理和o1的工作方式几乎一致 - 谷歌DeepMind这篇论文的题目是:优化LLM测试时计算比扩大模型参数规模更高效。 - Op…

828华为云征文 | 云服务器Flexus X实例:one-api 部署,支持众多大模型

目录 一、one-api 介绍 二、部署 one-api 2.1 拉取镜像 2.2 部署 one-api 三、运行 one-api 3.1 添加规则 3.2 运行 one-api 四、添加大模型 API 4.1 添加大模型 API 五、总结 本文通过 Flexus云服务器X实例 部署 one-api。Flexus云服务器X实例是新一代面向中小企业…

拥控算法BBR入门1

拥塞控制算法只与本地有关 一个TCP会话使用的拥塞控制算法只与本地有关。 两个TCP系统可以在TCP会话的两端使用不同的拥塞控制算法 Bottleneck Bandwidth and Round-trip time Bottleneck 瓶颈 BBR models the network to send as fast as the available bandwidth and is 2…

Java | Leetcode Java题解之第414题第三大的数

题目: 题解: class Solution {public int thirdMax(int[] nums) {Integer a null, b null, c null;for (int num : nums) {if (a null || num > a) {c b;b a;a num;} else if (a > num && (b null || num > b)) {c b;b num;…

驱动器磁盘未格式化危机:专业数据恢复实战指南

认识危机:驱动器中的磁盘未被格式化 在日常的数字生活中,我们时常依赖于各种存储设备来保存重要的文件、照片、视频等数据。然而,当某一天你尝试访问某个驱动器或外接硬盘时,突然弹出的“驱动器中的磁盘未被格式化。您想现在格式…

floodfill+DFS(2)

文章目录 太平洋大西洋流水问题扫雷游戏迷路的机器人 太平洋大西洋流水问题 class Solution { public:vector<vector<int>> res;int m 0, n 0;vector<vector<int>> pacificAtlantic(vector<vector<int>>& heights) {m heights.size…

iOS 18 正式上線,但 Apple Intelligence 還要再等一下

在 iPhone 16 即將正式開賣之際&#xff0c;Apple 如約上線了 iOS 18。雖然今年的重頭戲 Apple Intelligence 還要等下月的 iOS 18.1 才會有&#xff0c;但自訂主畫面和全新的鎖定頁面、控制中心等特性已可在最新的版本中體驗。除此之外&#xff0c;相簿、訊息、地圖、Safari 等…

React学习day07-ReactRouter-抽象路由模块、路由导航、路由导航传参、嵌套路由、默认二级路由的设置、两种路由模式

14、ReactRouter续 &#xff08;2&#xff09;抽象路由模块 1&#xff09;新建page文件夹&#xff0c;存放组件 组件内容&#xff1a; 2&#xff09;新建router文件夹&#xff0c;在其下创建实例 3&#xff09;实例导入&#xff0c;使用 4&#xff09;效果 &#xff08;3&…

佛山网站制作与设计

佛山网站制作与设计 在当今数字化时代&#xff0c;网站已成为企业展示形象、推广产品和服务的重要窗口。佛山作为一个经济迅速发展的城市&#xff0c;其网站制作与设计也日益受到重视。优质的网站不仅能提升企业的品牌形象&#xff0c;更是实现商业价值的重要工具。 一、网站制…

cout无法正常显示中文

cout无法正常显示中文 虽然你使用了buf.length()来指定写入的字节数&#xff0c;但是在包含中文字符&#xff08;UTF-8编码下每个中文字符占用3个字节&#xff09;的情况下&#xff0c;直接使用length()可能不会正确反映实际的字节数&#xff0c;因为它给出的是字符数而非字节…

RK3568平台(文件系统篇)VFS虚拟文件系统

一.VFS虚拟文件系统简介 为什么 Linux 内核的文件系统类型那么多,都能挂载上呢?为什么系统里可以直接 mount 其他文件系统呢?为什么 Linux 的虚拟文件系统这么强大?这得益于它的数据结构设计得十分精妙。 为支持各种本机文件系统,且在同时允许访问其他操作系统的文件,L…

gitee远程仓库OPEN GIT BASH HERE从错误中学习

推荐一个ai软件&#xff08;搜索器搜索kimi&#xff09;&#xff0c;是一个ai&#xff0c;有什么错误跟着一步步解决就可以了 当你创建一个仓库 会出现这些 打开这个窗口跟着敲就行了 到这里为止我还没出现错误&#xff0c;后面我把remote add添加远程仓库的地址输错地址了 所…

C++第七节课 运算符重载

一、运算符重载 并不是所有情况下都需要运算符重载&#xff0c;要看这个运算符对这个类是否有意义&#xff01; 例如&#xff1a;日期减日期可以求得两个日期之间的天数&#xff1b;但是日期 日期没有意义&#xff01; #include<iostream> using namespace std; clas…

文档内容识别系统源码分享

文档内容识别检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Vis…

鸿蒙开发之ArkUI 界面篇 五

Image 图片组件&#xff0c;用专门用于显示图片 语法&#xff1a;Image(图片源)&#xff0c;这里可以是网络、也可以是本地的图片 例如&#xff1a;Image(https://wxls-cms.oss-cn-hangzhou.aliyuncs.com/online/2024-04-18/218da022-f4bf-456a-99af-5cb8e157f7b8.jpg)效果如下…