手写数字识别案例分析(torch,深度学习入门)

在人工智能和机器学习的广阔领域中,手写数字识别是一个经典的入门级问题,它不仅能够帮助我们理解深度学习的基本原理,还能作为实践编程和模型训练的良好起点。本文将带您踏上手写数字识别的深度学习之旅,从数据集介绍、模型构建到训练与评估,一步步深入探索。

一、引言

手写数字识别(Handwritten Digit Recognition)是指通过计算机程序自动识别手写数字的过程。最著名的手写数字数据集之一是MNIST(Modified National Institute of Standards and Technology database),它包含了大量的手写数字图片,每张图片都被标记了对应的数字(0-9)。这个数据集成为了初学者学习深度学习,尤其是卷积神经网络(CNN)的首选。

二、MNIST数据集简介

MNIST数据集由60,000个训练样本和10,000个测试样本组成,每个样本都是一张28x28像素的灰度图像,代表了一个手写数字。这些图像已经被归一化并居中在图像中心,使得数字不会受到位置变化的影响。

 PyTorch 和 torchvision 库来下载并准备 MNIST 数据集,包括训练集和测试集

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor'''下载训练数据集(图片+标签)'''
training_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor()
)
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor()
)
  1. 打印设备信息:您的代码已经很好地检查了CUDA和MPS(针对Apple M系列芯片)的可用性,并设置了相应的设备。但是,在打印设备信息时,有一个小错误在字符串格式化上。您需要确保在字符串中正确地包含变量名。

  2. 打印数据形状:您已经正确地设置了DataLoader并打印了测试数据集中的一个批次的数据和标签的形状。这是一个很好的实践,可以帮助您了解数据的维度。

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)  # 通常训练时会打乱数据  
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)  # 测试时不需要打乱数据  # 打印测试数据集的一个批次的数据和标签的形状  
for x, y in test_dataloader:  print(f"Shape of x [N,C,H,W]: {x.shape}")  # 注意这里的x是图像,但MNIST是灰度图,所以C=1  print(f"Shape of y: {y.shape}, {y.dtype}")  # y是标签,通常是一维的,且为long类型  break  # 判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU  
device = "cuda" if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else "cpu")  
print(f"Using {device} device")  # 确保在字符串中正确地包含了变量名  

三、训练模型选择

一、创建一个具有多个隐藏层的神经网络,这些层都使用了nn.Linear来定义全连接层,并使用torch.sigmoid作为激活函数。

import torch  
import torch.nn as nn  class NeuralNetwork(nn.Module):  def __init__(self):  super().__init__()  self.flatten = nn.Flatten()  self.hidden1 = nn.Linear(28 * 28, 256)  self.relu1 = nn.ReLU()  self.hidden2 = nn.Linear(256, 128)  self.relu2 = nn.ReLU()  self.hidden3 = nn.Linear(128, 64)  self.relu3 = nn.ReLU()  self.hidden4 = nn.Linear(64, 32)  self.relu4 = nn.ReLU()  self.out = nn.Linear(32, 10)  # 输出层对应于10个类别的得分  def forward(self, x):x = self.flatten(x)x = self.hidden1(x)x = torch.sigmoid(x)x = self.hidden2(x)x = torch.sigmoid(x)x = self.hidden3(x)x = torch.sigmoid(x)x = self.hidden4(x)x = torch.sigmoid(x)x = self.out(x)return x model = NeuralNetwork().to(device)  
print(model)  

二、定义了一个具有三个卷积层的CNN,每个卷积层后面都跟着ReLU激活函数,前两个卷积层后面还跟着最大池化层。最后,通过一个全连接层将卷积层的输出转换为10个类别的得分。

import torch  
import torch.nn as nn  class CNN(nn.Module):  def __init__(self):  super(CNN, self).__init__()  self.conv1 = nn.Sequential(  nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),  nn.ReLU(),  nn.MaxPool2d(kernel_size=2),  )  self.conv2 = nn.Sequential(  nn.Conv2d(16, 32, 5, 1, 2),  nn.ReLU(),  nn.Conv2d(32, 32, 5, 1, 2),  nn.ReLU(),  nn.MaxPool2d(2),  )  self.conv3 = nn.Sequential(  nn.Conv2d(32, 64, 5, 1, 2),  nn.ReLU(),  )  self.out = nn.Linear(64 * 7 * 7, 10)  # 确保这里的输入特征数与卷积层输出后的特征数相匹配  def forward(self, x):  x = self.conv1(x)  x = self.conv2(x)  x = self.conv3(x)  # 输出应为(batch_size, 64, 7, 7)  x = x.view(x.size(0), -1)  # 展平操作,输出为(batch_size, 64*7*7)  output = self.out(x)  return output  model = CNN().to(device)  
print(model)
  • in_channels=1:这指定了输入图像的通道数。

  • out_channels=16:这指定了卷积操作后输出的通道数,也就是卷积核(或称为滤波器)的数量。

  • kernel_size=5:这定义了卷积核的大小。

  • stride=1:这指定了卷积核在输入数据上滑动的步长。

  • padding=2:这定义了要在输入数据周围添加的零填充(zero-padding)的数量。

