PyTorch 实现手写数字识别

PyTorch 实现手写数字识别

在本教程中,我们将使用 PyTorch 实现经典的手写数字识别任务。我们将使用 MNIST 数据集,这是一个包含手写数字的图像数据集。我们将介绍如何使用 PyTorch 构建、训练和评估一个简单的卷积神经网络(CNN)模型来进行手写数字识别。

1. 项目概述

手写数字识别任务是通过训练模型,让其能够识别手写数字图像并输出正确的数字类别(0-9)。MNIST 数据集包含 28x28 像素的灰度图像,每个图像代表一个手写数字。

我们将使用以下步骤:

  1. 加载 MNIST 数据集
  2. 构建一个卷积神经网络(CNN)
  3. 训练模型
  4. 评估模型性能
  5. 进行测试预测

2. 官方文档链接

  • PyTorch 官方文档
  • MNIST 数据集链接

3. 安装 PyTorch 和依赖库

首先,确保您已经安装了 PyTorch 和相关依赖库。如果没有安装,可以运行以下命令:

pip install torch torchvision matplotlib

4. 加载 MNIST 数据集

我们将使用 torchvision 提供的 MNIST 数据集。它包含 60,000 个训练样本和 10,000 个测试样本。

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 数据预处理:将图像转换为张量,并进行标准化
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])# 下载并加载 MNIST 训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 加载数据集
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 查看数据集的大小
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")# 可视化部分样本
examples = enumerate(train_loader)
batch_idx, (example_data, example_targets) = next(examples)
plt.figure(figsize=(10, 3))
for i in range(6):plt.subplot(1, 6, i + 1)plt.imshow(example_data[i][0], cmap='gray')plt.title(f"Label: {example_targets[i]}")plt.axis('off')
plt.show()

说明

  • transforms.Compose:我们将图像转换为 PyTorch 张量,并将像素值标准化为 [-1, 1] 的范围。
  • DataLoader:用于将数据集加载为批次,并打乱数据顺序以便训练时使用。

5. 构建卷积神经网络(CNN)

我们将构建一个简单的 CNN 模型,用于手写数字识别。该模型将包含两个卷积层和两个全连接层。

import torch.nn as nn
import torch.nn.functional as F# 定义 CNN 模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 卷积层1: 输入通道为1(灰度图),输出通道为16,卷积核大小为3x3self.conv1 = nn.Conv2d(1, 16, kernel_size=3)# 卷积层2: 输入通道为16,输出通道为32,卷积核大小为3x3self.conv2 = nn.Conv2d(16, 32, kernel_size=3)# 全连接层1: 输入为32*5*5(展平后的特征图),输出为128self.fc1 = nn.Linear(32 * 5 * 5, 128)# 全连接层2: 输入为128,输出为10(10个类别)self.fc2 = nn.Linear(128, 10)def forward(self, x):# 卷积层 + ReLU + 最大池化层x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2(x), 2))# 展平成一维向量x = x.view(-1, 32 * 5 * 5)# 全连接层 + ReLUx = F.relu(self.fc1(x))# 输出层x = self.fc2(x)return x# 实例化模型
model = CNN()
print(model)

说明

  • conv1conv2:卷积层用于提取图像特征。第一个卷积层从 1 个输入通道(灰度图像)转换为 16 个特征图,第二个卷积层将 16 个特征图转换为 32 个特征图。
  • max_pool2d:最大池化层,用于下采样特征图,将特征图尺寸减半。
  • fc1fc2:全连接层,用于将卷积层提取到的特征进行分类。

6. 训练模型

我们将定义损失函数和优化器,然后在训练数据集上训练模型。

import torch.optim as optim# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 将模型移动到 GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 训练模型
epochs = 5
for epoch in range(epochs):running_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}")print("训练完成!")

说明

  • CrossEntropyLoss:用于分类任务的损失函数,适用于多分类问题。
  • optimizer:使用 Adam 优化器,能够自动调整学习率并加快收敛速度。
  • 训练过程包括前向传播、损失计算、反向传播和参数更新。

7. 评估模型性能

在训练完成后,我们将使用测试数据集来评估模型的性能,计算模型在测试集上的准确率。

# 测试模型
model.eval()  # 切换到评估模式
correct = 0
total = 0with torch.no_grad():  # 关闭梯度计算for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'测试集上的准确率: {100 * correct / total:.2f}%')

