AI学习指南深度学习篇-Adam的Python实践

AI学习指南深度学习篇-Adam的Python实践

在深度学习领域,优化算法是影响模型性能的关键因素之一。Adam(Adaptive Moment Estimation)是一种广泛使用的优化算法,因其在多种问题上均表现优异而被广泛使用。本文将深入探讨Adam优化器,并提供详细的代码示例,展示如何在Python的深度学习库(如TensorFlow和PyTorch)中实现Adam,进行模型训练以及调参过程。

引言

优化算法的选择会影响深度学习模型的收敛速度和最终性能。Adam算法不仅结合了动量(Momentum)的优点,还引入了自适应学习率,这使得其在许多任务中表现良好。本文将通过实际代码示例介绍Adam的实现和调参过程,让读者能够在自己的项目中有效应用这一算法。

Adam优化器概述

2.1 公式推导

Adam优化器的核心思想是计算梯度的动量以及梯度的平方动量,并利用这两个动量来调整学习率。Adam的更新公式如下:

  1. 初始化参数

    • ( m t = 0 ) ( m_t = 0 ) (mt=0)(一阶矩估计)
    • ( v t = 0 ) ( v_t = 0 ) (vt=0)(二阶矩估计)
    • ( t = 0 ) ( t = 0 ) (t=0)(时间步长)
    • ( β 1 , β 2 ) ( \beta_1, \beta_2 ) (β1,β2)(通常取值为0.9,0.999)
    • ( ϵ ) ( \epsilon ) (ϵ)(通常取小值以避免除零错误)
  2. 参数更新
    [ t = t + 1 ] [ t = t + 1 ] [t=t+1]
    [ m t = β 1 ⋅ m t − 1 + ( 1 − β 1 ) ⋅ g t ] [ m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t ] [mt=β1mt1+(1β1)gt]
    [ v t = β 2 ⋅ v t − 1 + ( 1 − β 2 ) ⋅ g t 2 ] [ v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2 ] [vt=β2vt1+(1β2)gt2]
    [ m ^ t = m t 1 − β 1 t ] [ \hat{m}_t = \frac{m_t}{1 - \beta_1^t} ] [m^t=1β1tmt]
    [ v ^ t = v t 1 − β 2 t ] [ \hat{v}_t = \frac{v_t}{1 - \beta_2^t} ] [v^t=1β2tvt]
    [ θ t = θ t − 1 − α v ^ t + ϵ ⋅ m ^ t ] [ \theta_{t} = \theta_{t-1} - \frac{\alpha}{\hat{v}_t + \epsilon} \cdot \hat{m}_t ] [θt=θt1v^t+ϵαm^t]

2.2 参数说明

  • 学习率 ( ( α ) ) ((\alpha)) ((α)):控制每次更新的步幅,通常初始值设为0.001。
  • ( β 1 ) (\beta_1) (β1) ( β 2 ) (\beta_2) (β2):分别控制一阶矩和二阶矩的衰减率。
  • ( ϵ ) (\epsilon) (ϵ):通常设为 ( 1 0 − 8 ) (10^{-8}) (108),避免在计算时出现除零错误。

在TensorFlow中使用Adam

3.1 环境准备

确保你的计算环境中安装了TensorFlow和其他必要的库:

pip install tensorflow numpy matplotlib

3.2 数据加载

我们将使用Keras提供的MNIST手写数字数据集作为示例:

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()# 数据预处理
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

3.3 构建模型

我们将定义一个简单的神经网络模型:

def create_model():model = models.Sequential()model.add(layers.Flatten(input_shape=(28, 28)))model.add(layers.Dense(128, activation="relu"))model.add(layers.Dropout(0.2))model.add(layers.Dense(10, activation="softmax"))return model

3.4 训练模型

使用Adam优化器训练模型:

model = create_model()# 编译模型
model.compile(optimizer="adam",loss="categorical_crossentropy",metrics=["accuracy"])# 训练模型
history = model.fit(x_train, y_train, epochs=10, batch_size=32, validation_split=0.2)

3.5 调整超参数

可以通过以下方式调整超参数,比如修改学习率或尝试不同的批大小:

from tensorflow.keras.optimizers import Adam# 创建自定义Adam优化器
adam = Adam(learning_rate=0.001)# 重新编译模型
model.compile(optimizer=adam, loss="categorical_crossentropy", metrics=["accuracy"])# 重新训练模型
history = model.fit(x_train, y_train, epochs=10, batch_size=64, validation_split=0.2)

在PyTorch中使用Adam

4.1 环境准备