四、处理数据集和测试集

训练集处理:

def train(dataloader, model, loss_fn, optimizer):  model.train()  # 将模型设置为训练模式  batch_size_num = 1  # 这不是标准的用法,但在这里用作计数已处理批次的数量  for x, y in dataloader:  # 遍历数据加载器中的每个批次  x, y = x.to(device), y.to(device)  # 将数据和标签移动到指定的设备(如GPU)  pred = model(x)  # 通过模型进行前向传播  loss = loss_fn(pred, y)  # 计算预测和真实标签之间的损失  optimizer.zero_grad()  # 清除之前的梯度  loss.backward()  # 反向传播,计算当前梯度  optimizer.step()  # 更新模型的权重  loss_value = loss.item()if batch_size_num % 200 == 0:print(f"{loss_value:>7f}[number:{batch_size_num}]")#打印结果batch_size_num += 1  # 增加已处理批次的数量

测试集处理:

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for x, y in dataloader:x, y = x.to(device), y.to(device)pred = model(x)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batchescorrect /= sizeprint(f'Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}')

模型训练:

loss_fn = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.001)epochs = 10
for t in range(epochs):print(f"-----------------------------------------------\nepcho{t+1}")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)train(train_dataloader,model,loss_fn,optimizer)
test(test_dataloader,model, loss_fn)

结果:

神经网络:

cnn:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/148140.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

U盘格式化了怎么办?这4个工具能帮你恢复数据。

如果你思维U盘被格式化了,也不用太过担心,其实里面的数据并没有被删除,只是被标记为了可覆盖的状态。只要我们及时采取正确的数据恢复措施,就有很大的机会可以将数据找回。比如使用专业得的数据恢复软件,我也可以跟大家…

Keysight 下载信源 Visa 指令