说明

  • model.eval():在评估模型时关闭 dropout 和 batch normalization。
  • torch.no_grad():关闭梯度计算以提高测试阶段的效率。

8. 进行预测

最后,我们可以使用训练好的模型对手写数字图像进行预测。

# 从测试集中取出一个样本
example_data, example_target = next(iter(test_loader))
example_data = example_data.to(device)# 使用模型进行预测
model.eval()
with torch.no_grad():output = model(example_data)# 可视化预测结果
plt.figure(figsize=(10, 3))
for i in range(6):plt.subplot(1, 6, i + 1)plt.imshow(example_data[i][0].cpu(), cmap='gray')plt.title(f"预测: {torch.argmax(output[i]).item()}")plt.axis('off')
plt.show()

说明

  • 取出测试集中的一批样本进行预测,并可视化模型的预测结果。

9. 总结

在本教程中,我们使用 PyTorch 实现了手写数字识别任务,构建了一个简单的卷积神经网络(CNN),并在 MNIST 数据集上进行了训练和评估。通过此项目,您可以了解如何加载数据、构建模型、训练、评估和测试 PyTorch 模型。

10. 改进方向

  • 增加网络深度:可以增加卷积层和全连接层的

数量,提高模型的表现。

  • 使用数据增强:通过数据增强技术(旋转、缩放等),可以提高模型的泛化能力。
  • 应用在其他数据集:除了 MNIST,还可以将模型应用到其他数据集,如 FashionMNIST、CIFAR-10 等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/146130.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

项目第四弹:交换机、队列、绑定信息管理模块分析与代码实现

项目第四弹:交换机、队列、绑定信息管理模块分析与代码实现 一、模块设计分析1.模块划分2.功能需求 二、交换机模块的实现1.交换机结构体的实现2.交换机持久化管理模块的实现3.交换机对外管理模块实现声明、删除交换机时的查找不能复用exists函数为何持久化管理模块…

查找算法 01分块查找

自己设计一个分块查找的例子,不少于15个数据元素,并建立分块查找的索引 基于上述例子,计算查找成功的ASL、查找失败的ASL 拓展: ‌‌分块查找的平均查找长度(‌ASL)的计算公式如下‌:‌ ‌顺序…

ESP32 JTAG 调试

前言 个人邮箱:zhangyixu02gmail.com本人使用的是 Ubuntu 环境,采用 GDB 方式进行调试。对于新手,我个人还是建议参考ESP32S3学习笔记(0)—— Vscode IDF环境搭建及OpenOCD调试介绍进行图形化的方式调试。如果是希望在…

占领矩阵-第15届蓝桥省赛Scratch中级组真题第5题

[导读]:超平老师的《Scratch蓝桥杯真题解析100讲》已经全部完成,后续会不定期解读蓝桥杯真题,这是Scratch蓝桥杯真题解析第190讲。 如果想持续关注Scratch蓝桥真题解读,可以点击《Scratch蓝桥杯历年真题》并订阅合集,…

Python酷库之旅-第三方库Pandas(122)

目录 一、用法精讲 541、pandas.DataFrame.take方法 541-1、语法 541-2、参数 541-3、功能 541-4、返回值 541-5、说明 541-6、用法 541-6-1、数据准备 541-6-2、代码示例 541-6-3、结果输出 542、pandas.DataFrame.truncate方法 542-1、语法 542-2、参数 542-3…

植保无人机是朝阳产业还是夕阳产业?

植保无人机产业是朝阳产业还是夕阳产业,可以从多个维度进行分析: 一、市场需求与增长趋势 市场需求:随着农业现代化的推进和劳动力成本的上升,植保无人机因其高效、安全、节省农药等优势,在农业生产中的应用越来越广…

自闭症能上寄宿学校吗?了解解答与选择

在探讨自闭症儿童教育的话题时,寄宿学校作为一种特殊的教育模式,常常引发家长们的关注与讨论。对于自闭症儿童而言,寄宿学校既是一个充满挑战的新环境,也是一个能够促进他们独立成长与社交融合的重要平台。今天,我们将…

自制数据库空洞率清理工具-C版-03-EasyClean-V1.3(支持南大通用数据库Gbase8a)

目录 一、环境信息 二、简述 三、升级点 四、支持功能 五、空洞率 六、工具流程图 1、流程描述 2、注意点 (1)方法一 (2)方法二 七、清理空洞率流程图 八、安装包下载地址 九、参数介绍 1、命令模板 2、命令样例 3…

