梯度下降法以及 Python 实现

文章目录

1. 引言

梯度下降法,可以根据微分求出的斜率计算函数的最小值。
在人工智能中,经常被应用于学习算法。

2. 梯度法

梯度法 是根据函数的微分值(斜率)搜索最小值的算法。

梯度下降法也是一种梯度法,它通过向最陡方向下降来查找最小值。

给定一个多变量函数:
f ( x ) = f ( x 1 , x 2 , … , x i , … , x n ) . f(x) = f(x_1, x_2, \dots, x_i, \dots, x_n). f(x)=f(x1,x2,,xi,,xn).
首先为 x x x 赋予一个合适的初始值,然后通过下面的表达式进行更新:
x i t + 1 = x i t − η ∂ f ( x ) ∂ x i . x^{t+1}_i = x^{t}_i - \eta \frac{\partial f(x)}{\partial x_i}. xit+1=xitηxif(x).

其中, ∂ f ( x ) ∂ x i \displaystyle \frac{\partial f(x)}{\partial x_i} xif(x) 表示函数 f ( x ) f(x) f(x) 对变量 x i x_i xi 的偏导数。 x i t x^{t}_i xit 表示第 t t t 次迭代时变量 x i x_i xi 的取值, x i t + 1 x^{t+1}_i xit+1 表示第 t + 1 t+1 t+1 次迭代时变量 x i x_i xi 的取值。需要说明的是, t t t 是一个非负整数,也即是 t ∈ N t \in \mathbb{N} tN

η \eta η 是一个重要的参数,被称为学习系数或学习率的常数。 η \eta η 决定了 x i x_i xi 的更新速度。可以理解为,一个人 P 要从 A 点走到 B 点,, η \eta η 就是 P 走路时每一步的跨步大小,也称为步长。

根据该表达式, ∂ f ( x ) ∂ x i \displaystyle \frac{\partial f(x)}{\partial x_i} xif(x) 越大,也即是坡度越陡, x i x_i xi 值的变化就越大。

重复此操作,直到 f ( x ) f(x) f(x) 停止变化,那么此时 f ( x ) f(x) f(x) 的值就是 min ⁡ f ( x ) \min f(x) minf(x)

3. 例子

给定一个单变量函数 f ( x ) f(x) f(x)
f ( x ) = x 2 − 2 x . f(x)= x^2 - 2x. f(x)=x22x.
f ( x ) f(x) f(x) 的最小值。

:函数 f ( x ) f(x) f(x) 的导数记为 f ′ ( x ) f'(x) f(x)
f ′ ( x ) = d f ( x ) d x = 2 x − 2. f'(x)=\frac{\mathrm{d} f(x)}{\mathrm{d} x}=2x-2. f(x)=dxdf(x)=2x2.
f ′ ( x ) = 0 f'(x)=0 f(x)=0,则
f ′ ( x ) = 0 ⇒ 2 x − 2 = 0 x = 1. \begin{aligned} f'(x) =0 \Rightarrow 2x-2 & = 0 \\ x & = 1. \\ \end{aligned} f(x)=02x2x=0=1.
即当 x = 1 x=1 x=1 处, f ( x ) f(x) f(x) 的导数 f ′ ( x ) f'(x) f(x) 为 0。

x = 1 x=1 x=1 带入到 f ( x ) f(x) f(x) 中,得到:
f min ⁡ ( x ) = f ( x = 1 ) = 1 2 − 2 ∗ 1 = − 1. f_{\min}(x)=f(x=1)=1^2-2*1=-1. fmin(x)=f(x=1)=1221=1.

f ( x ) f(x) f(x) 的最小值在 x = 1 x=1 x=1 处取得,最小值为 -1。


下面通过模拟梯度下降法来求解。

假设 x x x 的初始值为 2,即 x 0 = 2 x^0=2 x0=2,令学习率 η = 0.1 \eta=0.1 η=0.1

