手写数字识别(分类任务)

1. 导入必要的库

from pathlib import Path
import requests
import pickle
import gzip
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch import optim
  • 功能: 导入所需的库,以便进行文件操作、数据处理、构建神经网络、计算损失以及加载和处理数据。

2. 数据准备

a. 创建数据路径并下载数据集
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():content = requests.get(URL + FILENAME).content(PATH / FILENAME).open("wb").write(content)
  • 功能:
    • 创建存储数据集的路径 data/mnist
    • 从指定 URL 下载 MNIST 数据集,并保存为 mnist.pkl.gz 文件,如果文件已经存在则不重复下载。
b. 加载数据集
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
  • 功能:
    • 使用 gzip 解压缩文件并使用 pickle 加载数据集,得到训练数据 x_train、训练标签 y_train、验证数据 x_valid 和验证标签 y_valid

3. 数据转换为 PyTorch 张量

x_train_test, y_train_test, x_valid_test, y_valid_test = map(torch.tensor, (x_train, y_train, x_valid, y_valid))
  • 功能: 将 NumPy 数组转换为 PyTorch 张量,以便后续的模型训练和计算。

4. 模型构建

a. 定义神经网络结构
class Mnist_NN(nn.Module):def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)  # 第一隐藏层self.hidden2 = nn.Linear(128, 256)  # 第二隐藏层self.out = nn.Linear(256, 10)       # 输出层self.dropout = nn.Dropout(0.5)      # Dropout层
  • 功能:
    • 创建一个名为 Mnist_NN 的神经网络类,继承自 nn.Module
    • 在初始化方法中定义网络的结构,包括两个隐藏层和一个输出层,以及一个 Dropout 层以减少过拟合。
b. 定义前向传播方法
def forward(self, x):x = F.relu(self.hidden1(x))  # 第一层激活x = self.dropout(x)          # Dropoutx = F.relu(self.hidden2(x))  # 第二层激活x = self.dropout(x)          # Dropoutx = self.out(x)              # 输出层return x
  • 功能:
    • 定义前向传播过程,输入数据通过各层进行计算,并应用 ReLU 激活函数和 Dropout。

5. 损失函数和优化器

a. 定义损失函数
loss_func = F.cross_entropy  # 使用交叉熵损失函数
  • 功能: 选择交叉熵损失函数作为模型的损失计算标准。
b. 定义优化器
def get_model():model = Mnist_NN()  # 实例化模型return model, optim.SGD(model.parameters(), lr=0.001)  # 使用 SGD 优化器
  • 功能:
    • 创建模型实例,并定义 SGD 优化器,学习率设置为 0.001。

6. 数据加载

a. 创建数据集和数据加载器
train_ds = TensorDataset(x_train, y_train)  # 创建训练数据集
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)  # 创建训练数据加载器valid_ds = TensorDataset(x_valid, y_valid)  # 创建验证数据集
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)  # 创建验证数据加载器
  • 功能:
    • 将训练和验证数据转换为 TensorDataset 格式,以便于进行批量处理。
    • 创建数据加载器 DataLoader,在训练时打乱训练集的顺序,方便分批次读取数据。

6.5 loss_batch

def loss_batch(model, loss_func, xb, yb, opt=None):  # 计算当前批次的损失loss = loss_func(model(xb), yb)  # 如果提供了优化器 opt,则进行反向传播和优化if opt is not None:loss.backward()  # 计算损失的梯度(反向传播)opt.step()       # 更新模型参数opt.zero_grad()  # 清空梯度,避免累积# 返回当前批次的损失值和该批次数据的大小return loss.item(), len(xb)  

7. 训练过程

a. 定义训练函数
def fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()  # 设置模型为训练模式for xb, yb in train_dl:  # 遍历训练数据loss_batch(model, loss_func, xb, yb, opt)  # 计算并优化损失model.eval()  # 设置模型为评估模式with torch.no_grad():  # 禁用梯度计算losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])  # 计算验证集损失val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)  # 计算加权平均损失print('当前step:' + str(step), '验证集损失:' + str(val_loss))  # 打印损失
  • 功能:
    • 定义 fit 函数,负责训练过程,包括:
      • 将模型设置为训练模式并进行训练。
      • 在验证模式下计算验证集损失并输出。

8. 运行训练