确保你的计算环境中安装了PyTorch和其他必要的库:

pip install torch torchvision numpy matplotlib

4.2 数据加载

与TensorFlow类似,我们将使用同样的数据集:

import torch
from torchvision import datasets, transforms
from torch import nn, optim# 数据预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])# 加载MNIST数据集
trainset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)testset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)

4.3 构建模型

PyTorch模型构建如下:

class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(28 * 28, 128)self.dropout = nn.Dropout(0.2)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = x.view(x.shape[0], -1)  # 展平操作x = torch.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return xmodel = SimpleNN()

4.4 训练模型

使用Adam优化器训练模型的示例如下:

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
epochs = 10
for epoch in range(epochs):running_loss = 0for images, labels in trainloader:optimizer.zero_grad()  # 清空梯度output = model(images)  # 前向传播loss = criterion(output, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数running_loss += loss.item()print(f"Epoch {epoch + 1}/{epochs} - Loss: {running_loss/len(trainloader)}")

4.5 调整超参数

在PyTorch中,你也可以像在TensorFlow中那样调整超参数,下面是修改学习率的例子:

# 创建自定义Adam优化器
optimizer = optim.Adam(model.parameters(), lr=0.0001)# 重新训练模型
for epoch in range(epochs):running_loss = 0for images, labels in trainloader:optimizer.zero_grad()output = model(images)loss = criterion(output, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch + 1}/{epochs} - Loss: {running_loss/len(trainloader)}")

结论

Adam优化器因其良好的自适应性和快速的收敛能力,成为深度学习中最流行的优化算法之一。在TensorFlow和PyTorch等深度学习框架中,Adam均被用户广泛应用。本文详细介绍了在这两种框架中使用Adam优化器进行模型训练的完整流程,并展示了如何在训练过程中灵活调整超参数。希望这篇文章能帮助你更好地理解和应用Adam优化器。尽管TensorFlow和PyTorch有其独特之处,但选用合适的优化器对于模型的最终表现仍然至关重要。在实际应用中,建议尝试多种优化算法并进行超参数调整,以获得最佳的训练效果。

如果想了解更深入的Adam算法工作原理或其他优化算法的使用,请关注后续更新,继续学习更多的深度学习内容。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/144039.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

蓝桥杯嵌入式客观题合集

十四届模拟赛二客观题 解析:STM32微控制器的I/O端口寄存器必须按32位字被访问 解析:微分电路能将三角波转换为方波;积分电路能将方波转换为三角波 解析:放大电路的本质是能量的控制与转换 解析:具有n个节点&#xff0c…

Ansible——Playbook基本功能???

文章目录 一、Ansible Playbook介绍1、Playbook的简单组成1)“play”2)“task”3)“playbook” 2、Playbook与ad-hoc简单对比区别联系 3、YAML文件语法:---以及多个---??使用 include 指令 1. 基本结构2. 数…

Java入门:09.Java中三大特性(封装、继承、多态)02