【C语言-数据结构】单链表的定义

单链表的定义(实现) 比较顺序表和单链表的物理存储结构就能够清楚地发现二者的区别 用代码定义一个单链表 typedef struct LNode{ElemType data; //每个结点存放一个数据元素struct LNode* next; //指针指向下一个结点 }LNode, *LinkList;//要表示一个…

[JavaEE] TCP协议

目录 一、TCP协议段格式 二、TCP确保传输可靠的机制 2.1 确认应答 2.2 超时重传 2.3 连接管理 2.3.1 三次握手 2.3.2 四次挥手 2.4 滑动窗口 2.4.1 基础知识 2.4.2 两种丢包情况 2.4.2.1 数据报已经抵达,ACK丢包 2.4.2.2 数据包丢包 2.5 流量控制…

【时时三省】(C语言基础)指针笔试题2

山不在高,有仙则名。水不在深,有龙则灵。 ----CSDN 时时三省 笔试题2 这里的0x1是16进制的1 跟十进制的1一样 这道题考察的是:指针类型决定了指针的运算 p是上面结构体的指针 它指向的大小结果是20个字节 指针…

项目第五弹:队列消息管理模块

项目第五弹:队列消息管理模块 一、消息如何组织并管理1.消息结构体2.消息持久化管理模块设计1.数据消息文件名2.临时消息文件名3.对外接口与包含成员 二、自定义应用层协议解决文件读写的粘包问题1.Length-Value协议 三、队列消息管理模块设计1.待确认消息哈希表2.待…

[数据结构]动态顺序表的实现与应用

文章目录 一、引言二、动态顺序表的基本概念三、动态顺序表的实现1、结构体定义2、初始化3、销毁4、扩容5、缩容5、打印6、增删查改 四、分析动态顺序表1、存储方式2、优点3、缺点 五、总结1、练习题2、源代码 一、引言 想象一下,你有一个箱子(静态顺序…

【医学半监督】对比互补掩蔽的自监督预训练半监督心脏图像分割

SELF-SUPERVISED PRE-TRAINING BASED ON CONTRASTIVE COMPLEMENTARY MASKING FOR SEMI-SUPERVISED CARDIAC IMAGE SEGMENTATION 2024 IEEE International Symposium on Biomedical Imaging (ISBI) 摘要: 心脏结构分割对心脏病诊断非常重要,而使用大量注释的深度学习在这项任…

Buck变换器闭环控制,simulink仿真模型(适合初学者学习)

Buck变换器,又称为降压斩波器,是一种常见的DC-DC转换器,广泛应用于电源管理领域。它通过开关元件(通常是MOSFET或BJT)的导通与截止,改变输入电压到负载的平均电压,从而实现电压的降低。在实际应…

harbor私有镜像仓库,搭建及管理

私有镜像仓库 docker-distribution docker的镜像仓库,默认端口号5000 做个仓库,把镜像放里头,用什么服务,起什么容器 vmware公司在docker私有仓库的基础上做了一个web页面,是harbor docker可以把仓库的镜像下载到本地&…

tauri嵌入自定义目录/文件,并在代码中读取文件内容的操作流程

可以看官方文档:Embedding Additional Files | Tauri Apps 在绑定了文件之后,可以在js中访问嵌入的文件或者在rust中读取嵌入的文件内容,详细的配置操作如下。 在src-tauri中创建自定义文件夹或文件,并在在tauri.conf.json中配置…

Java多线程Thread及其原理深度解析

文章目录 1. 实现多线程的方式2. Thread 部分源码2.1. native 方法注册2.2. Thread 中的成员变量2.3. Thread 构造方法与初始化2.4. Thread 线程状态与操作系统状态2.4. start() 与 run() 方法2.5. sleep() 方法2.6. join() 方法2.7. interrupt() 方法 本文参考: 线…

Spring自定义参数解析器

在这篇文章中,我们认识了参数解析器和消息转换器,今天我们来自定义一个参数解析器。 自定义参数解析器 实现HandlerMethodArgumentResolver的类,并注册到Spring容器。 Component//注册到Spring public class UserAr…

Java集合必知必会:热门面试题汇编与核心源码(ArrayList、HashMap)剖析

写在前面 🔥我把后端Java面试题做了一个汇总,有兴趣大家可以看看!这里👉 ⭐️在无数次的复习巩固中,我逐渐意识到一个问题:面对同样的面试题目,不同的资料来源往往给出了五花八门的解释&#…