机器学习day6-线性代数2-梯度下降

6梯度下降

1.概念

引入梯度下降的原因:1.并不是所有的机器学习都是凸函数,并且可能存在多个极值,不能确定唯一解;2.特征和数据量较多时,矩阵计算量太大(逆矩阵运算的时间复杂度是(

$$
O(n^3)
$$

))

梯度下降:梯度表示损失函数对于模型参数的偏导数。对于每个可训练的参数,梯度告诉我们在当前参数下,沿着每个参数方向变化时,损失函数的变化率。通过计算损失函数对参数的梯度,算法可以根据梯度信息来调整参数从而沿着减少损失的方向更新模型进行调优。

$$
\bar e={\frac{1}{n}}\sum_{i=1}^{n}x_{i}^{2}w^{2}-{\frac{2}{n}}\sum_{i=1}^{n} x_{i}y_{i}w+{\frac{1}{n}}\sum_{i=1}^{n} y_{i}^{2}
$$

中损失函数对于参数w的梯度就是此时w的切线斜率。

梯度下降法是通过不断地优化从而得到最优解。(优化算法都是期望以最快的速度把模型参数w求解出来)

2.步骤

①.Random随机生成初始w,随机生成一组正态分布的数值

$$
w_0,w_1,w_2....w_n
$$

②.求梯度g,即损失函数在此w点上的切线斜率(求导)

③.g<0,表示切线斜率为负数,表示在导数为0的w值的左边,那么就应该把w往右调大不断逼近导数为0的w值;g>0,表示切线斜率为正,表示在导数为0的w值的右边,将w往左调小不断逼近导数为0的w值

④.判断是否收敛,如果收敛就跳出迭代,否则重复②-④。判断收敛的标准是:随着迭代看loss的值变化多少,如果很小甚至不再变化,则认为达到迭代。(迭代次数可以固定)

3.公式

$$
w^{n+1}=w^n-α*gradient
$$

$$
新值=旧值-学习率*导数
$$

学习率:学习率很小可以保证收敛,但是会增加计算;学习率较大,训练会震荡收敛;学习率过大,系统会来回震荡无法收敛

一般将学习率设置成一个小数,0.1,0.01,0.001,0.0001等,可以根据情况进行调整。一般学习率在整体迭代过程中不变,也可以设置成随着迭代次数增多学习率逐渐变小以更精准地得到最优解。

按照步骤来计算w的值

假设要求的损失函数的抛物线公式为:

$$
\bar e=10w^{2}-15.9w+6.5
$$

那么切线公式为:

$$
e'=20w-15.9
$$

w0=0.2,假设学习率为0.01(w0表示第0次w的值)

w1=w0-0.01*e'(w0)=0.2-0.01 *(20 * 0.2-15.9)=0.319

w2=w1-0.01 * e'(w1)=0.319-0.01 * (20 * 0.319-15.9)=0.4142

以此类推

w在最低点的左边还是右边:导数为负值,则w在最低值的左边,应该往右移动(减去一个负数=加上一个正数=往右移动);导数为正值,则w在最低值的右边,应该往左移动(加上一个负数=往左移动)

4.代码实现梯度下降

①.损失函数为一个特征的抛物线:

$$
loss(w_1)=(w_1-3.5)^2-4.5w_1+10
$$

#自己实现一维梯度下降算法(一个w)
import numpy as np
import matplotlib.pyplot as plt
​
w=np.linspace(-10,10,100)
#print(w)
def loss(w):return (w-3.5)**2-4.5*w+10
​
def dloss(w):return 2*(w-3.5)-4.5
#print(loss)
plt.plot(w,loss(w))
plt.show()
​
#梯度下降
#学习率
learning_rate=0.1
#初始化一个w值
np.random.seed(1)
w=np.random.randint(-10,20)#随机给一个w值:-5
print(w)
e=loss(w)#初始化的w为-5时的loss值
​
x=[w]
y=[e]
#第1次梯度下降
w=w-learning_rate*dloss(w)
e=loss(w)
x.append(w)
y.append(e)
#第2次梯度下降
w=w-learning_rate*dloss(w)
e=loss(w)
x.append(w)
y.append(e)
#第3次梯度下降
w=w-learning_rate*dloss(w)
e=loss(w)
x.append(w)
y.append(e)
​
plt.scatter(x,y)#画点
plt.show()
#自己实现一维梯度下降算法(一个w)-循环
import numpy as np
import matplotlib.pyplot as plt
#假设有一个函数是y=wx 通过损失函数的思路已经得出了损失函数
def loss(w):return (w-3.5)**2-4.5*w+10
#得出导函数
def dloss(w):return x*(w-3.5)-4.5
def train():#初始化随机给一个w值w=-10#np.random.randint(-10,20)#初始化学习率lr=0.1#更新次数epoch=1000#梯度下降更新wt0,t1=1,100for i in range(0,epoch):#print(i)lr=t0/(t1+i)w=w-lr*dloss(w)print(f"第{i}次w更新后的值:{w},更新后损失函数的值{loss(w)}")
​
train()
​

②.损失函数为两个特征的抛物线:

$$
loss(w_1,w_2)=(w_1-3.5)^2+(w_2-2)^2+3w_1w_2-4.5w_1+2w_2+20
$$

分别对w1,w2求导

$$
对w_1求导,就是把w_2看成一个常数: loss(w1)’=2(w_1-3.5)+0+3w_2-4.5=2w_1+3w_2-11.5
$$

$$
对w_2求导,就是把w_1看成一个常数: loss(w2)’=0+2(w_2-2)+3w_1+2=2w_2+3w_1-2
$$

参照一维求梯度下降的算法:

#自己实现二维梯度下降算法(两个w)
import numpy as np
​
#假设有一个函数是y=wx 通过损失函数的思路已经得出了损失函数
def loss(w1,w2):return (w1-3.5)**2+(w2-2)**2+3*w1*w2-4.5*w1+2*w2+20
#得出导函数
def dloss_w1(w1,w2):return 2*(w1-3.5)+3*w2-4.5
def dloss_w1(w1,w2):return 2*(w2-2)+3*w1+2
#梯度下降算法
def train():#初始化随机给一个w1,w2值w1=10#np.random.randint(-10,20)w2=10#初始化学习率lr=0.1#更新次数epoch=1000#梯度下降更新wt0,t1=1,100for i in range(0,epoch):#更新w1,w2lr=t0/(t1+i)w1_=w1w2_=w2w1=w1-lr*dloss(w1_,w2_)w2=w2-lr*dloss(w1_,w2_)print(i,w1,w2,loss(w1_,w2_))#print(f"第{i}次w更新后的值:{w1,w2},更新后损失函数的值{loss(w)}")
​
train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/18982.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

2024.6使用 UMLS 集成的基于 CNN 的文本索引增强医学图像检索

Enhancing Medical Image Retrieval with UMLS-Integrated CNN-Based Text Indexing 问题 医疗图像检索中&#xff0c;图像与相关文本的一致性问题&#xff0c;如患者有病症但影像可能无明显异常&#xff0c;影响图像检索系统准确性。传统的基于文本的医学图像检索&#xff0…

H.264/H.265播放器EasyPlayer.js网页直播/点播播放器关于播放的时候就有声音

EasyPlayer.js H5播放器&#xff0c;是一款能够同时支持HTTP、HTTP-FLV、HLS&#xff08;m3u8&#xff09;、WS、WEBRTC、FMP4视频直播与视频点播等多种协议&#xff0c;支持H.264、H.265、AAC、G711A、Mp3等多种音视频编码格式&#xff0c;支持MSE、WASM、WebCodec等多种解码方…

Redis 的代理类注入失败,连不上 redis

在测试 redis 是否成功连接时&#xff0c;发现 bean 没有被创建成功&#xff0c;导致报错 根据报错提示&#xff0c;需要我们添加依赖&#xff1a; <dependency><groupId>org.apache.commons</groupId><artifactId>commons-pool2</artifactId>&l…

Prometheus结合K8s(一)搭建

公司之前K8s集群没有监控&#xff0c;top查看机器cpu使用率很高&#xff0c;为了监控pod的cpu和内存&#xff0c;集群外的mysql资源&#xff0c;初步搭建了Prometheus监控系统 提前准备镜像 docker.io/grafana/grafana 10.4.4 docker.io/prom/prometheus v2.47.2 docker.io/…

Vscode/Code-server无网环境安装通义灵码

Date: 2024-11-18 参考材料&#xff1a;https://help.aliyun.com/zh/lingma/user-guide/individual-edition-login-tongyi-lingma?spma2c4g.11186623.0.i0 1. 首先在vscode/code-server插件市场中安装通义插件&#xff0c;这步就不细说了。如果服务器没网&#xff0c;会问你要…

【划分型DP-约束划分个数】力扣813. 最大平均值和的分组

给定数组 nums 和一个整数 k 。我们将给定的数组 nums 分成 最多 k 个非空子数组&#xff0c;且数组内部是连续的 。 分数 由每个子数组内的平均值的总和构成。 注意我们必须使用 nums 数组中的每一个数进行分组&#xff0c;并且分数不一定需要是整数。 返回我们所能得到的最…

IDEA:2023版远程服务器debug

很简单&#xff0c;但是很多文档没有写清楚&#xff0c;wocao 一、首先新建一个远程jvm 二、配置 三、把上面的参数复制出来 -agentlib:jdwptransportdt_socket,servery,suspendn,address5005 四、然后把这串代码放到服务器中 /www/server/java/jdk1.8.0_371/bin/java -agentl…

centos安装jenkins

本机使用虚拟机centos 7.9.2009 安装gitlab&#xff0c;本机的虚拟机ip地址是 192.168.60.151&#xff0c; 步骤记录如下 1、下载jenkins&#xff0c;安装jenkins之前需要安装jdk jdk和jenkins的版本对应关系参考&#xff1a;Redhat Jenkins Packages Index of /redhat-stable…

蜀道山CTF<最高的山最长的河>出题记录

