BP实战minist数据集

目录

前言

一、MNIST数据集介绍和加载

1.MNIST数据集介绍

2.加载数据集MNIST数据集

二、构建 BP 网络模型

1.神经网络结构图示

2.BP 网络模型代码解释

三、定义和训练BP 网络模型

四、训练结果

总结


前言

在当今人工智能与机器学习飞速发展的时代,神经网络作为一种强大的工具,在图像识别、自然语言处理等众多领域都展现出了卓越的性能。其中,BP(Back Propagation)神经网络作为一种经典的前馈神经网络,以其简单的结构和高效的学习能力,一直备受研究者和开发者的青睐。

本次实战,我们将目光聚焦于著名的 MNIST 数据集。MNIST 数据集由手写数字的图像组成,它具有规模适中、问题清晰等特点,非常适合作为神经网络的入门实战案例。

通过使用 BP 神经网络对 MNIST 数据集中的手写数字进行识别,我们将深入了解神经网络的工作原理、训练过程以及在实际问题中的应用。 在这个过程中,我们将逐步探索如何构建 BP 神经网络模型、如何加载和预处理数据集、如何进行模型的训练和优化,以及如何评估模型的性能。


一、MNIST数据集介绍和加载

1.MNIST数据集介绍

MNIST 数据集是机器学习领域中广泛使用的一个基准数据集,主要用于图像识别和数字分类任务。

MNIST 数据集由手写数字的图像组成,这些数字是从 0 到 9 的整数。它包含了 70,000 张灰度图像,其中 60,000 张用于训练,10,000 张用于测试。每一张图像都是 28×28 像素的,呈现出不同人书写的数字形态,具有一定的多样性和复杂性。

该数据集的图像是灰度的且数字居中,这在一定程度上减少了预处理的工作量并加快了模型的运行速度。其简洁明了的特点使得 MNIST 成为初学者进入机器学习和深度学习领域的理想选择,许多经典的算法和模型都首先在这个数据集上进行验证和优化。 MNIST 数据集的广泛应用推动了图像识别技术的发展,研究人员通过在这个数据集上不断尝试新的算法和改进模型结构,为更复杂的图像识别任务奠定了基础。

2.加载数据集MNIST数据集

# MNIST 包含 70,000 张手写数字图像: 60,000 张用于训练,10,000 张用于测试。
# 图像是灰度的,28×28 像素的,并且居中的,以减少预处理和加快运行。
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])# 使用 torchvision 读取数据
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 使用 DataLoader 加载数据
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

 首先定义了一个数据转换 transform,包括将图像转换为张量并进行归一化处理。然后使用 torchvision.datasets.MNIST 加载 MNIST 数据集,分别设置 train=True 和 train=False 来获取训练集和测试集。最后使用 torch.utils.data.DataLoader 将数据集包装成数据加载器,设置了批量大小为 64,训练集进行随机打乱,测试集不打乱。

二、构建 BP 网络模型

# 第 1 步:构建 BP 网络模型
class BPNetwork(torch.nn.Module):def __init__(self):super(BPNetwork, self).__init__()"""定义第一个线性层,输入为图片(28x28),输出为第一个隐层的输入,大小为 128。"""self.linear1 = torch.nn.Linear(28 * 28, 128)# 在第一个隐层使用 ReLU 激活函数self.relu1 = torch.nn.ReLU()"""定义第二个线性层,输入是第一个隐层的输出,输出为第二个隐层的输入,大小为 64。"""self.linear2 = torch.nn.Linear(128, 64)# 在第二个隐层使用 ReLU 激活函数self.relu2 = torch.nn.ReLU()"""定义第三个线性层,输入是第二个隐层的输出,输出为输出层,大小为 10"""self.linear3 = torch.nn.Linear(64, 10)# 最终的输出经过 softmax 进行归一化self.softmax = torch.nn.LogSoftmax(dim=1)def forward(self, x):"""定义神经网络的前向传播x: 图片数据, shape 为(64, 1, 28, 28)"""# 首先将 x 的 shape 转为(64, 784)x = x.view(x.shape[0], -1)# 接下来进行前向传播x = self.linear1(x)x = self.relu1(x)x = self.linear2(x)x = self.relu2(x)x = self.linear3(x)x = self.softmax(x)# 上述一串,可以直接使用 x = self.model(x) 代替。return x


1.神经网络结构图示

 

层名输入大小输出大小
输入层(展平后的图片)784-
第一个隐藏层784128
第二个隐藏层12864
输出层6410

2.BP 网络模型代码解释

定义了一个名为 BPNetwork 的类,继承自 torch.nn.Module,用于构建一个三层的神经网络模型。在 __init__ 方法中定义了三个线性层和两个 ReLU 激活函数以及一个对数 softmax 函数用于输出层的归一化。在 forward 方法中定义了神经网络的前向传播过程,首先将输入的图片数据展平为一维向量,然后依次通过三个线性层和激活函数,最后经过 softmax 归一化得到输出。

三、定义和训练BP 网络模型

