LeNet网络复现

文章目录

  • 1. LeNet历史背景
    • 1.1 早期神经网络的挑战
    • 1.2 LeNet的诞生背景
  • 2. LeNet详细结构
    • 2.1 总览
    • 2.2 卷积层与其特点
    • 2.3 子采样层(池化层)
    • 2.4 全连接层
    • 2.5 输出层及激活函数
  • 3. LeNet实战复现
    • 3.1 模型搭建model.py
    • 3.2 训练模型train.py
    • 3.3 测试模型test.py
  • 4. LeNet的变种与实际应用
    • 4.1 LeNet-5及其优化
    • 4.2 从LeNet到现代卷积神经网络

1. LeNet历史背景

1.1 早期神经网络的挑战

早期神经网络面临了许多挑战。首先,它们经常遇到训练难题,例如梯度消失和梯度爆炸,特别是在使用传统激活函数如sigmoid或tanh时。另外,当时缺乏大规模、公开的数据集,导致模型容易过拟合并且泛化性能差。再者,受限于当时的计算资源,模型的大小和训练速度受到了很大的限制。最后,由于深度学习领域还处于萌芽阶段,缺少许多现代技术来优化和提升模型的表现。

1.2 LeNet的诞生背景

LeNet的诞生背景是为了满足20世纪90年代对手写数字识别的实际需求,特别是在邮政和银行系统中。Yann LeCun及其团队意识到,对于图像这种有结构的数据,传统的全连接网络并不是最佳选择。因此,他们引入了卷积的概念,设计出了更适合图像处理任务的网络结构,即LeNet。

2. LeNet详细结构

2.1 总览

在这里插入图片描述

2.2 卷积层与其特点

卷积层是卷积神经网络(CNN)的核心。这一层的主要目的是通过卷积操作检测图像中的局部特征。

特点:

  1. 局部连接性: 卷积层的每个神经元不再与前一层的所有神经元相连接,而只与其局部区域相连接。这使得网络能够专注于图像的小部分,并检测其中的特征。
  2. 权值共享: 在卷积层中,一组权值在整个输入图像上共享。这不仅减少了模型的参数,而且使得模型具有平移不变性。
  3. 多个滤波器: 通常会使用多个卷积核(滤波器),以便在不同的位置检测不同的特征。

2.3 子采样层(池化层)

池化层是卷积神经网络中的另一个关键组件,用于缩减数据的空间尺寸,从而减少计算量和参数数量。

主要类型:

  1. 最大池化(Max pooling): 选择覆盖区域中的最大值作为输出。
  2. 平均池化(Average pooling): 计算覆盖区域的平均值作为输出。

2.4 全连接层

在卷积神经网络的最后,经过若干卷积和池化操作后,全连接层用于将提取的特征进行“拼接”,并输出到最终的分类器。

特点:

  1. 完全连接: 全连接层中的每个神经元都与前一层的所有神经元相连接。
  2. 参数量大: 由于全连接性,此层通常包含网络中的大部分参数。
  3. 连接多个卷积或池化层的特征: 它的主要目的是整合先前层中提取的所有特征。

2.5 输出层及激活函数

输出层:
输出层是神经网络的最后一层,用于输出预测结果。输出的数量和类型取决于特定任务,例如,对于10类分类任务,输出层可能有10个神经元。

激活函数:
激活函数为神经网络提供了非线性,使其能够学习并进行复杂的预测。

  1. Sigmoid: 取值范围为(0, 1)。
  2. Tanh: 取值范围为(-1, 1)。
  3. ReLU (Rectified Linear Unit): 最常用的激活函数,将所有负值置为0。
  4. Softmax: 常用于多类分类的输出层,它返回每个类的概率。

3. LeNet实战复现

3.1 模型搭建model.py