次数 t t t变量 x t x^t xt导数 f ′ ( x t ) = 2 x t − 2 f'(x^t)=2x^t-2 f(xt)=2xt2函数 f ( x t ) = ( x t ) 2 − 2 x t f(x^t)=(x^t)^2-2x^t f(xt)=(xt)22xt更新 x t + 1 x^{t+1} xt+1
0 x 0 = 2 x^0=2 x0=2 f ′ ( x 0 ) = 2 ∗ 2 − 2 = 2 f'(x^0)=2*2-2=2 f(x0)=222=2 f ( x 0 ) = 2 2 − 2 ∗ 2 = 0 f(x^0)=2^2-2*2=0 f(x0)=2222=0 x 1 = 2 − 0.1 ∗ 2 = 1.8 x^1=2-0.1*2=1.8 x1=20.12=1.8
1 x 1 = 1.8 x^1=1.8 x1=1.8 f ′ ( x 1 ) = 2 ∗ 1.8 − 2 = 1.6 f'(x^1)=2*1.8-2=1.6 f(x1)=21.82=1.6 f ( x 1 ) = 1. 6 2 − 2 ∗ 1.6 = − 0.64 f(x^1)=1.6^2-2*1.6=-0.64 f(x1)=1.6221.6=0.64 x 2 = 1.8 − 0.1 ∗ 1.6 = 1.64 x^2=1.8-0.1*1.6=1.64 x2=1.80.11.6=1.64
2 x 2 = 1.64 x^2=1.64 x2=1.64 f ′ ( x 2 ) = 2 ∗ 1.64 − 2 = 1.28 f'(x^2)=2*1.64-2=1.28 f(x2)=21.642=1.28 f ( x 2 ) = 1.6 4 2 − 2 ∗ 1.64 = − 0.5904 f(x^2)=1.64^2-2*1.64=-0.5904 f(x2)=1.64221.64=0.5904 x 3 = 1.64 − 0.1 ∗ 1.28 = 1.512 x^3=1.64-0.1*1.28=1.512 x3=1.640.11.28=1.512
3 x 3 = 1.512 x^3=1.512 x3=1.512 f ′ ( x 3 ) = 2 ∗ 1.512 − 2 = 1.024 f'(x^3)=2*1.512-2=1.024 f(x3)=21.5122=1.024 f ( x 3 ) = 1.51 2 2 − 2 ∗ 1.512 = − 0.7379 f(x^3)=1.512^2-2*1.512=-0.7379 f(x3)=1.512221.512=0.7379 x 4 = 1.512 − 0.1 ∗ 1.024 = 1.4096 x^4=1.512-0.1*1.024=1.4096 x4=1.5120.11.024=1.4096
4 x 4 = 1.4096 x^4=1.4096 x4=1.4096 … \dots … \dots … \dots

根据梯度下降法的公式进行计算,可以得到上面的表格。可以观察到,导数 f ′ ( x ) f'(x) f(x) 的值越来越小。继续计算上面的表, x x x 的值会越来越小,逐渐逼近 1。当 f ′ ( x ) = 0 f'(x)=0 f(x)=0 时, x = 1 x=1 x=1,此时 f ( x ) = − 1 f(x)=-1 f(x)=1

4. 代码实现

我们利用 Python 代码可以模拟上面的梯度下降过程。

定义一个函数,表示 f ( x ) f(x) f(x)

def my_func(x):"""$y = x^2 - 2x$:param x: 变量:return: 函数值"""return x**2 - 2*x

变量 x 对应于 x x x,my_func() 的结果(返回值)对应于 f ( x ) f(x) f(x)

再定义一个函数,表示 f ′ ( x ) f'(x) f(x)

def grad_func(x):"""函数 $y = x^2 - 2x$ 的导数:param x: 变量:return: 导数值"""return 2*x - 2

变量 x 对应于 x x x,grad_func() 的结果(返回值)对应于 f ′ ( x ) f'(x) f(x)

给定一个学习率 η \eta η,给定一个 x x x 的初始值

eta = 0.1
x = 4.0

那么就可以开始模拟梯度下降法求解最小值。

import numpy as np
import matplotlib.pyplot as pltdef my_func(x):"""$y = x^2 - 2x$:param x: 变量:return: 函数值"""return x**2 - 2*xdef grad_func(x):"""函数 $y = x^2 - 2x$ 的导数:param x: 变量:return: 导数值"""return 2*x - 2eta = 0.1
x = 4.0
record_x = []
record_y = []for i in range(20):y = my_func(x)record_x.append(x)record_y.append(y)x -= eta * grad_func(x)print(np.round(record_x, 4))
print(np.round(record_y, 4))x_f = np.linspace(-2, 4)
y_f = my_func(x_f)plt.plot(x_f, y_f, linestyle='--', color='red')
plt.scatter(record_x, record_y)plt.xlabel('x', size=14)
plt.ylabel('y', size=14)
plt.grid()
plt.show()

x x x 的变化过程为:

[4. 3.4 2.92 2.536 2.2288 1.983 1.7864 1.6291 1.5033 1.4027 1.3221 1.2577 1.2062 1.1649 1.1319 1.1056 1.0844 1.0676 1.054 1.0432]

f ( x ) f(x) f(x) 的变化过程为:

[ 8. 4.76 2.6864 1.3593 0.5099 -0.0336 -0.3815 -0.6042 -0.7467 -0.8379 -0.8962 -0.9336 -0.9575 -0.9728 -0.9826 -0.9889 -0.9929 -0.9954 -0.9971 -0.9981]

我们使用了 matplotlib 可视化函数 f ( x ) f(x) f(x) 的图像,以及梯度下降法求解的过程。

在这里插入图片描述

红色虚线是函数 f ( x ) f(x) f(x) 的图像。

蓝色点表示梯度下降法求解过程中 f ( x ) f(x) f(x) 的值。

5. 讨论 — 学习率 η \eta η

学习率 η \eta η)是一个非常重要的参数。有多重要呢?请接着看……

在上面的例子中,我们设置学习率为 0.1,即 η = 0.1 \eta = 0.1 η=0.1。同样的以上面的例子为例,我们修改学习率。

5.1 当 η \eta η 设置过大

设置 η = 1 \eta = 1 η=1 x x x 的初始值保持一致,仍取值 4.0。

eta = 1
x = 4.0

那么,再一次利用梯度下降法求解,
x x x 的变化过程为:

[ 4. -2. 4. -2. 4. -2. 4. -2. 4. -2. 4. -2. 4. -2. 4. -2. 4. -2. 4. -2.]

f ( x ) f(x) f(x) 的变化过程为:

[8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8.]

可视化结果为:

在这里插入图片描述

上面的输出结果和图像都可以看出, x x x f ( x ) f(x) f(x) 的结果在循环,始终无法得到正确结果,进入了死循环。

5.2 当 η \eta η 设置过小

设置 η = 0.01 \eta = 0.01 η=0.01 x x x 的初始值保持一致,仍取值 4.0。

eta = 0.01
x = 4.0

那么,再一次利用梯度下降法求解,
x x x 的变化过程为:

[4. 3.94 3.8812 3.8236 3.7671 3.7118 3.6575 3.6044 3.5523 3.5012 3.4512 3.4022 3.3542 3.3071 3.2609 3.2157 3.1714 3.128 3.0854 3.0437]

f ( x ) f(x) f(x) 的变化过程为:

[8. 7.6436 7.3013 6.9726 6.6569 6.3537 6.0625 5.7828 5.5142 5.2562 5.0085 4.7705 4.542 4.3226 4.1118 3.9094 3.7149 3.5282 3.3489 3.1767]

可视化结果为:

在这里插入图片描述

上面的输出结果和图像都可以看出,梯度下降法在正确工作。但是求解过程很缓慢,离最小值还有一段距离。此时需要增加循环轮次,消耗更多的资源。

总结:需要设置合理的学习率 η \eta η,过大或过小都不好。

参考

-《用Python编程和实践!数学教科书》

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/35484.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

OpenCV-图像阈值

简单阈值法 此方法是直截了当的。如果像素值大于阈值,则会被赋为一个值(可能为白色),否则会赋为另一个值(可能为黑色)。使用的函数是 cv.threshold。第一个参数是源图像,它应该是灰度图像。第二…

详细了解IO分类