train_dl, valid_dl = get_data(train_ds, valid_ds, bs)  # 获取数据加载器
model, opt = get_model()  # 获取模型和优化器
fit(25, model, loss_func, opt, train_dl, valid_dl)  # 开始训练
  • 功能:
    • 通过调用 get_dataget_model 函数获取数据加载器和模型,然后调用 fit 函数进行训练。

总结

以上步骤展示了从数据准备到模型训练的完整过程。每一步都围绕着构建一个用于手写数字识别的神经网络进行,确保数据的加载、模型的构建和训练过程都能顺利进行。通过这些步骤,最终可以得到一个能够对手写数字进行分类的模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1557172.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

毕设 大数据电影数据分析与可视化系统(源码+论文)

文章目录 0 前言1 项目运行效果2 设计概要3 最后 0 前言 🔥这两年开始毕业设计和毕业答辩的要求和难度不断提升,传统的毕设题目缺少创新和亮点,往往达不到毕业答辩的要求,这两年不断有学弟学妹告诉学长自己做的项目系统达不到老师…

51 单片机最小系统

一、51 单片机最小系统概述 51 单片机最小系统是一个基于 51 单片机的最小化电路系统,它包含了使单片机能够正常工作的最少元件。这个系统主要用于学习和实验目的,帮助学习者在没有复杂电路的情况下快速了解 51 单片机的工作原理,其重要性不…

动态规划算法专题(四):子串、子数组系列

目录 1、最大子数组和 1.1 算法原理 1.2 算法代码 2、环形子数组的最大和 2.1 算法原理 2.2 算法代码 3、乘积最大子数组 3.1 算法原理 3.2 算法代码 4、乘积为正数的最长子数组长度 4.1 算法原理 4.2 算法代码 5、等差数列划分 5.1 算法原理 5.2 算法代码 6、…

java语言基础案例-cnblog

java语言基础案例 象棋口诀 输出 package nb;public class XiangQi {public static void main(String[] args) {char a 马;char b 象;char c 卒;System.out.println(a"走日"b"走田""小"c"一去不复还");} }输出汇款单 package nb…

30 树 · 二叉树

目录 一、树 (一)树的概念与结构 (二)树相关术语 (三)树的表示 (四)树形结构的实际应用场景 二、二叉树 (一)概念与结构 (二)…

【LeetCode】每日一题 2024_10_7 最低加油次数(堆、贪心)