2 继承 需要两个类才能实现继承的效果。 比如:类A 继承 类B A类 称为 子类 , 衍生类,派生类 B类 称为 父类,基类,超类 继承的作用 子类自动的拥有父类的所有属性和方法 (父类编写,子类不需要…

IDEA开发HelloWorld程序

IDEA管理Java程序的结构 project(项目、工程)---project中可以创建多个modulemodule(模块)---module中可以创建多个packagepackage(包)---package中可以创建多个classclass(类)---c…

光控资本:股市黑色星期一是什么意思?黑色星期五什么意思?

股市黑色星期一是指股市大跌经常出现在星期一的现象。 最著名的黑色星期一便是1987年10月19日(星期一)产生的全球股市暴降工作,当日全球股市在纽约道琼斯公司工业均匀指数带头暴降下全面下泻, 引发金融商场惊惧, 以及…

python 爬虫 selenium 笔记

todo 阅读并熟悉 Xpath, 这个与 Selenium 密切相关、 selenium selenium 加入无图模式,速度快很多。 from selenium import webdriver from selenium.webdriver.chrome.options import Options# selenium 无图模式,速度快很多。 option Options() o…

2024 go-zero社交项目实战

背景 一位商业大亨,他非常看好国内的社交产品赛道,想要造一款属于的社交产品,于是他找到了负责软件研发的小明。 小明跟张三一拍即合,小明决定跟张三大干一番。 社交产品MVP版本需求 MVP指:Minimum Viable Product&…

【C语言二级考试】循环结构设计

C语言二级考试——循环结构程序设计 五.循环结构程序设计 1.for循环结构 2.while和do-while循环结构 3.continue语句和break语句 4.循环的嵌套 知识点参考【C语言】循环-CSDN博客 文章目录 1.for循环2.while和do-while循环结构3.continue语句和break语句4.循环的嵌套 1.for循环…

阿里云容器服务Kubernetes部署新服务

这里部署的是前端项目 1.登录控制台-选择集群 2.选择无状态-命名空间-使用镜像创建 3.填写相关信息 应用基本信息: 容器配置: 高级配置: 创建成功后就可以通过30006端口访问项目了

【测向定位】差频MUSIC算法DOA估计【附MATLAB代码】

​微信公众号:EW Frontier QQ交流群:554073254 摘要 利用多频处理方法,在不产生空间混叠的情况下,估计出高频区域平面波的波达方向。该方法利用了差频(DF),即两个高频之间的差。这使得能够在可…

视觉语言大模型模型介绍-CLIP学习

多模态学习领域通过结合图像和文本信息,为各种视觉语言任务提供了强大的支持。图像和文本的结合在人工智能领域具有重要的意义,它使得机器能够更全面地理解人类的交流方式。通过这种结合,模型能够处理包括图像描述、视觉问答、特征提取和图像…

多线程---线程的状态及常用方法

1. 线程的状态 在Java程序中,一个线程对象通过调用start()方法启动线程,并且在线程获取CPU时,自动执行run()方法。run()方法执行完毕,代表线程的生命周期结束。 在整个线程的生命周期中,线程的状态有以下六种&#xff…

前海桂湾的海边免费停车场

​前海很多打工人晚上加班前海边散步的地方。相信很多前海打工人都曾经路过这个免费的停车场。坐标出于滨海大道的断头路区域。 看卫星地图可以发现,是个断头路,但是面积还是很大,停个几十辆车没问题。我就停过一次,周末带娃来这里…

ESP8266+使用串口1打印LOG+释放串口0

Menuconfig配置 具体的位置位于Component config > Common ESP-related 配置后,串口0上电还是会打印一些信息,除此之外就不打印了。 ets Jan 8 2013,rst cause:2, boot mode:(3,6)load 0x40100000, len 7792, room 16 tail 0 chksum 0x44 load 0…

Lab2 【哈工大_操作系统】操作系统的引导

本节将更新哈工大《操作系统》课程第二个 Lab 实验 操作系统的引导。按照实验书要求,介绍了非常详细的实验操作流程,并提供了超级无敌详细的代码注释。文末附完整 bootsect.s 和 setup.s 标准答案代码以及超详细注释。 实验目的: 熟悉 hit-o…

C语言中的assert断言

Assert断言 断言是程序中处理异常的一种高级形式。可以在任何时候启用和禁用断言验证,因此可以在测试时启用断言,而在部署时禁用断言。同样,程序投入运行后,最终用户在遇到问题时可以重新启用断言。 用法: #…

AD域控服务器

1.AD域控服务器安装 2.客户端Windows10加入域环境 3.组织单位OU和域用户创建 目的是分部门管理用户和使用域用户登录客户端 4.域用户安全策略 5.当客户端密码锁住了,管理员解锁账户。 6.只允许域用户使用自己的电脑 7.域策略 7.1统一客户端桌面壁纸 7.2重定向用户配置文件路径…

软件设计画图,流程图、甘特图、时间轴图、系统架构图、网络拓扑图、E-R图、思维导图

目录 一、流程图 二、甘特图 三、时间轴图 四、系统架构图 五、网络拓扑图 六、E-R图 七、思维导图 一、流程图 是一种用符号表示算法、工作流或流程的图形。用不同的图形表示不同含义,如椭圆表示开始和结束、菱形表示判断等。 画图工具WPS office 应用市场…

如何使用ssm实现基于vue.js的购物商场的设计与实现+vue

TOC ssm616基于vue.js的购物商场的设计与实现vue 第1章 绪论 1.1选题动因 当前的网络技术,软件技术等都具备成熟的理论基础,市场上也出现各种技术开发的软件,这些软件都被用于各个领域,包括生活和工作的领域。随着电脑和笔记本…

如何使用ssm实现基于ssm框架的车辆出租管理系统+vue

TOC ssm643基于ssm框架的车辆出租管理系统vue 第1章 绪论 1.1 课题背景 二十一世纪互联网的出现,改变了几千年以来人们的生活,不仅仅是生活物资的丰富,还有精神层次的丰富。在互联网诞生之前,地域位置往往是人们思想上不可跨域…