import torch
from torch import nn# 自定义网络模型
class LeNet(nn.Module):# 1. 初始化网络(定义初始化函数)def __init__(self):super(LeNet, self).__init__()# 定义网络层self.Sigmoid = nn.Sigmoid()self.c1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2)self.s2 = nn.AvgPool2d(kernel_size=2, stride=2)self.c3 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)self.s4 = nn.AvgPool2d(kernel_size=2, stride=2)self.c5 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5)# 展开self.flatten = nn.Flatten()self.f6 = nn.Linear(120, 84)self.output = nn.Linear(84, 10)# 2. 前向传播网络def forward(self, x):x = self.Sigmoid(self.c1(x))x = self.s2(x)x = self.Sigmoid(self.c3(x))x = self.s4(x)x = self.c5(x)x = self.flatten(x)x = self.f6(x)x = self.output(x)return xif __name__ == "__main__":x = torch.rand([1, 1, 28, 28])model = LeNet()y = model(x)

3.2 训练模型train.py

import torch
from torch import nn
from model import LeNet
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
import os# 数据转换为tensor格式
data_transformer = transforms.Compose([transforms.ToTensor()
])# 加载训练的数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_transformer, download=True)
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)# 加载测试的数据集
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_transformer, download=True)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=16, shuffle=True)# 使用GPU进行训练
device = "cuda" if torch.cuda.is_available() else "cpu"# 调用搭好的模型,将模型数据转到GPU上
model = LeNet().to(device)# 定义一个损失函数(交叉熵损失)
loss_fn = nn.CrossEntropyLoss()# 定义一个优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)# 学习率每隔10轮, 变换原来的0.1
lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)# 定义训练函数
def train(dataloader, model, loss_fn, optimizer):loss, current, n = 0.0, 0.0, 0for batch, (x, y) in enumerate(dataloader):# 前向传播x, y = x.to(device), y.to(device)output = model(x)cur_loss = loss_fn(output, y)_, pred = torch.max(output, axis=1)cur_acc = torch.sum(y == pred)/output.shape[0]optimizer.zero_grad()cur_loss.backward()optimizer.step()loss += cur_loss.item()current += cur_acc.item()n = n + 1print("train_loss" + str(loss/n))print("train_acc" + str(current/n))# 定义测试函数
def val(dataloader, model, loss_fn):model.eval()loss, current, n = 0.0, 0.0, 0with torch.no_grad():for batch, (x, y) in enumerate(dataloader):# 前向传播x, y = x.to(device), y.to(device)output = model(x)cur_loss = loss_fn(output, y)_, pred = torch.max(output, axis=1)cur_acc = torch.sum(y == pred) / output.shape[0]loss += cur_loss.item()current += cur_acc.item()n = n + 1print("val_loss" + str(loss / n))print("val_acc" + str(current / n))return current/n# 开始训练
epoch = 50
min_acc = 0for t in range(epoch):print(f'epoch{t+1}\n-----------------------------------------------------------')train(train_dataloader, model, loss_fn, optimizer)a = val(test_dataloader, model, loss_fn)# 保存最好的模型权重if a > min_acc:folder = "sava_model"if not os.path.exists(folder):os.mkdir("sava_model")min_acc = aprint("sava best model")torch.save(model.state_dict(), 'sava_model/best_model.pth')
print('Done!')

3.3 测试模型test.py

import torch
from model import LeNet
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage# 数据转换为tensor格式
data_transformer = transforms.Compose([transforms.ToTensor()
])# 加载训练的数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_transformer, download=True)
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)# 加载测试的数据集
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_transformer, download=True)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=16, shuffle=True)# 使用GPU进行训练
device = "cuda" if torch.cuda.is_available() else "cpu"# 调用搭好的模型,将模型数据转到GPU上
model = LeNet().to(device)# 加载模型权重并设置为评估模式
model.load_state_dict(torch.load("./sava_model/best_model.pth"))
model.eval()# 获取结果
classes = ["0","1","2","3","4","5","6","7","8","9",
]# 把tensor转换为图片, 方便可视化
show = ToPILImage()# 进入验证
for i in range(5):x, y = test_dataset[i]show(x).show()x = torch.unsqueeze(x, dim=0).float().to(device)with torch.no_grad():pred = model(x)predicted, actual = classes[torch.argmax(pred[0])], classes[y]print(f'Predicted: {predicted}, Actual: {actual}')

4. LeNet的变种与实际应用

4.1 LeNet-5及其优化

LeNet-5是由Yann LeCun于1998年设计的,并被广泛应用于手写数字识别任务。它是卷积神经网络的早期设计之一,主要包含卷积层、池化层和全连接层。