用于传输原始的IQ数据 file.wiq 或者 file.bin wave_bin:bytes with open("./WaveForm.wfm","rb") as f:wave_bin f.read()log.info("File:WaveForm.wfm Size:%d Bytes"%len(wave_bin)) IMPL.sendCommand(":MEM:DATA \"WFM1:FILE1\&q…

使用 IntelliJ IDEA 连接到达梦数据库(DM)

前言 达梦数据库是一款国产的关系型数据库管理系统,因其高性能和稳定性而被广泛应用于政府、金融等多个领域。本文将详细介绍如何在 IntelliJ IDEA 中配置并连接到达梦数据库。 准备工作 获取达梦JDBC驱动: 访问达梦在线服务平台网站或通过其他官方渠道…

远程升级又双叒叕失败?背后原因竟然是。。。

最近又遇到了远程升级接连失败的情况,耐心和信心都备受折磨! 事情是这样的:有客户反馈在乡村里频繁出现掉线的情况,不敢耽搁,赶紧联系小伙伴排查测试,最后发现,只有去年某一批模块在当下环境才…

Redis:持久化

1. Redis持久化机制 Redis 支持 RDB 和 AOF 两种持久化机制,持久化功能有效地避免因进程退出造成数据丢失问题, 当下次重启时利⽤之前持久化的文件即可实现数据恢复。 2.RDB RDB 持久化是把当前进程数据⽣成快照保存到硬盘的过程,触发 RDB…

c++类中的特殊函数

My_string.cpp #include <iostream> #include "my_string.h" #include <string.h> using namespace std; My_string::My_string():size(15) { this->ptr new char[size] ; this->ptr[0]\0;//串为空串 this->len 0; }; My_string::My_str…

如何使用ssm实现疫苗预约系统+vue

TOC ssm673疫苗预约系统vue 第1章 绪论 1.1选题动因 当前的网络技术&#xff0c;软件技术等都具备成熟的理论基础&#xff0c;市场上也出现各种技术开发的软件&#xff0c;这些软件都被用于各个领域&#xff0c;包括生活和工作的领域。随着电脑和笔记本的广泛运用&#xff…

Django 数据库配置以及字段设置详解

配置PostGre 要在 Django 中配置连接 PostgreSQL 数据库&#xff0c;并创建一个包含“使用人”和“车牌号”等字段的 Car 表 1. 配置 PostgreSQL 数据库连接 首先&#xff0c;在 Django 项目的 settings.py 中配置 PostgreSQL 连接。 修改 settings.py 文件&#xff1a; …

数据结构篇--折半查找【详解】

折半查找也叫做二分查找或者对数查找&#xff0c;是一种在有序数组中查找特定元素的查找算法。 折半查找的算法步骤如下&#xff1a; 将目标关键字key与数组中的中间元素比较&#xff0c;若相等则查找成功。key大于中间元素&#xff0c;就到数组中大于中间元素的部分进行查找&…

超详细超实用!!!AI编程之cursor编写官网新增轮播效果(三)

云风网 云风笔记 云风知识库 index.html内容如下&#xff1a; <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"&g…

AI绘画,让AI穿上指定衣服(附工具)

前言 AI绘画的商业应用前景非常广阔&#xff0c;用stable diffusion进行AI绘画时&#xff0c;不仅可以很容易的制作真实人物图片&#xff0c;还能让AI穿上自己指定的衣服&#xff0c;对于做服装生意的电商&#xff0c;可以节省雇佣模特的时间和费用&#xff0c;有效降低成本&a…

JEDEC DDR3 SRAM standard

DDRDouble Data Rate双倍速率,DDR SDRAM双倍速率同步动态随机存储器&#xff0c;人们习惯称为DDR&#xff0c;其中&#xff0c;SDRAM 是Synchronous Dynamic Random Access Memory的缩写&#xff0c;即同步动态随机存取存储器。而DDR SDRAM是Double Data Rate SDRAM的缩写&…

【论文阅读笔记】TOOD: Task-aligned One-stage Object Detection

论文代码&#xff1a;https://github.com/fcjian/TOOD 文章目录 论文小结论文简介论文方法Task-aligned Head&#xff08;T-Head&#xff09;T-Head伪代码解释 Task Alignment Learning&#xff08;TAL&#xff09;Task-aligned Sample AssignmentTask-aligned Loss 论文实验消…

思维商业篇(5)—发展趋势分析

思维商业篇(5)—发展趋势分析 核心理论 巴菲特曾在《滚雪球》一书中提到他的投资之道其实非常简单&#xff0c;可以总结为两句话&#xff1a;找到足够长的雪道&#xff0c;找到足够湿的雪球。 而发展趋势的分析&#xff0c;正好可以借助巴菲特的这个滚雪球理论。 足够长的雪…

内存和管理

在 C 中&#xff0c;对象拷贝时编译器可能会进行一些优化&#xff0c;以提高程序的性能。 一种常见的优化是“返回值优化&#xff08;Return Value Optimization&#xff0c;RVO&#xff09;”和“具名返回值优化&#xff08;Named Return Value Optimization&#xff0c;NRV…

“明月寄情,文化共融”iEnglish助力青少年用英语讲述中国故事

在全球化日益加深的今天&#xff0c;文化的交流与融合成为了不可阻挡的趋势。中秋节&#xff0c;这一承载着中华民族深厚文化底蕴与家国情怀的传统节日&#xff0c;正通过新的方式走向世界舞台。今年中秋&#xff0c;在斐济、澳大利亚、法国等多个国家的中秋文化活动中&#xf…

电脑桌面文件太多太杂?电脑管理软件一键整理,强迫症福音!

电脑桌面文件太多太杂&#xff1f;随着工作量的增加和信息的不断累积&#xff0c;许多人的电脑桌面上往往堆满了各式各样的文件和文件夹&#xff0c;显得杂乱无章。这种“桌面乱象”不仅影响了工作效率&#xff0c;还可能给心理带来不必要的压力&#xff0c;尤其对于那些有强迫…

【RTT-Studio】详细使用教程十六:DAC7311外部DAC使用

文章目录 一、简介二、驱动程序三、DAC设置注册四、完整代码五、测试验证 一、简介 8 位 DAC5311、10 位 DAC6311 和 12 位 DAC7311 (DACx311) 是低功耗、单通道、电压输出数模转换器 (DAC)。DACx311 在正常工作状态下具有低功耗&#xff08;5V 时为 0.55mW&#xff0c;断电模式…

【Qt笔记】QStackedWidget控件详解

目录 引言 一、基础功能 二、属性设置 2.1 属性介绍 2.2 代码示例 2.3 代码解析 三、常用API 3.1 添加子部件 3.2 插入子部件 3.3 移除子部件 3.4 设置当前页面索引值 3.5 设置当前显示子部件 3.6 返回索引处子部件指针 3.7 返回子部件索引值 四、信号与槽 4.…

蓝牙AOA基站助力打造智慧医院管理系统

随着科技的飞速发展&#xff0c;智慧医院的概念逐渐深入人心。其中&#xff0c;蓝牙AOA&#xff08;到达角&#xff09;定位技术以其高精度、低功耗、低成本等优势&#xff0c;在智慧医院建设中扮演着重要角色。本文将深入探讨蓝牙AOA基站如何助力智慧医院的建设与发展。 一、蓝…