按照数据流方向如何划分? 输入流(Input Stream):从源(如文件、网络等)读取数据到程序。 输出流(Output Stream):将数据从程序写出到目的地(如文件、网络、控…

Appium 安装问题汇总

好生气好生气,装了几天了, opencv4nodejs 和 mjpeg-consumer 就是装不了,气死我了不管了,等后面会装的时候再来完善,气死了气死了。 目录 前言 1、apkanalyzer.bat 2、opencv4nodejs 3、ffmpeg 4、mjpeg-consume…

目标检测知识点总结

1、数据增强 2、指标 3、vit 、swint ViT算法,创新性地将图像划分成一个个patch,并将每个patch展平为一个向量,使得图像数据转化为序列化数据,之后输入到Transformer模型中,实现了Transformer在图像分类任务中的应用。…

Lecture 11 - List,Set,Map

List, Set and Map are all interfaces: they define how these respective types work, but they don’t provide implementation code overview 1. List(列表): (1) 创建、访问和操作列表:ArrayList …

ElfBoard开源项目|基于百度智能云平台的车牌识别项目

本项目基于百度智能云平台,旨在利用其强大的OCR服务实现车牌号码的自动识别。选择百度智能云的原因是其高效的API接口和稳定的服务质量,能够帮助开发者快速实现车牌识别应用。 本项目使用摄像头捕捉图像后,通过集成百度OCR服务的API&#xf…

应对超声波测试挑战,如何选择合适的数字化仪?

一、数字化仪的超声波应用 超声波是频率大于人类听觉范围上限的声学声压(声学)波。超声波设备的工作频率为20 kHz至几千MHz。表1总结了一些更常见的超声波应用的特征。 每个应用中使用的频率范围都反映了实际情况下的平衡。提高工作频率可以通过提高分…

product/admin/list?page=0size=10field=jancodevalue=4562249292272

文章目录 1、ProductController2、AdminCommonService3、ProductApiService4、ProductCommonService5、ProductSqlService https://api.crossbiog.com/product/admin/list?page0&size10&fieldjancode&value45622492922721、ProductController GetMapping("ad…

博物馆导览系统方案(一)背景需求分析与核心技术实现

维小帮提供多个场所的室内外导航导览方案,如需获取博物馆导览系统解决方案可前往文章最下方获取,如有项目合作及技术交流欢迎私信我们哦~撒花! 一、博物馆导览系统的背景与市场需求 在数字化转型的浪潮中,博物馆作为文化传承和知…

Flink学习连载文章11--双流Join

双流 Join 和两个流合并是不一样的 两个流合并:两个流变为 1 个流 union connect 双流 join: 两个流 join,其实这两个流还是原来的,只是满足条件的数据会变为一个新的流。 可以结合 sql 语句中的 union 和 join 的区别。 在离线 Hive 中&…

vue3+wangeditor富文本编辑器详细教程

一、前言 在这篇教程中,我将指导如何使用 Vue 3 和 WangEditor 创建一个功能丰富的富文本编辑器。WangEditor 是一个轻量级的富文本编辑器,它非常适合集成到 Vue 项目中。这个例子展示了如何配置富文本编辑器,包括工具栏、编辑器配置以及如何…

Python学习39天

my_tools.py文件提供工具函数 """ 此文件编写工具函数,供程序员使用 my_tools """def read_confirm_select():"""让用户输入:Y/N,不区分大小写,将用户输入值转为小写返回&#xff…

LCA - Lowest Common Ancestor

LCA - Lowest Common Ancestor https://www.luogu.com.cn/problem/SP14932 题目描述 A tree is an undirected graph in which any two vertices are connected by exactly one simple path. In other words, any connected graph without cycles is a tree. - Wikipedia T…

unity打包web,发送post请求,获取地址栏参数,解决TypeError:s.replaceAll is not a function

发送post请求 public string url "http://XXXXXXXXX";// 请求数据public string postData "{\"user_id\": 1}";// Start is called before the first frame updatevoid Start(){// Post();StartCoroutine(PostRequestCoroutine(url, postData…

恒创科技:如何区分网站的域名主机名

如何区分网站的域名主机名?它们都是网址机制的一部分,当你在地址栏输入它们,就能访问互联网上想去的地方。你可曾思考过主机名和域名的区别呢? 简单来说,域名就像网址,而主机名用于标识网络中的设备。不过,这只是表面…

【技巧学习】ArcGIS如何计算水库库容量?

ArcGIS如何计算水库库容量? 一、数据获取 DEM数据来源于地理空间数据云,该网站是由中科院计算机网络信息中心于2008年创立的地学大数据平台。 二、填洼 将DEM数据中凹陷的区域填充至与倾斜点同样高度,这里的【Z限制】说的是设定一个特定的值&#x…

机器学习——感知机模型

文章目录 前言1.感知机模型介绍1.1基本概念1.2数学表达1.3几何解释1.4优缺点 2.二分类应用2.1应用介绍2.2准备数据集2.2.1环境检查2.2.2数据集介绍2.2.3获取数据2.2.4划分数据集 2.3可视化训练集2.4训练过程2.4.1首轮梯度下降2.4.2多轮梯度下降 2.5可视化分类结果2.6在验证集验…

11.20[JAVAEXP3]重定向细究【DEBUG】

设置了根域名访问为testServlet,让他重定向到首页为test.jsp,事实上也都触发了,但是最后显示的为什么不是test.jsp生成页面,依然还是index.jsp生成的页面?? 重定向是通过Dispatcher进行的,而不是sendRedir…

YOLOv11模型改进-注意力-引入卷积和注意力融合模块(CAFM) 提升小目标和遮挡检测

本篇文章将介绍一个新的改进机制——卷积和注意力融合模块CAFM,并阐述如何将其应用于YOLOv11中,显著提升模型性能。首先,CAFM是为了融合卷积神经网络(CNNs)和 Transformer 的优势,同时对全局和局部特征进行…

APM装机教程(五):测绘无人船

文章目录 前言一、元生惯导RTK使用二、元厚HXF260测深仪使用三、云卓H2pro遥控器四、海康威视摄像头 前言 船体:超维USV-M1000 飞控:pix6c mini 测深仪:元厚HXF160 RTK:元生惯导RTK 遥控器:云卓H12pro 摄像头&#xf…