结构:

  1. 输入层:接收32×32的图像。
  2. 卷积层C1:使用5×5的滤波器,输出6个特征图。
  3. 池化层S2:2x2的平均池化。
  4. 卷积层C3:使用5×5的滤波器,输出16个特征图。
  5. 池化层S4:2x2的平均池化。
  6. 卷积层C5:使用5×5的滤波器,输出120个特征图。
  7. 全连接层F6。
  8. 输出层:10个单元,对应于0-9的手写数字。
    激活函数:Sigmoid或Tanh。

优化:

  1. ReLU激活函数: 原始的LeNet-5使用Sigmoid或Tanh作为激活函数,但现代网络更喜欢使用ReLU,因为它的训练更快,且更少受梯度消失的影响。
  2. 更高效的优化算法: 如Adam或RMSProp,它们通常比传统的SGD更快、更稳定。
  3. 批量归一化: 加速训练,并提高模型的泛化能力。
  4. Dropout: 在全连接层中引入Dropout可以增强模型的正则化效果。

4.2 从LeNet到现代卷积神经网络

从LeNet-5开始,卷积神经网络已经经历了巨大的发展。以下是一些重要的里程碑:

  1. AlexNet (2012): 在ImageNet竞赛中取得突破性的成功。它具有更深的层次,使用ReLU激活函数,以及Dropout来防止过拟合。
  2. VGG (2014): 由于其统一的结构(仅使用3×3的卷积和2x2的池化)而闻名,拥有多达19层的版本。
  3. GoogLeNet/Inception (2014): 引入了Inception模块,可以并行执行多种大小的卷积。
  4. ResNet (2015): 引入了残差块,使得训练非常深的网络变得可能。通过这种方式,网络可以达到上百甚至上千的层数。
  5. DenseNet (2017): 每层都与之前的所有层连接,导致具有非常稠密的特征图。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/146441.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

借助Log360实现综合可见性的增强网络安全

当今的企业对技术的依赖程度前所未有,因此强大的威胁检测和响应策略变得至关重要。在现代世界中,网络犯罪分子不断寻找新的、富有创意的方式,以侵入组织的网络并窃取敏感数据。综合性的可见性是一个关键要素,有时被忽视&#xff0…

Unity HDRP Custom Pass 实现场景雪地效果

先使用Shader Graph连一个使用模型法线添加雪地的shader,并赋给一个material。 1.1 先拿到模型世界坐标下的顶点法线,简单处理一下,赋给透明度即可。 给场景添加Custom Pass,剔除不需要的层级。 1.在Hierarchy界面中&#xff…

Wi-Fi直连分享:Android设备间的高速连接

Wi-Fi直连分享:Android设备间的高速连接 引言 随着无线局域网(Wi-Fi)的普及和发展,使用Wi-Fi直连技术(P2P)在没有中间接入点的情况下实现设备间直接互联成为可能。通过Wi-Fi直连,具备相应硬件…

【React】React组件生命周期以及触发顺序(部分与vue做比较)

最近在学习React,发现其中的生命周期跟Vue有一些共同点,但也有比较明显的区别,并且执行顺序也值得讨论一下,于是总结了一些资料在这里,作为学习记录。 v17.0.1后生命周期图片 初始化阶段 由ReactDOM.render()触发 —…

【Java 进阶篇】JDBC插入数据详解

在Java应用程序中,与数据库交互是一项常见的任务。其中,插入数据操作是一种基本的数据库操作之一。本文将详细介绍如何使用Java JDBC(Java Database Connectivity)来执行插入数据操作。无论您是初学者还是有一定经验的开发人员&am…

CSS详细基础(六)边框样式

本期是CSS基础的最后一篇~ 目录 一.border属性 二.边框属性复合写法 三.CSS修改表格标签 四.内边距属性 五.外边距属性 六.其他杂例 1.盒子元素水平居中 2.清除网页内外元素边距 3.外边距的合并与塌陷 4.padding不会撑大盒子的情况 七.综合案例——新浪导航栏仿真 …

Android studio升级Giraffe | 2022.3.1 Patch 1踩坑

这里写自定义目录标题 not "opens java.io" to unnamed module错误报错信息解决 superclass access check failed: class butterknife.compiler.ButterKnifeProcessor$RScanner报错报错信息解决 Android studio升级Giraffe | 2022.3.1 Patch 1后,出现项目…

