VGG16模型实现MNIST图像分类

MNIST图像数据集

MNIST(Modified National Institute of Standards and Technology)是一个经典的机器学习数据集,常用于训练和测试图像处理和机器学习算法,特别是在数字识别领域。该数据集包含了大约 7 万张手写数字图片,其中 6 万张是用于训练,1 万张用于测试。每张图片都是 28x28 像素的灰度图像,展示了从 0 到 9 的手写数字。这些图像已经被处理过,以使得数字在图像中居中且尺寸一致。

MNIST 数据集是一个广泛被用于测试新的机器学习算法的基准,因为它相对较小,易于理解,且可以用于快速验证算法的有效性。许多人使用 MNIST 作为开始学习深度学习的入门数据集,因为它提供了一个简单但具有挑战性的任务,即将手写数字图像分类为相应的数字。

尽管 MNIST 已经存在了很长时间,但它仍然是一个重要的基准数据集,特别是对于新的机器学习研究和算法的初步测试。MINIST数据集中部分图片如下所示:

下载MNIST数据集

由于MINIST作为经典数据集,已经被内嵌在torchvision库中的dataset中了,所以直接使用代码datasets.MNIST进行下载即可。

下载后的文件格式如下图所示。

搭建VGG16图像分类模型

class VGGClassifier(nn.Module):def __init__(self, num_classes):super(VGGClassifier, self).__init__()self.features = models.vgg16(pretrained=True).features  # 使用预训练的VGG16模型作为特征提取器# 重构网络的第一层卷积层,适配mnist数据的灰度图像格式self.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 256),  # 添加一个全连接层,输入特征维度为512x7x7,输出维度为4096nn.ReLU(True),nn.Dropout(), # 随机将一些神经元“关闭”,有效地防止过拟合。nn.Linear(256, 256),  # 添加一个全连接层,输入和输出维度都为4096nn.ReLU(True),nn.Dropout(),nn.Linear(256, num_classes),  # 添加一个全连接层,输入维度为4096,输出维度为类别数(10))self._initialize_weights()  # 初始化权重参数

定义VGG网络结构如上所示,在上面代码中我定义了一个基于 VGG16 架构分类器的模型。VGG16 是一种经典的卷积神经网络模型,由 16 层深度的卷积层和全连接层组成,所构建的 VGGClassifier 类的网络结构包含两个主要部分:

特征提取器(features):这部分使用了预训练的 VGG16 模型的特征提取器。通过调用 models.vgg16(pretrained=True).features 来加载 VGG16 的特征提取器部分。然后,将第一层卷积层的输入通道数从 3 修改为 1,以适应 MNIST 数据集的灰度图像格式。

分类器(classifier):这部分是自定义的分类器,用于对提取的特征进行分类。首先,通过几个全连接层将特征图展平成一维张量,然后通过一系列的线性层和激活函数对特征进行处理。具体来说,包括:一个包含 256 个神经元的全连接层,输入维度为 512x7x7(经过 VGG16 的特征提取器后的输出尺寸),使用 ReLU 激活函数。一个 Dropout 层,用于防止过拟合,随机关闭一些神经元。一个包含 256 个神经元的全连接层,使用 ReLU 激活函数。再次添加一个 Dropout 层。最后是一个包含 num_classes 个神经元的全连接层,用于输出最终的类别预测结果。

通过上述方式,整个网络结构将 VGG16 的特征提取器和自定义的分类器相结合,以适应 MNIST 数据集的图像分类任务。

构建的VGG网络结构如下图所示:

VGG网络结构图

模型训练

# 定义超参数和训练参数
batch_size = 16  # 批处理大小
num_epochs = 5  # 训练轮数(epoch)
learning_rate = 0.001  # 学习率(learning rate)
num_classes = 10  # 类别数(MNIST数据集有10个类别)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPU进行训练,如果有GPU则使用第一个GPU(cuda:0)进行训练,否则使用CPU进行训练。

模型参数设置如下表所示(代码见上)

模型超参数

数值

batchsize

16

num_epochs