model = BPNetwork()
# criterion = torch.nn.MSELoss()
criterion = torch.nn.NLLLoss()                                            # 定义 loss 函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)   # 定义优化器epochs = 15                               # 一共训练 15 轮
for i in range(epochs):running_loss = 0                     # 本轮的损失值for images, labels in trainloader:# 前向传播获取预测值output = model(images)# 计算损失loss = criterion(output, labels)# 进行反向传播loss.backward()# 更新权重optimizer.step()# 清空梯度optimizer.zero_grad()# 累加损失running_loss += loss.item()# 一轮循环结束后打印本轮的损失函数print("Epoch {} - Training loss: {}".format(i, running_loss / len(trainloader)))

 这里首先创建了一个 BPNetwork 模型实例,然后定义了损失函数为负对数似然损失(torch.nn.NLLLoss),优化器为随机梯度下降(torch.optim.SGD),设置了学习率为 0.003 和动量为 0.9。接着设置了训练轮数为 15。在训练循环中,遍历训练数据加载器,进行前向传播得到预测值,计算损失,然后进行反向传播、更新权重和清空梯度。最后打印每一轮的训练损失。

四、测试模型

examples = enumerate(testloader)
batch_idx, (imgs, labels) = next(examples)fig = plt.figure()
for i in range(64):logps = model(imgs[i])                    # 通过模型进行预测probab = list(logps.detach().numpy()[0])  # 将预测结果转为概率列表。[0]是取第一张照片的 10 个数字的概率列表(因为一次只预测一张照片)pred_label = probab.index(max(probab))    # 取最大的 index 作为预测结果img = torch.squeeze(imgs[i])img = img.numpy()plt.subplot(8, 8, i + 1)plt.tight_layout()plt.imshow(img, cmap='gray', interpolation='none')plt.title("预测值: {}".format(pred_label))plt.xticks([])plt.yticks([])plt.show()

首先从测试数据加载器中获取一批数据,然后创建一个 matplotlib 的图形对象。接着在一个循环中,对这批数据中的前 64 张图像进行预测,将预测结果转换为概率列表,取最大概率的索引作为预测标签。同时将图像数据转换为 numpy 数组并进行展示,在图像上标注预测值。最后显示绘制的图形,展示测试结果。 

五、训练结果

 


总结

用 BP 神经网络对 MNIST 数据集进行实战。首先构建了一个包含两个隐藏层的三层神经网络模型,使用全连接层和 ReLU 激活函数,输出层经 softmax 归一化。接着加载 MNIST 数据集并预处理,用数据加载器进行高效加载。然后定义损失函数和优化器进行模型训练,通过前向传播、计算损失、反向传播等步骤更新权重。最后在测试集上进行预测,展示图像及预测结果。此实战有助于理解神经网络原理和训练过程,为深入学习提供基础和经验。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1557395.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

SPI主从通讯稳定性之解决方法

在使用SPI通讯时,将硬件SPI用作主机的比较多,程序设计也比较容易,但是,若将硬件SPI用作从机了,网上的案例就比较少了,因为大家都有一个习惯,实在实现不了,就用软件模拟SPI来完成通讯…

函数式接口在Java中的应用与实践

1. 引言 函数式接口是Java 8引入的一个概念,它是指只有一个抽象方法的接口。函数式接口可以被用作lambda表达式的目标类型。在函数式接口中,除了抽象方法外,还可以有默认方法和静态方法。 函数式接口的引入是为了支持函数式编程&#xff0c…

Java项目: 基于SpringBoot+mybatis+maven+vue网上摄影工作室(含源码+数据库+任务书+毕业论文)

一、项目简介 本项目是一套基于SpringBootmybatismavenmavenvue网上摄影工作室 包含:项目源码、数据库脚本等,该项目附带全部源码可作为毕设使用。 项目都经过严格调试,eclipse或者idea 确保可以运行! 该系统功能完善、界面美观、…

【算法】博弈论(C/C++)

个人主页:摆烂小白敲代码 创作领域:算法、C/C 持续更新算法领域的文章,让博主在您的算法之路上祝您一臂之力 欢迎各位大佬莅临我的博客,您的关注、点赞、收藏、评论是我持续创作最大的动力 目录 博弈论: 1. Grundy数…

【MySQL】-- 表的操作

文章目录 1. 查看所有表1.1 语法 2. 创建表2.1 语法2.2 示例2.3 表在磁盘上对应的文件 3. 查看表结构3.1 语法3.2 示例 4. 查看创建表的语句5. 修改表5.1 语法5.2 示例5.2.1 向表中添加一列5.2.2 修改某列的长度5.2.3 重命名某列5.2.4 删除某个字段5.2.5 修改表名 6. 删除表6.1…

不入耳开放式耳机哪个品牌好?开放式耳机排行榜10强推荐!

不入耳开放式耳机哪个品牌好?开放式耳机排行榜10强推荐! 随着开放式耳机的日益流行,市场上的选择愈发多样,这有时会让消费者在挑选时感到迷茫,不知道哪个牌子的开放式耳机最好。为解决这一困扰,我精心筛选…

社区圈子系统 圈子社区系统 兴趣社区圈子论坛系统 圈子系统源码圈子系统的适用领域有哪些?如何打造自己的圈子圈子系统有哪些常见问题