imgui开发笔记<1>、ubuntu环境下快速应用

去这个链接下载imgui源码(在此之前需要安装opengl glfw3等等): sudo apt-get install libglfw3-dev https://github.com/ocornut/imgui 我这里源码下载到/home/temp/imgui目录下,咱们不需要编译源码成库,而是直接将下…

【Axure高保真原型】3D圆柱图_中继器版

今天和大家分享3D圆柱图_中继器版的原型模板,图表在中继器表格里填写具体的数据,调整坐标系后,就可以根据表格数据自动生成对应高度的圆柱图,鼠标移入时,可以查看对应圆柱体的数据……具体效果可以打开下方原型地址体验…

网络安全渗透测试工具之skipfish

网络安全渗透测试工具skipfish介绍 在数字化的时代,Web 应用程序安全成为了首要任务。想象一下,您是一位勇敢的安全冒险家,迎接着那些隐藏在 Web 应用程序中的未知风险。而在这个冒险之旅中,您需要一款强大的工具来帮助您发现漏洞,揭示弱点。而这个工具就是 Skipfish。 …

LeetCode 周赛上分之旅 #48 一道简单的树上动态规划问题

⭐️ 本文已收录到 AndroidFamily,技术和职场问题,请关注公众号 [彭旭锐] 和 BaguTree Pro 知识星球提问。 学习数据结构与算法的关键在于掌握问题背后的算法思维框架,你的思考越抽象,它能覆盖的问题域就越广,理解难度…

点击、拖拉拽,BI系统让业务掌握数据分析主动权

在今天的商业环境中,数据分析已经成为企业获取竞争优势的关键因素之一。然而,许多企业在面对复杂的数据分析工具时,却常常感到困扰。这些工具往往需要专业的技术人员操作,而且界面复杂,难以理解和使用。对业务人员来说…

阿里云 Oss 权限控制

前言 最近公司的私有 Oss 服务满了,且 Oss 地址需要设置权限,只有当前系统的登录用户才能访问 Oss 下载地址。一开始想着用 Nginx 做个转发来着,Nginx 每当检测当前请求包含特定的 Oss 地址就转发到我们的统一鉴权接口上去,但是紧…

picoctf_2018_shellcode

picoctf_2018_shellcode Arch: i386-32-little RELRO: Partial RELRO Stack: No canary found NX: NX disabled PIE: No PIE (0x8048000) RWX: Has RWX segments32位,啥都没开 这个看着挺大的,直接来个ROPchain,…

Redis缓存穿透、击穿和雪崩

面试高频 服务的高可用问题! 在这里我们不会详细的区分析解决方案的底层! Redis缓存概念 Redis缓存的使用,极大的提升了应用程序的性能和效率,特别是数据查询方面。但同时,它也带来了一些问题。其中,最要…

Android stdio的Gradle菜单栏无内容问题的解决方法

右边Gradle菜单栏里没有Tasks选项内容的问题 正常情况↓ 如果这个问题如果无法解决的话,Gradle打包就只能通过控制台输入命令来解决,但这无疑是把简单问题复杂化了,我们来看看怎么解决这个问题吧。 这里有几个方法提供,可以自行选…

排序篇(四)----归并排序

排序篇(四)----归并排序 1.归并(递归) 基本思想: 归并排序(MERGE-SORT)是建立在归并操作上的一种有效的排序算法,该算法是采用分治法(Divide andConquer)的一个非常典型的应用。将已有序的子序列合并,得到…

leetCode 376.摆动序列 动态规划 + 图解 + 状态转移

376. 摆动序列 - 力扣(LeetCode) 如果连续数字之间的差严格地在正数和负数之间交替,则数字序列称为 摆动序列 。第一个差(如果存在的话)可能是正数或负数。仅有一个元素或者含两个不等元素的序列也视作摆动序列。 例如…

微信小程序引入字体在部分机型失效不兼容解决办法

写小程序页面,美工作图用了特殊字体 引入代码: font-face {font-family: huxiaobo;src: url("https://xxxxxxxx.top/assets/fonts/huxiaobonanshenti.woff") } .font-loaded {font-family: "huxiaobo"; } 上线后发现部分安卓机型不…