5

learning_rate

0.001

num_classes

10

由于MINIST数据集样本数量较大,所以对于上述代码训练速度也会较慢,我考虑使用我的笔记本电脑独显进行运算,却发现电脑显存不够,于是我调小batchsize与epoch,并降低学习率learning rate才让GPU勉强能够运行上面代码,并获得到了模型model.pth,最终获得模型在测试集上面的识别精度为96.7%,精度还是比较高的。(由于笔记本电脑性能有限,在处理较大规模数据的小型项目时速度较慢,故上述代码运行了一下午左右的时间才跑完)。

模型测试

使用上面模型进行手写数字识别的检验。绘制一张图片上面含有9张子图,随机选取识别结果的9张进行展示 。识别效果以及运行结果如下图所示。

 

附录:

 VGG训练代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transformsimport warnings
warnings.filterwarnings("ignore")# 定义数据预处理操作
transform = transforms.Compose([transforms.Resize(224), # 将图像大小调整为(224, 224)transforms.ToTensor(),  # 将图像转换为PyTorch张量transforms.Normalize((0.5,), (0.5,))  # 对图像进行归一化
])# 下载并加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)class VGGClassifier(nn.Module):def __init__(self, num_classes):super(VGGClassifier, self).__init__()self.features = models.vgg16(pretrained=True).features  # 使用预训练的VGG16模型作为特征提取器# 重构网络的第一层卷积层,适配mnist数据的灰度图像格式self.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 256),  # 添加一个全连接层,输入特征维度为512x7x7,输出维度为4096nn.ReLU(True),nn.Dropout(), # 随机将一些神经元“关闭”,有效地防止过拟合。nn.Linear(256, 256),  # 添加一个全连接层,输入和输出维度都为4096nn.ReLU(True),nn.Dropout(),nn.Linear(256, num_classes),  # 添加一个全连接层,输入维度为4096,输出维度为类别数(10))self._initialize_weights()  # 初始化权重参数def forward(self, x):x = self.features(x)  # 通过特征提取器提取特征x = x.view(x.size(0), -1)  # 将特征张量展平为一维向量x = self.classifier(x)  # 通过分类器进行分类预测return xdef _initialize_weights(self):  # 定义初始化权重的方法,使用Xavier初始化方法for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)# 定义超参数和训练参数
batch_size = 16  # 批处理大小
num_epochs = 5  # 训练轮数(epoch)
learning_rate = 0.001  # 学习率(learning rate)
num_classes = 10  # 类别数(MNIST数据集有10个类别)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPU进行训练,如果有GPU则使用第一个GPU(cuda:0)进行训练,否则使用CPU进行训练。# 定义数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)# 初始化模型和优化器
model = VGGClassifier(num_classes=num_classes).to(device)  # 将模型移动到指定设备(GPU或CPU)
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=learning_rate)  # 使用随机梯度下降优化器(SGD)# 训练模型
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(device)  # 将图像数据移动到指定设备labels = labels.to(device)  # 将标签数据移动到指定设备# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()  # 清空梯度缓存loss.backward()  # 计算梯度optimizer.step()  # 更新权重参数if (i + 1) % 100 == 0:  # 每100个batch打印一次训练信息print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, len(train_loader),loss.item()))# 训练结束,保存模型参数
torch.save(model.state_dict(), './model.pth')# 加载训练好的模型参数
model.load_state_dict(torch.load('./model.pth'))
model.eval()  # 将模型设置为评估模式,关闭dropout等操作# 定义评估指标变量
correct = 0  # 记录预测正确的样本数量
total = 0  # 记录总样本数量# 测试模型性能
with torch.no_grad():  # 关闭梯度计算,节省内存空间for images, labels in test_loader:images = images.to(device)  # 将图像数据移动到指定设备labels = labels.to(device)  # 将标签数据移动到指定设备outputs = model(images)  # 模型前向传播,得到预测结果_, predicted = torch.max(outputs.data, 1)  # 取预测结果的最大值对应的类别作为预测类别total += labels.size(0)  # 更新总样本数量correct += (predicted == labels).sum().item()  # 统计预测正确的样本数量# 计算模型准确率并打印出来
accuracy = 100 * correct / total  # 计算准确率,将正确预测的样本数量除以总样本数量并乘以100得到百分比形式的准确率。
print('Accuracy of the model on the test images: {} %'.format(accuracy))  # 打印出模型的准确率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1556960.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