前言 每天和你一起刷 LeetCode 每日一题~ 大家国庆节快乐呀~ LeetCode 启动! 国庆最后一天,力扣还在加油站,怕不是国庆回家路上堵车了 题目:最低加油次数 代码与解题思路 func minRefuelStops(target int, startFuel int, st…

刷题 双指针 滑动窗口

面试经典 150 题 - 双指针 ⭐️125. 验证回文串 学会内部字母处理函数的使用 class Solution { public:bool isPalindrome(string s) {int left 0, right s.size() - 1;while (left < right) {// 处理左边字符if (!isalnum(s[left])) {left;continue;}// 处理右边字符if…

C(十五)函数综合(一)--- 开公司吗?

在这篇文章中&#xff0c;杰哥将带大家 “开公司”。 主干内容部分&#xff08;你将收获&#xff09;&#xff1a;&#x1f449; 为什么要有函数&#xff1f;函数有哪些&#xff1f;怎么自定义函数以及获得函数的使用权&#xff1f;怎么对函数进行传参&#xff1f;函数中变量的…

C语言 | Leetcode C语言题解之第462题最小操作次数使数组元素相等II

题目&#xff1a; 题解&#xff1a; static inline void swap(int *a, int *b) {int c *a;*a *b;*b c; }static inline int partition(int *nums, int left, int right) {int x nums[right], i left - 1;for (int j left; j < right; j) {if (nums[j] < x) {swap(…

Linux 外设驱动 应用 1 IO口输出

从这里开始外设驱动介绍&#xff0c;这里使用的IMX8的芯片作为驱动介绍 开发流程&#xff1a; 修改设备树&#xff0c;配置 GPIO1_IO07 为 GPIO 输出。使用 sysfs 接口或编写驱动程序控制 GPIO 引脚。编译并测试。 这里假设设备树&#xff0c;已经配置好了。不在论述这个问题…

【英语】5. 考研英语语法体系

文章目录 前言句字的成分一、常规句型简单句&#xff08;5 种&#xff09;1. 定义&#xff1a;句子中只包含 *一套主谓结构* 的句子。&#xff08;一个句子只能有一个谓语动词&#xff09;2. 分类 并列句&#xff08;由关联词组成&#xff09;&#xff08;3 种&#xff09;基本…

二分图算法总结 C++实现

总体概念 染色法 基本思路步骤 将所有的边及其相接的边用邻接表存储起来&#xff1b;遍历所有的点&#xff0c;找到未上色的点&#xff1b;用BFS将该点及其相接的点迭代上色&#xff1b;在上述染色步骤中&#xff0c;如果相邻点的颜色相同则无法形成二分图&#xff1b; 题目…

继电保护之电压重动、电压并列和电压切换

实践&#xff1a;以某开关室10kV母联隔离柜为例&#xff1a; ZYQ-824为PT并列装置&#xff0c;装置内包含一系列继电器&#xff0c;用于PT重动及并列。按照装置编号原则&#xff0c;交流电压切换箱一般命名为7n。 ​下图为装置内继电器线圈部分接线&#xff1a; 下图为装置内…

销售秘籍:故事+观点+结论

在销售的浩瀚宇宙中&#xff0c;隐藏着一个不朽的秘诀——利用人类共有的“错失恐惧”&#xff0c;激发客户内心的渴望与行动。正如村上春树所言&#xff0c;每个故事都深深植根于灵魂&#xff0c;而大仲马则揭示&#xff0c;灵魂之眼所见&#xff0c;比肉眼更为长久铭记。 错…

如何将数据从 AWS S3 导入到 Elastic Cloud - 第 1 部分:Elastic Serverless Forwarder

作者&#xff1a;来自 Elastic Hemendra Singh Lodhi 这是多部分博客系列的第一部分&#xff0c;探讨了将数据从 AWS S3 导入 Elastic Cloud 的不同选项。 Elasticsearch 提供了多种从 AWS S3 存储桶导入数据的选项&#xff0c;允许客户根据其特定需求和架构策略选择最合适的方…

Mysql锁机制解读(敲详细)

目录 锁的概念 全局锁 表级锁 表锁 元数据锁 意向锁 锁的概念 全局锁 表级锁 表锁 元数据锁 主要是对未提交事务&#xff0c;修改表结构造成表结构混乱&#xff0c;进行控制。 在不涉及表结构变化的情况下,元素锁可以忽略。 意向锁 避免有行级锁影响加表级锁&#xff0…

YoloV9改进策略:BackBone改进|CAFormer在YoloV9中的创新应用,显著提升目标检测性能

摘要 在目标检测领域,模型性能的提升一直是研究者和开发者们关注的重点。近期,我们尝试将CAFormer模块引入YoloV9模型中,以替换其原有的主干网络,这一创新性的改进带来了显著的性能提升。 CAFormer,作为MetaFormer框架下的一个变体,结合了深度可分离卷积和普通自注意力…

​.NET一款反序列化执行命令的白名单工具

01阅读须知 此文所提供的信息只为网络安全人员对自己所负责的网站、服务器等&#xff08;包括但不限于&#xff09;进行检测或维护参考&#xff0c;未经授权请勿利用文章中的技术资料对任何计算机系统进行入侵操作。利用此文所提供的信息而造成的直接或间接后果和损失&#xf…

Unity_Obfuscator Pro代码混淆工具_学习日志

Unity_Obfuscator Pro代码混淆工具_学习日志 切勿将密码或 API 密钥存储在您附带的应用程序内。 混淆后的热更新暂时没有想到怎么办 Obfuscator 文档 https://docs.guardingpearsoftware.com/manual/Obfuscator/Description.html商店链接Obfuscator Pro&#xff08;大约$70&a…

sqli-labs靶场第三关less-3

sqli-labs靶场第三关less-3 1、确定注入点 http://192.168.128.3/sq/Less-3/?id1 http://192.168.128.3/sq/Less-3/?id2 有不同回显&#xff0c;判断可能存在注入&#xff0c; 2、判断注入类型 输入 http://192.168.128.3/sq/Less-3/?id1 and 11 http://192.168.128.3/sq/L…