采用自适应调整参数的 BP 网络学习改进算法详解

采用自适应调整参数的 BP 网络学习改进算法详解

一、引言

BP(Back Propagation)神经网络是一种广泛应用于解决复杂非线性问题的有效工具,如模式识别、函数逼近、数据分类等领域。然而,传统的 BP 算法存在一些局限性,其中学习率和其他参数的固定设置往往导致训练效率低下、收敛速度慢甚至可能陷入局部最优解。为了克服这些问题,采用自适应调整参数的 BP 网络学习改进算法应运而生。本文将详细阐述这种改进算法的原理、步骤以及通过丰富的代码示例展示其实现过程。

二、传统 BP 网络学习算法概述

(一)网络结构与前向传播

BP 神经网络通常由输入层、若干隐藏层和输出层组成。假设输入层有 n n n 个神经元,输入向量为 x = ( x 1 , x 2 , ⋯ , x n ) \mathbf{x}=(x_1,x_2,\cdots,x_n) x=(x1,x2,,xn);中间有 l l l 个隐藏层,第 i i i 个隐藏层有 h i h_i hi 个神经元;输出层有 m m m 个神经元,输出向量为 y = ( y 1 , y 2 , ⋯ , y m ) \mathbf{y}=(y_1,y_2,\cdots,y_m) y=(y1,y2,,ym)

在前向传播过程中,对于任意一层(以第 j j j 层为例),神经元的输入为:

n e t k j = ∑ i w i k j − 1 o i j − 1 + b k j net_{k}^{j}=\sum_{i}w_{ik}^{j - 1}o_{i}^{j - 1}+b_{k}^{j} netkj=iwikj1oij1+bkj

其中, w i k j − 1 w_{ik}^{j - 1} wikj1 是第 j − 1 j - 1 j1 层神经元 i i i 到第 j j j 层神经元 k k k 的连接权重, o i j − 1 o_{i}^{j - 1} oij1 是第 j − 1 j - 1 j1 层神经元 i i i 的输出, b k j b_{k}^{j} bkj 是第 j j j 层神经元 k k k 的偏置。神经元的输出经过激活函数 f ( ⋅ ) f(\cdot) f() 处理,如常见的 Sigmoid 函数 f ( x ) = 1 1 + e − x f(x)=\frac{1}{1 + e^{-x}} f(x)=1+ex1,即 o k j = f ( n e t k j ) o_{k}^{j}=f(net_{k}^{j}) okj=f(netkj)

(二)误差计算与反向传播

对于训练样本集 { ( x ( p ) , t ( p ) ) } p = 1 P \{(\mathbf{x}^{(p)},\mathbf{t}^{(p)})\}_{p = 1}^{P} {(x(p),t(p))}p=1P,其中 x ( p ) \mathbf{x}^{(p)} x(p) 是第 p p p 个输入样本, t ( p ) \mathbf{t}^{(p)} t(p) 是对应的目标输出。误差通常采用均方误差(MSE)衡量:

E = 1 2 P ∑ p = 1 P ∑ k = 1 m ( y k ( p ) − t k ( p ) ) 2 E=\frac{1}{2P}\sum_{p = 1}^{P}\sum_{k = 1}^{m}(y_{k}^{(p)}-t_{k}^{(p)})^{2} E=2P1p=1Pk=1m(yk(p)tk(p))2

在反向传播过程中,根据误差函数对权重的梯度来更新权重。以输出层到最后一个隐藏层的权重更新为例,通过链式法则计算权重 w k m l w_{km}^{l} wkml 的梯度:

∂ E ∂ w k m l = ∑ p = 1 P ∂ E ∂ y m ( p ) ∂ y m ( p ) ∂ n e t m l ∂ n e t m l ∂ w k m l \frac{\partial E}{\partial w_{km}^{l}}=\sum_{p = 1}^{P}\frac{\partial E}{\partial y_{m}^{(p)}}\frac{\partial y_{m}^{(p)}}{\partial net_{m}^{l}}\frac{\partial net_{m}^{l}}{\partial w_{km}^{l}} wkmlE=p=1Pym(p)Enetmlym(p)wkmlnetml

然后根据梯度下降法更新权重:

w k m l ( t + 1 ) = w k m l ( t ) − η ∂ E ∂ w k m l w_{km}^{l}(t + 1)=w_{km}^{l}(t)-\eta\frac{\partial E}{\partial w_{km}^{l}} wkml(t+1)=wkml(t)ηwkmlE

其中, η \eta η 是学习率, t t t 表示训练次数。类似地,可以计算其他层之间的权重更新公式。

三、传统 BP 算法的问题分析

(一)固定学习率问题

  1. 收敛速度不稳定
    在传统 BP 算法中,学习率是一个固定值。如果学习率设置过大,可能会导致权重更新幅度过大,使得网络在训练过程中跳过误差曲面的最小值点,甚至可能导致训练过程不收敛。相反,如果学习率设置过小,权重更新缓慢,会导致训练时间过长,尤其在处理复杂的高维数据时问题更加突出。
  2. 对不同训练阶段适应性差
    在训练初期,误差曲面通常较为平坦,较大的学习率有助于快速接近最优解区域。但在训练后期,当接近最优解时,误差曲面变得更加复杂,需要较小的学习率来精确调整权重以避免越过最小值点。固定学习率无法满足这种不同训练阶段的需求。

(二)其他固定参数的局限性

除了学习率,BP 算法中的一些其他参数,如动量项(如果有)、权重初始化范围等,在传统算法中通常也是固定的。这些固定参数在不同的数据集和应用场景下可能无法达到最优的训练效果,进一步限制了网络的性能。

四、自适应调整参数的 BP 网络学习改进算法原理

(一)自适应学习率调整策略

  1. 基于梯度信息调整
    一种常见的方法是根据每次迭代中误差函数对权重的梯度大小来调整学习率。如果当前梯度较大,说明当前权重更新方向可能距离最优解较远,可以适当增大学习率以加快收敛速度;如果梯度较小,接近最优解区域,则减小学习率以避免跳过最小值。例如,可以采用以下简单的调整规则:

η ( t + 1 ) = { γ η ( t ) , if  ∥ ∂ E ∂ w ( t ) ∥ > θ β η ( t ) , if  ∥ ∂ E ∂ w ( t ) ∥ ≤ θ \eta(t + 1)=\begin{cases} \gamma\eta(t), & \text{if } \|\frac{\partial E}{\partial w}(t)\| > \theta\\ \beta\eta(t), & \text{if } \|\frac{\partial E}{\partial w}(t)\| \leq \theta \end{cases} η(t+1)={γη(t),βη(t),if wE(t)>θif wE(t)θ

其中, γ > 1 \gamma > 1 γ>1 β < 1 \beta < 1 β<1 是调整系数, θ \theta θ 是一个阈值, t t t 是训练次数。

  1. 基于训练性能调整
    还可以根据网络在训练集或验证集上的性能来调整学习率。如果在连续几个训练迭代中误差没有明显下降,甚至有上升趋势,可以减小学习率;反之,如果误差下降速度较快,可以适当增大学习率。例如,计算连续 k k k 次迭代的误差变化率:

Δ E r a t e = E ( t − k ) − E ( t ) E ( t − k ) \Delta E_{rate}=\frac{E(t - k)-E(t)}{E(t - k)} ΔErate=E(tk)E(tk)E(t)

根据 Δ E r a t e \Delta E_{rate} ΔErate 的值来调整学习率,如当 Δ E r a t e > δ \Delta E_{rate} > \delta ΔErate>δ δ \delta δ 为正阈值)时增大学习率,当 Δ E r a t e < − δ \Delta E_{rate} < -\delta ΔErate<δ 时减小学习率。

(二)其他参数的自适应调整

  1. 自适应动量项调整
    对于带有动量项的 BP 算法,动量系数也可以自适应调整。在训练初期,为了加速收敛,可以设置较大的动量系数,使权重更新具有较大的惯性;在训练后期,接近最优解时,减小动量系数,避免因惯性过大而错过最小值。例如,可以根据训练次数或当前误差值来调整动量系数。

  2. 自适应权重初始化范围调整
    根据输入数据的特征和网络规模,自适应地确定权重初始化的范围。例如,如果输入数据的方差较大,可以适当增大权重初始化范围;如果网络层数较多,可以适当减小初始化范围以避免梯度消失或爆炸问题。

五、自适应调整参数的 BP 网络学习改进算法步骤

(一)初始化

  1. 权重和偏置初始化
    根据自适应权重初始化范围调整策略确定合适的范围,随机初始化网络的连接权重 w i j w_{ij} wij 和偏置 b j b_{j} bj
  2. 学习率和其他参数初始化
    初始化学习率 η \eta η、动量系数(如果有)等其他参数。同时,设置用于自适应调整参数的相关参数,如调整系数、阈值等。

(二)训练过程

  1. 前向传播
    对于每个训练样本,按照传统 BP 算法的前向传播方式计算网络的输出。
  2. 误差计算与反向传播
    计算当前样本的误差,并根据误差进行反向传播计算梯度。
  3. 参数自适应调整
    • 根据当前梯度信息或训练性能,按照自适应学习率调整策略调整学习率。
    • 如果存在动量项,根据训练阶段调整动量系数。
  4. 权重更新
    根据调整后的学习率(和动量系数,如果有)更新权重和偏置。
  5. 重复训练
    对所有训练样本重复上述步骤,完成一次训练迭代(epoch)。多次重复训练迭代,直到满足训练停止条件,如达到预定的训练次数、误差小于某个阈值或验证集误差不再下降等。

六、代码示例

以下是一个简单的自适应调整学习率的 BP 网络 Python 代码示例:

import numpy as np# Sigmoid 激活函数
def sigmoid(x):return 1 / (1 + np.exp(-x))# Sigmoid 函数的导数
def sigmoid_derivative(x):return x * (1 - x)class NeuralNetwork:def __init__(self, input_size, hidden_size, output_size):self.input_size = input_sizeself.hidden_size = hidden_sizeself.output_size = output_size# 随机初始化权重self.W1 = np.random.rand(self.input_size, self.hidden_size)self.b1 = np.zeros((1, self.hidden_size))self.W2 = np.random.rand(self.hidden_size, self.output_size)self.b2 = np.zeros((1, self.output_size))self.learning_rate = 0.1  # 初始学习率self.gamma = 1.2  # 学习率增大系数self.beta = 0.8  # 学习率减小系数self.theta = 0.5  # 梯度阈值def forward_propagation(self, X):self.z1 = np.dot(X, self.W1) + self.b1self.a1 = sigmoid(self.z1)self.z2 = np.dot(self.a1, self.W2) + self.b2self.a2 = sigmoid(self.z2)return self.a2def back_propagation(self, X, y):m = X.shape[0]dZ2 = self.a2 - ydW2 = np.dot(self.a1.T, dZ2) / mdb2 = np.sum(dZ2, axis=0, keepdims=True) / mdZ1 = np.dot(dZ2, self.W2.T) * sigmoid_derivative(self.a1)dW1 = np.dot(X.T, dZ1) / mdb1 = np.sum(dZ1, axis=0, keepdims=True) / mreturn dW1, db1, dW2, db2def update_learning_rate(self, dW1, dW2):gradient_norm1 = np.mean(np.sqrt(np.sum(dW1 ** 2, axis=1)))gradient_norm2 = np.mean(np.sqrt(np.sum(dW2 ** 2, axis=1)))if gradient_norm1 > self.theta or gradient_norm2 > self.theta:self.learning_rate *= self.gammaelse:self.learning_rate *= self.betadef update_weights(self, dW1, db1, dW2, db2):self.W1 -= self.learning_rate * dW1self.b1 -= self.learning_rate * db1self.W2 -= self.learning_rate * dW2self.b2 -= self.learning_rate * db2def train(self, X, y, epochs):for epoch in range(epochs):output = self.forward_propagation(X)dW1, db1, dW2, db2 = self.back_propagation(X, y)self.update_learning_rate(dW1, dW2)self.update_weights(dW1, db1, dW2, db2)if epoch % 100 == 0:error = np.mean((output - y) ** 2)print(f'Epoch {epoch}: Error = {error}, Learning Rate = {self.learning_rate}')

你可以使用以下方式测试这个神经网络:

# 示例用法
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = np.array([[0], [1], [1], [0]])
neural_network = NeuralNetwork(2, 3, 1)
neural_network.train(X, y, 1000)

七、总结

采用自适应调整参数的 BP 网络学习改进算法通过动态调整学习率和其他相关参数,有效地克服了传统 BP 算法中固定参数带来的问题。这种算法能够根据训练过程中的实际情况,如梯度大小、训练性能等,自动调整参数,提高了训练的效率和稳定性,减少了陷入局部最优解的可能性。代码示例展示了自适应学习率调整的简单实现,在实际应用中,可以进一步扩展和优化自适应调整策略,以适应更复杂的数据集和应用场景,从而进一步提升 BP 神经网络的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/19311.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

PlantUML——时序图

PlantUML时序图 背景 时序图&#xff08;Sequence Diagram&#xff09;&#xff0c;又名序列图、循序图&#xff0c;是一种UML交互图&#xff0c;用于描述对象之间发送消息的时间顺序&#xff0c;显示多个对象之间的动态协作。时序图的使用场景非常广泛&#xff0c;几乎各行各…

算法——链表相交(leetcode23)

链表相交这题就是找出两个相交链表相交的节点并返回 如上图假设上方第一个节点是链表A的头结点下方第一个节点是链表B的头结点 解法有以下两种 方法一(移动长链表指针后同步移动两个链表的指针直至相等) 也就是先遍历链表A和链表B的长度接着得到链表A和B长度的差值然后领长链…

STM32单片机锁死

自己画了一块stm32f407板子&#xff0c;外部晶振用了25MHz&#xff0c;烧写了8MHz的程序&#xff0c;第一次烧写成功&#xff0c;第二次开始识别不到芯片&#xff0c;第一次烧写成功由于外部晶振为25Hz&#xff0c;芯片内频率计算器却是按照8MHz写的&#xff0c;所以得出最后的…

Windows文件资源管理器增强工具

引言&#xff1a; 资源管理器在我们使用电脑时是经常用到的&#xff0c;各种文件资源等的分类整理都离不开它。但是Windows Explorer确实不好用&#xff0c;不智能&#xff0c;不符合人体工程学。特别是在一些场合&#xff0c;在打开的一堆文件夹里&#xff0c;想从中找到自己要…

聚类中3个解空间的描述

深度学习中做分类任务时&#xff0c;我们常常根据最后的全连接层得到一组向量A&#xff08;比如&#xff1a;[0.9, 0.7, 0.2]&#xff09;&#xff0c;这组向量经过归一化得到向量B(比如&#xff1a;[0.5&#xff0c; 0.3&#xff0c; 0.2])&#xff0c;再根据B向量采用概率最大…

Empirical analysis of hardware-assisted GPU virtualization

​ 年份&#xff1a;2019 作者&#xff1a;Anshuj Garg 会议&#xff1a;ESCI 出版商&#xff1a;IEEE 摘要 本篇文章对vGPU虚拟化的性能开销、调度算法的影响、同构与异构工作负载的干扰效应&#xff0c;以及PCI直通与vGPU的性能差异进行了研究。结果表明&#xff0c;vGP…

Java面试题2024-Java基础

Java基础 1、 Java语言有哪些特点 1、简单易学、有丰富的类库 2、面向对象&#xff08;Java最重要的特性&#xff0c;让程序耦合度更低&#xff0c;内聚性更高&#xff09; 3、与平台无关性&#xff08;JVM是Java跨平台使用的根本&#xff09; 4、可靠安全 5、支持多线程 2、…

【案例分享】运用 Infragistics Ultimate UI 让工业物联网 IIoT 数据流更易于访问

客户概况 贝克休斯旗下的 Bently Nevada 是状态监测和资产保护领域的全球领导者。该公司拥有 60 多年的专业知识&#xff0c;在全球安装了超过 600 万个传感器和 100,000 个机架监测系统。 如今&#xff0c;Bently Nevada的开发团队正在使用现代 UI 工具包来增强他们的系统&a…

PHM技术:基于支持向量机的智能故障诊断 | 行星齿轮箱智能故障诊断

目录 1.数据获取 2.特征提取与选择 3.健康状态识别 1.数据获取 用的行星齿轮箱数据采集自图1中的多级齿轮传动系统实验台中&#xff0c;在实验过程中&#xff0c;分别模拟了8种行星齿轮箱的健康状态&#xff0c;包括正常、第一级太阳轮点蚀、第一级太阳轮齿根裂纹、第一级…

推荐一款Windows系统精简工具:NTLite

NTLite是一款可以对Windows系统优化的安装工具&#xff0c;使用这款完全中文的NTLite授权注册版让你不会因为注册或者语言导致无法正常的使用&#xff0c;如果你正需要马上下载使用吧。 NTLite基本简介 NTLite 中文版可以用来做什么&#xff0c;它其实是一款 Windows 系统精简…

ESP-IDF VScode 项目构建/增加组件 新手友好!!!

项目构建 1.新建文件夹&#xff0c;同时在该文件夹内新建.c和.h文件 如图所示&#xff0c;在components中新建ADC_User.c、ADC_User.h、CMakeLists.txt文件。当然这里你也可以不在components文件夹内新建文件&#xff0c;下面会说没有在components文件夹内新建文件构建项目的方…

Node Exporter 可观测性最佳实践

Node Exporter 介绍 Node Exporter 是一个开源的 Prometheus 指标收集器&#xff0c;它提供了大量关于宿主机系统的关键指标&#xff0c;如 CPU、内存、磁盘和网络使用情况。在 Kubernetes 环境中&#xff0c;Node Exporter 对于监控集群节点的健康状况至关重要。本文将介绍如…

Spring Boot汽车资讯:科技与速度的交响

3系统分析 3.1可行性分析 通过对本汽车资讯网站实行的目的初步调查和分析&#xff0c;提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本汽车资讯网站采用SSM框架&#xff0c;JAVA作为开发语言&#…

外卖跑腿小程序源码如何满足多样需求?

外卖跑腿平台已经成了当代年轻人的便捷之选&#xff0c;校园中也不例外&#xff0c;那么外卖、跑腿小程序就需要满足用户多样化的需求&#xff0c;而这背后的源码扮演者最重要的角色。 用户类型的多样性 1.对上班族而言&#xff0c;他们希望外卖小程序能够快速下单、准确配送…

GeeRPC第一天 服务端与消息编码(1)

RPC 1. 系统架构图解释&#xff08;Graph&#xff09; 架构层次 RPC框架核心功能&#xff1a;这是系统的最上层&#xff0c;涵盖了框架的主要功能模块&#xff0c;直接与底层服务和用户交互。 服务层&#xff1a;主要负责服务的注册、发现和治理。 服务注册&#xff1a;将服务…

如何在谷歌浏览器中开启离线模式

在数字化时代&#xff0c;互联网已经成为我们生活中不可或缺的一部分。然而&#xff0c;有时候我们可能会遇到没有网络连接的情况&#xff0c;这时谷歌浏览器的离线模式就显得尤为重要。本教程将详细介绍如何在谷歌浏览器中轻松开启离线模式&#xff0c;并附带一些相关教程指南…

【进阶系列】正则表达式 #匹配

正则表达式 正则表达式是一个特殊的字符序列&#xff0c;它能帮助你方便的检查一个字符串是否与某种模式匹配。re模块使 Python 语言拥有全部的正则表达式功能。 一个正则表达式的匹配工具&#xff1a;regex101: build, test, and debug regex s "C:\\a\\b\\c" pri…

C++使用Alglib数学库进行非线性最小二乘拟合

目录 一、前言 二、主要函数分析 2.1 lsfitcreatef 2.2 lsfitsetcond 2.3 lsfitfit 2.4 lsfitresults 三、基础代码实现 3.1 定义待拟合函数 3.2 数据拟合 四、可视化代码实现 4.1 拟合h文件 4.2 拟合cpp文件 4.2 代码实验 一、前言 本文记录基于Alglib进行非线性…

Spring Boot汽车世界:资讯与技术的交汇

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常…

算法--“找零方案”问题

def main():d [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0] # 存储各种硬币的面值d_num [] # 存储每种硬币的数量total_money 0 # 收银员拥有的总金额# 输入每种硬币的数量temp input(请输入每种零钱的数量&#xff08;以空格分隔&#xff09;:)d_num0 temp.split() # 以空…