喜讯 | 攸信技术入选第六批专精特新“小巨人”企业

日前,根据工信部评审结果,厦门市工业和信息化局公示了第六批专精特新“小巨人”企业和第三批专精特新“小巨人”复核通过企业名单,其中,厦门攸信信息技术有限公司进入第六批专精特新“小巨人”企业培育。 “专精特新”企业是指具有…

图像分割恢复方法

传统的图像分割方法主要依赖于图像的灰度值、纹理、颜色等特征,通过不同的算法将图像分割成多个区域。这些方法通常可以分为以下几类: 1.基于阈值的方法 2.基于边缘的方法 3.基于区域的方法 4.基于聚类的方法 下面详细介绍这些方法及其示例代码。 1. 基…

代码随想录--栈与队列--用栈实现队列

队列是先进先出,栈是先进后出。 如图所示: 题目 使用栈实现队列的下列操作: push(x) – 将一个元素放入队列的尾部。 pop() – 从队列首部移除元素。 peek() – 返回队列首部的元素。 empty() – 返回队列是否为空。 示例: MyQueue qu…

draw.io 设置默认字体及添加常用字体

需求描述 draw.io 是一个比较好的开源免费画图软件。但是其添加容器或者文本框时默认的字体是 Helvetica,一般的期刊、会议论文或者学位论文要求的英文字体是 Times New Roman,中文字体是 宋体,所以一般需要在文本字体选项里的下拉列表选择 …

分层解耦-05.IOCDI-DI详解

一.依赖注入的注解 在我们的项目中,EmpService的实现类有两个,分别是EmpServiceA和EmpServiceB。这两个实现类都加上Service注解。我们运行程序,就会报错。 这是因为我们依赖注入的注解Autowired默认是按照类型来寻找bean对象的进行依赖注入…

2-115 基于matlab的瞬态提取变换(TET)时频分析

基于matlab的瞬态提取变换(TET)时频分析,瞬态提取变换是一种比较新的TFA方法。该方法的分辨率较高,能够较好地提取出故障的瞬态特征,用于故障诊断领域。通过对原始振动信号设置不同信噪比噪声,对该方法的抗…

关于一个模仿qq通信程序

7月份的时候还在学校那个时候想要学习嵌入式Linux,但是还没有买开发板来玩,再学linux系统编程,网络编程,Linux系统的文件IO,于是学完之后想做一个模仿qq的通信程序于是就有了这个“ailun.exe”,因为暑假去打…

【数据结构与算法】线性表

文章目录 一.什么是线性表?二.线性表如何存储?三.线性表的类型 我们知道从应用中抽象出共性的逻辑结构和基本操作就是抽象数据类型,然后实现其存储结构和基本操作。下面我们依然按这个思路来认识线性表 一.什么是线性表? 定义 线性…

TryHackMe 第7天 | Web Fundamentals (二)

继续介绍一些 Web hacking 相关的漏洞。 IDOR IDOR (Insecure direct object reference),不安全的对象直接引用,这是一种访问控制漏洞。 当 Web 服务器接收到用户提供的输入来检索对象时 (包括文件、数据、文档),如果对用户输入数据过于信…

【springboot】使用代码生成器快速开发

接上一项目&#xff0c;使用mybatis-plus-generator实现简易代码文件生成 在fast-demo-web模块中的pom.xml中添加mybatis-plus-generator、freemarker和Lombok依赖 <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-generator&…

Python | 由高程计算坡度和坡向

写在前面 之前参加一个比赛&#xff0c;提供了中国的高程数据&#xff0c;可以基于该数据进一步计算坡度和坡向进行相关分析。 对于坡度和坡向&#xff0c;这里分享一个找到的库&#xff0c;可以方便快捷的计算。这个库为&#xff1a;RichDEM&#xff0c;官网地址如下 https…