出这道题的最开始感觉就是,因为现在逆向的形式好多,我最开始学习的时候,经常因为很多工具,或者手段完全不知道,就很懵逼,很多师傅都出了各种类型的,我就想着给以前的"自己"出一道正常exe,慢慢调的题,为了不那么简单,我就选择了C(究极混淆,可能比rust好点),让大家无聊…

中伟视界:AI智能分析算法如何针对非煤矿山的特定需求,提供定制化的安全生产解决方案

非煤矿山智能化改造&#xff0c;除了政策文件&#xff0c;上级监管单位需要安装的AI智能分析算法功能之外的&#xff0c;矿方真正关心的&#xff0c;能解决矿方安全生产隐患的AI智能分析算法功能有哪些呢&#xff1f; 经过与矿方的现场交流沟通&#xff0c;收集第一现场人员对安…

如何生成python项目需要的最小requirements.txt文件?

今天咱们来聊聊 Python 项目中如何生成一个“最小的” requirements.txt 文件。我们都知道&#xff0c;当我们开发一个 Python 项目的时候&#xff0c;很多时候都会在一个虚拟环境中进行&#xff0c;这样一来&#xff0c;就能避免不同项目之间的依赖冲突。 可有时候&#xff0c…

每日论文22-24ESSERC一种54.6-65.1GHz多路径同步16振荡器

《A 54.6-65.1 GHz Multi-Path-Synchronized 16-Core Oscillator Achieving −131.4 dBc/Hz PN and 195.8 dBc/Hz FoMT at 10 MHz Offset in 65nm CMOS》24欧洲固态 本文是在60GHz 16核VCO的工作&#xff0c;主要亮点在于每一组中四个VCO之间的三路同步拓扑结构&#xff0c;有…

web——upload-labs——第十一关——黑名单验证,双写绕过

还是查看源码&#xff0c; $file_name str_ireplace($deny_ext,"", $file_name); 该语句的作用是&#xff1a;从 $file_name 中去除所有出现在 $deny_ext 数组中的元素&#xff0c;替换为空字符串&#xff08;即删除这些元素&#xff09;。str_ireplace() 在处理时…

网络安全之国际主流网络安全架构模型

目前&#xff0c;国际主流的网络安全架构模型主要有&#xff1a; ● 信息技术咨询公司Gartner的ASA&#xff08;Adaptive Security Architecture自适应安全架构&#xff09; ● 美国政府资助的非营利研究机构MITRE的ATT&CK&#xff08;Adversarial Tactics Techniques &…

StarRocks 架构

StarRocks 是什么&#xff1f;&#xff08; What is StarRocks?&#xff09; StarRocks 是 MPP 的查询引擎&#xff0c;用来做实时查询&#xff0c;提供亚秒级的查询性能。 兼容 MYSQL 协议&#xff0c;可以和大部分 BI 工具进行无缝衔接。 Apache 2.0 开源产品。 使用场景&…

图像处理 之 凸包和最小外围轮廓生成

“ 最小包围轮廓之美” 一起来欣赏图形之美~ 1.原始图片 男人牵着机器狗 2.轮廓提取 轮廓提取 3.最小包围轮廓 最小包围轮廓 4.凸包 凸包 5.凸包和最小包围轮廓的合照 凸包和最小包围轮廓的合照 上述图片中凸包、最小外围轮廓效果为作者实现算法生成。 图形几何之美系列&#…

【机器学习】决策树算法原理详解

决策树 1 概述 1.1 定义 决策树是一种解决分类问题的算法&#xff0c;决策树算法采用树形结构&#xff0c;使用层层推理来实现最终的分类。 决策树即可以做分类&#xff0c;也可以做回归。它主要分为两种&#xff1a;分类树 和 回归树。 1.2 决策树算法 第一个决策树算法…

基于深度学习的车牌检测系统的设计与实现(安卓、YOLOV、CRNNLPRNet)+文档

&#x1f497;博主介绍&#x1f497;&#xff1a;✌在职Java研发工程师、专注于程序设计、源码分享、技术交流、专注于Java技术领域和毕业设计✌ 温馨提示&#xff1a;文末有 CSDN 平台官方提供的老师 Wechat / QQ 名片 :) Java精品实战案例《700套》 2025最新毕业设计选题推荐…

中国省级金融发展水平指数(金融机构存款余额、贷款余额、GDP)2020-2023年

数据范围&#xff1a; 包含的数据内容如下&#xff1a; 分省份金融机构存款余额、分省份金融机构贷款余额、分省份金融机构存贷款余额、分省份GDP、分省份金融发展指数 西藏自治区、贵州省、黑龙江省2023年数据暂未公布&#xff0c;计算至2022年&#xff0c;其他省份数据无缺失…

如何在 Ubuntu 上安装 Mosquitto MQTT 代理

如何在 Ubuntu 上安装 Mosquitto MQTT 代理 Mosquitto 是一个开源的消息代理&#xff0c;实现了消息队列遥测传输 (MQTT) 协议。在 Ubuntu 22.04 上安装 MQTT 代理&#xff0c;您可以利用 MQTT 轻量级的 TCP/IP 消息平台&#xff0c;该平台专为资源有限的物联网 (IoT) 设备设计…