社区圈子系统 圈子社区系统 兴趣社区圈子论坛系统 圈子系统源码圈子系统的适用领域有哪些?如何打造自己的圈子圈子系统有哪些常见问题 圈子系统的适用领域 圈子系统的适用领域广泛,涵盖了多个行业和场景,包括但不限于以下几个方面&#xff1…

Label Studio 半自动化标注

引言 Label Studio ML 后端是一个 SDK,用于包装您的机器学习代码并将其转换为 Web 服务器。Web 服务器可以连接到正在运行的 Label Studio 实例,以自动执行标记任务。我们提供了一个示例模型库,您可以在自己的工作流程中使用这些模型,也可以根据需要进行扩展和自定义。 1…

springboot邮件群发功能的开发与优化策略?

springboot邮件配置指南?如何实现spring邮件功能? SpringBoot框架因其简洁、高效的特点,成为了开发邮件群发功能的理想选择。AokSend将深入探讨SpringBoot邮件群发功能的开发过程,并提出一系列优化策略,以确保邮件发送…

常见的图像处理算法:均值滤波----mean filter

一、什么是均值滤波 均值滤波器是一种常见的图像滤波器,是典型的线性滤波算法。其基本原理是用一个给定的窗口覆盖图像中的每一个像素点,将窗口内的像素值求平均值,然后用这个平均值代替原来的像素值。均值滤波器可以去除噪声、平滑图像、减少…

代码随想录算法训练营Day28 | 39. 组合总和、40.组合总和Ⅱ、131.分割回文串

目录 39. 组合总和 40.组合总和Ⅱ 131.分割回文串 39. 组合总和 题目 39. 组合总和 - 力扣(LeetCode) 给你一个 无重复元素 的整数数组 candidates 和一个目标整数 target ,找出 candidates 中可以使数字和为目标数 target 的 所有 不…

路径规划关于地图的整理

路径规划离不开地图,其中真实地图,栅格地图和RVIZ之间Grid显示之间很混乱,还有各个原点位置显示,不弄清发现map在rviz里显示老是偏的,专门学习记录一下。 RVIZ里Grid的全局坐标系原点,在默认在栅格中间&am…

软考学习笔记

学习资料: 数据库关系模式的范式总结_关系模式范式-CSDN博客 【范式】五大范式所解决的问题及说明_天空_新浪博客 (sina.com.cn) 求函数依赖: 根据函数依赖求候选码_证明存在部分依赖属于候选码-CSDN博客 关系范式: 1NF:若关…

xss-labs靶场第二关测试报告

目录 一、测试环境 1、系统环境 2、使用工具/软件 二、测试目的 三、操作过程 1、注入点寻找 2、使用hackbar进行payload测试 3、绕过结果 四、源代码分析 五、结论 一、测试环境 1、系统环境 渗透机:本机(127.0.0.1) 靶 机:本机(127.0.0.…

刷题 二叉树

面试经典 150 题 - 二叉树 104. 二叉树的最大深度 广度优先遍历 class Solution { public:// 广度优先遍历int maxDepth(TreeNode* root) {if (root nullptr) return 0;queue<TreeNode*> que;que.push(root);int result 0;while (!que.empty()) {result;int num que…

看《米小圈动画汉字》轻松掌握汉字的起源、演变和应用!

在这个充满探索与发现的年纪&#xff0c;孩子刚刚从幼儿园的温暖怀抱中走出&#xff0c;踏入了小学的大门。对于每一个小学生而言&#xff0c;这不仅是一个新环境的适应&#xff0c;更是知识大门的开启。汉字&#xff0c;这一古老而美丽的文字&#xff0c;承载着丰富的文化与历…

【JAVA基础】集合类之Hash的原理及应用

近期几期内容都是围绕该体系进行知识讲解&#xff0c;以便于同学们学习Java集合篇知识能够系统化而不零散。 本文将介绍HashSet的基本概念&#xff0c;功能特点&#xff0c;使用方法&#xff0c;以及优缺点分析和应用场景案例。 一、概念 HashSet是 Java 集合框架中的一个重…

思科防火墙:ASA中Object-group在ACL中的应用

一、实验拓扑&#xff1a; 二、实验要求&#xff1a; 先定义几个小的&#xff0c;然后用大的包在一起&#xff1b;打包在一起&#xff0c;这就是所谓的嵌套&#xff0c;嵌套在编程里是很长用的东西&#xff0c;叫做Object-group&#xff1b; Object-group比较强大&#xff0c;可…

【JAVA开源】基于Vue和SpringBoot的师生共评作业管理系统

本文项目编号 T 071 &#xff0c;文末自助获取源码 \color{red}{T071&#xff0c;文末自助获取源码} T071&#xff0c;文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 查…

学习threejs,模拟窗户光源

&#x1f468;‍⚕️ 主页&#xff1a; gis分享者 &#x1f468;‍⚕️ 感谢各位大佬 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍⚕️ 收录于专栏&#xff1a;threejs gis工程师 文章目录 一、&#x1f340;前言二、&#x1f340;绘制任意字体模型…