SAP学习笔记 - 豆知识11 - 如何查询某个字段/DataElement/Domain在哪个表里使用?

大家知道SAP的表有10几万个&#xff08;也有说30多万个的&#xff0c;总之很多就是了&#xff09;&#xff0c;而且不断增多&#xff0c;那么当想知道一个字段在哪个表里使用的时候该怎么办呢&#xff1f; 思路就是SAP的表其实也是存在表里的&#xff1a;&#xff09;&#xf…

【Git】TortoiseGitPlink提示输入密码解决方法

问题 克隆仓库&#xff0c;TortoiseGitPlink提示输入密码 解法 1、打开TortoiseGit 下的puttygen工具 位置&#xff1a;C:\Program Files\TortoiseGit\bin\ 2、点击【Load】按钮&#xff0c;载入 C:\Users\Administrator\.ssh\ 文件夹下的id_rsa文件。 3、点击save private …

qt_c++_xml存这种复杂类型

demo&#xff0c;迅雷链接。或者我主页上传的资源 链接&#xff1a;https://pan.xunlei.com/s/VO8bIvYFfhmcrwF-7wmcPW1SA1?pwdnrp4# 复制这段内容后打开手机迅雷App&#xff0c;查看更方便 #ifndef MAINWINDOW_H #define MAINWINDOW_H#include <QMainWindow>#include…

请散户股民看过来,密切关注两件大事

明天股市要开市&#xff0c;不仅散户股民期盼节后股市大涨&#xff0c;上面也同样想在节后来上一个“开门红”。 为此&#xff0c;上面没休假&#xff0c;关起门来办了两件大事&#xff0c;这两天发布消息已提前预热了。 两件大事如下&#xff1a; 一是&#xff0c;上交所10…

什么是 JavaScript 的数组空槽

JavaScript 中的数组空槽一直是一个非常有趣且颇具争议的话题。我们可能对它的实际意义、历史以及现今的新版本中对它的处理方式有所疑问。数组空槽的存在最早可以追溯到 JavaScript 的诞生之初&#xff0c;当时的设计决定让它成为了现代 JavaScript 开发中的一种特别的现象。 …

大数据新视界 --大数据大厂之数据血缘追踪与治理:确保数据可追溯性

&#x1f496;&#x1f496;&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎你们来到 青云交的博客&#xff01;能与你们在此邂逅&#xff0c;我满心欢喜&#xff0c;深感无比荣幸。在这个瞬息万变的时代&#xff0c;我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

计算机毕业设计hadoop+spark天气预测 天气可视化 天气大数据 空气质量检测 空气质量分析 气象大数据 气象分析 大数据毕业设计 大数据毕设

Hadoop天气预测系统开题报告 一、研究背景与意义 在信息化和大数据时代&#xff0c;天气数据已成为社会生活和经济发展中不可或缺的重要资源。天气预测系统作为现代气象学的重要组成部分&#xff0c;对于农业生产、交通管理、环境保护以及防灾减灾等方面都具有重要意义。然而…

集智书童 | 用于时态动作检测的预测反馈 DETR !

本文来源公众号“集智书童”&#xff0c;仅用于学术分享&#xff0c;侵权删&#xff0c;干货满满。 原文链接&#xff1a;用于时态动作检测的预测反馈 DETR ! 视频中的时间动作检测&#xff08;TAD&#xff09;是现实世界中的一个基本且具有挑战性的任务。得益于 Transformer …

Chrome浏览器调用ActiveX控件--allWebOffice控件

背景 allWebOffice控件能够实现在浏览器窗口中在线操作文档的应用&#xff08;阅读、编辑、保存等&#xff09;&#xff0c;支持编辑文档时保留修改痕迹&#xff0c;支持书签位置内容动态填充&#xff0c;支持公文套红&#xff0c;支持文档保护控制等诸多办公功能&#xff0c;本…