使用 TensorFlow 实现 ZFNet 进行 MNIST 图像分类

        ZFNet(ZF-Net)是由 Matthew Zeiler 和 Rob Fergus 提出的卷积神经网络架构,它在图像分类任务中取得了显著的效果。它在标准卷积神经网络(CNN)的基础上做了一些创新,例如优化了卷积核大小和池化策略,使得网络在处理图像时表现得更加高效。

        本文将详细介绍如何使用 TensorFlow 2.x 实现 ZFNet,在 MNIST 数据集上进行图像分类,并将训练部分和测试部分分开进行讲解。

1. 环境准备

        首先,我们需要确保已安装 TensorFlow 和其他相关库。在命令行中执行以下命令进行安装:

pip install tensorflow matplotlib numpy

2. 训练部分:构建和训练 ZFNet 模型

        在训练部分,我们将加载 MNIST 数据集,构建 ZFNet 模型,并在 GPU 或 CPU 上进行训练。

2.1 加载并预处理 MNIST 数据集

        MNIST 数据集包含了 70,000 张手写数字图像,训练集包含 60,000 张,测试集包含 10,000 张。在加载数据后,我们需要对数据进行预处理:标准化和调整大小。

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from zfnet import create_zfnet_model  # 从 zfnet.py 导入模型创建函数def prepare_data():"""准备 MNIST 数据集并进行预处理:return: 训练集和测试集的图像及标签"""# 加载 MNIST 数据集(x_train, y_train), (x_test, y_test) = mnist.load_data()# 数据预处理:标准化、调整大小、添加维度x_train = x_train.astype('float32') / 255.0x_test = x_test.astype('float32') / 255.0# 调整图像大小并添加额外维度 (32x32, 1通道)x_train = tf.image.resize(x_train[..., tf.newaxis], (32, 32))x_test = tf.image.resize(x_test[..., tf.newaxis], (32, 32))# 确保数据类型是 float32x_train = tf.cast(x_train, tf.float32)x_test = tf.cast(x_test, tf.float32)# 类别标签 one-hot 编码y_train = to_categorical(y_train, 10)y_test = to_categorical(y_test, 10)return x_train, y_train, x_test, y_test

解释:
  1. 标准化:图像像素值从 [0, 255] 转换为 [0, 1],有助于加速网络训练并提高稳定性。
  2. 调整图像大小:由于 ZFNet 网络需要 32x32 的输入图像,所以我们将图像大小调整为 32x32。
  3. One-Hot 编码:标签数据转换为 One-Hot 编码格式,以便与神经网络输出匹配。

2.2 创建 ZFNet 模型

        ZFNet 是一个深度卷积神经网络,它的设计关注如何高效地提取图像特征。我们通过以下代码来构建 ZFNet 模型。

from tensorflow.keras import layers, modelsdef create_zfnet_model(input_shape=(32, 32, 1), num_classes=10):"""创建 ZFNet 模型。参数:- input_shape: 输入图像的形状,默认 (32, 32, 1)。- num_classes: 类别数目,默认 10。返回:- 返回构建好的模型。"""model = models.Sequential()# 使用 Input 层显式定义输入形状model.add(layers.Input(shape=input_shape))  # 显式指定输入形状# 特征提取部分model.add(layers.Conv2D(64, (7, 7), activation='relu', strides=2, padding='same'))model.add(layers.MaxPooling2D(pool_size=(3, 3), strides=2, padding='same'))model.add(layers.Conv2D(128, (5, 5), activation='relu', padding='same'))model.add(layers.MaxPooling2D(pool_size=(3, 3), strides=2, padding='same'))model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))model.add(layers.Conv2D(512, (3, 3), activation='relu', padding='same'))# 扁平化层model.add(layers.Flatten())# 全连接层model.add(layers.Dense(1024, activation='relu'))model.add(layers.Dropout(0.5))# 输出层model.add(layers.Dense(num_classes, activation='softmax'))return model

解释:
  • 卷积层:通过多个卷积层提取图像的空间特征。ZFNet 采用不同大小的卷积核(如 7x7、5x5 和 3x3),通过优化的卷积结构捕捉更多层次的图像信息。
  • 池化层:最大池化层用于减少图像尺寸,并使特征保持重要信息。
  • 全连接层:通过扁平化和全连接层进一步处理特征,并输出分类结果。

2.3 编译与训练模型

        在训练之前,我们需要编译模型并选择优化器和损失函数。然后,调用 fit 函数开始训练。

def compile_model(model):"""编译模型:param model: 待编译的模型:return: 已编译的模型"""model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])return modeldef train_model(model, x_train, y_train, x_test, y_test, device, epochs=5, batch_size=128):"""在指定设备上训练模型:param model: 训练的模型:param x_train: 训练集图像:param y_train: 训练集标签:param x_test: 测试集图像:param y_test: 测试集标签:param device: 设备:param epochs: 训练轮数:param batch_size: 批处理大小"""with tf.device(device):model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test, y_test))

解释:
  • 优化器:我们使用 Adam 优化器,它具有自适应学习率,非常适合深度学习任务。
  • 损失函数categorical_crossentropy 用于多分类问题。
  • 训练:通过 model.fit() 函数训练模型,并在每个 epoch 后使用测试数据进行验证。

3. 测试部分:评估模型并进行预测

        一旦训练完成,我们将评估模型在测试集上的表现,并可视化其预测结果。

3.1 评估模型

def evaluate_model(model, x_test, y_test):"""评估模型在测试集上的表现:param model: 训练好的模型:param x_test: 测试集图像:param y_test: 测试集标签:return: 测试集上的损失和准确率"""test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)print(f"Test accuracy: {test_acc}")return test_loss, test_acc

解释:
  • 使用 evaluate() 方法评估模型的性能,返回模型的损失和准确率。

3.2 可视化预测结果

def visualize_predictions(model, x_test, y_test, num_images=6):"""可视化模型对多张测试图片的预测结果:param model: 训练好的模型:param x_test: 测试集图像:param y_test: 测试集标签:param num_images: 显示图像的数量"""predictions = model.predict(x_test[:num_images])predicted_labels = np.argmax(predictions, axis=1)actual_labels = np.argmax(y_test[:num_images], axis=1)# 绘制结果fig, axes = plt.subplots(2, 3, figsize=(10, 7))axes = axes.ravel()for i in range(num_images):ax = axes[i]# 将 Tensor 转换为 NumPy 数组,并使用 reshapeimg = x_test[i].numpy().reshape(32, 32)  # 这里调用 .numpy() 将 Tensor 转换为 NumPy 数组ax.imshow(img, cmap='gray')ax.set_title(f"Pred: {predicted_labels[i]} | Actual: {actual_labels[i]}")ax.axis('off')plt.tight_layout()plt.show()

解释:
  • 预测结果可视化:我们选择部分图像进行预测并显示模型的预测标签和真实标签,帮助分析模型的分类效果。

3.3 计算整体准确率

# 计算整体准确率accuracy = np.sum(predicted_labels == actual_labels) / len(actual_labels)print(f"Accuracy on the entire test set: {accuracy * 100:.2f}%")

解释:
  • 通过对比预测标签和实际标签,计算模型在测试集上的整体准确率。

4. 总结

        本文介绍了如何使用 TensorFlow 实现 ZFNet 网络,并在 MNIST 数据集上进行训练和测试。通过训练模型、评估性能、可视化预测结果,我们能够更好地理解 ZFNet 的优势和图像分类中的应用。

        希望这篇博客能帮助你掌握 ZFNet 的实现过程,理解其背后的原理,并能够顺利地应用到其他图像分类任务中!

        如有问题或进一步的疑问,请随时留言讨论!

完整项目:

https://github.com/qxd-ljy/ZFNet-TensorFlowicon-default.png?t=O83Ahttps://github.com/qxd-ljy/ZFNet-TensorFlowZFNet-TensorFlow: 使用 TensorFlow 实现 ZFNet 进行 MNIST 图像分类icon-default.png?t=O83Ahttps://gitee.com/qxdlll/zfnet-tensor-flow

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/17571.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

11.15 HTML

传统路线 HTML、CSS、JS AjaxJQueryMySQLJDBCServletJSPEL&JSTLCookieSessionFilterServlet案例MybatisSpringSpringMVCSpringBoot 全新路线 HTM、CSS、JSAjax、AxiosVue、Element前端工程化 vue脚手架MavenSpringBoot基础 基于SpringBoot进行讲解Spring的IOC&#xff…

打造旅游卡服务新标杆:构建SOP框架与智能知识库应用

随着旅游业的蓬勃兴起,旅游卡产品正逐渐成为市场的焦点。为了进一步提升服务质量和客户体验,构建一套高效且标准化的操作流程(SOP)变得尤为重要。本文将深入探讨如何构建旅游卡的SOP框架,并介绍如何利用智能知识库技术…

Java 简单家居开关系统

1.需求: 面向对象编程实现智能家居控制系统(简单的开关) 2.实现思路 1.定义设备类:创建设备对象代表家里的设备 JD类: import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor;D…

Github客户端工具github-desktop使用教程

文章目录 1.客户端工具的介绍2.客户端工具使用感受3.仓库的创建4.初步尝试5.本地文件和仓库路径5.1原理说明5.2修改文件5.3版本号的说明5.4结合码云解释5.5版本号的查找 6.分支管理6.1分支的引入6.2分支合并6.3创建测试仓库6.4创建测试分支6.5合并分支6.6合并效果查看6.7分支冲…

3D Gaussian Splatting的全面理解

1.概述 高斯展开是一种表示 3D 场景和渲染新视图的方法,在“用于实时辐射场渲染的 3D 高斯展开” 中介绍。它可以被认为是类似 NeRF 的模型的替代品,就像过去的 NeRF 一样,高斯飞溅导致了许多新的研究工作,他们选择将其用作各种用例的 3D 世界的底层表示。那么它有什么特别…

Arcgis地图实战三:自定义导航功能的实现

文章目录 1.最终效果预览2.计算两点之间的距离3.将点线画到地图上4.动态展示点线的变化5.动态画线6.动态画点 1.最终效果预览 2.计算两点之间的距离 let dis this.utilsTools.returnDisByCoorTrans(qdXYData, zdXYData, "4549")当距离小于我们在配置文件中预设置的…

【Mysql】Mysql的多表查询---多表联合查询(中)

1、外连接查询 外连接 查询分为左外连接(left outer join), 右外连接查询(right outer join) ,满外连接查询(full outer join). 注意:oracle 里面有full join &#xf…

Linux:进程状态

文章目录 前言一、初识fork1.1 fork函数的介绍1.2 fork出的子进程存在形式1.3 写时拷贝 二、进程的状态2.1 Linux内核源代码2.2 理解内核链表(重要)2.3 运行状态2.4 阻塞状态2.5 挂起状态 三、Z(zombie)状态 ,僵尸进程四、 孤儿进程总结 前言…

qml显示加载嵌入QWidget窗口

本篇博客介绍如何在qml界面里显示QWidget窗口,开发环境Qt6.5.3 qml. 视频讲解:https://edu.csdn.net/learn/40003/654001?spm=3001.4143 qml和QWidget是两套独立的开发方式,二者的窗口可以相互嵌套显示,本篇博客介绍把QWidget窗口封装为动态库,然后在QML的窗口里显示出来…

【MySQL】多表查询

5. 多表查询 5.1 多表关系 项目开发中,在进行数据库表结构设计时,会根据业务需求及业务模块之间的关系,分析并设计表结构,由于业务之间相互关联,所以各个表结构之间也存在着各种联系,基本上分为三种&#…

2024-11-16 串的存储结构

一、顺序存储。 1.首先定一个静态数组,然后定义i记录串的实际长度。(缺点:长度不可变) 2.使用malloc申请动态空间,定义指针指向串的地址。(需手动ferr) 方案一: 数组末尾记录长度 …

nodejs21: 快速构建自定义设计样式Tailwind CSS

Tailwind CSS 是一个功能强大的低级 CSS 框架,只需书写 HTML 代码,无需书写 CSS,即可快速构建美观的网站。 1. 安装 Tailwind CSS React 项目中安装 Tailwind CSS: 1.1 安装 Tailwind CSS 和相关依赖 安装 Tailwind CSS: npm…

Windows 安装Docker For Desktop概要

Windows 安装docker 下载部分的工作需要使用科学技术。如果没有可以联系博主发送已下载好的文件。 本文档不涉及技术的讲解,仅有安装的步骤。 准备工作 包含下载与环境准备,下载的文件仅下载,在后续步骤进行安装。 微软关于wsl的文档&…

对称加密算法DES的实现

一、实验目的 1、了解对称密码体制基本原理 2、掌握编程语言实现对称加密、解密 二、实验原理 DES 使用一个 56 位的密钥以及附加的 8 位奇偶校验位,产生最大 64 位的分组大小。这是一个迭代的分组密码,使用称为 Feistel 的技术,其中将加密…

三十八、Python(pytest框架-上)

一、介绍 框架(framework):框架是为解决一类事情的功能集合。 pytest框架:pytest框架是单元测试框架,这是第三方框架想要使用必须要安装,可以使用pytest来作为自动化测试执行框架,用来管理测试…

《Django 5 By Example》阅读笔记:p165-p210

《Django 5 By Example》学习第6天,p165-p210总结,总计46页。 一、技术总结 1.bookmarks项目 (1)登录认证 作者这里使用的是Django自带的auth。 (2)上传头像 图片处理,使用Pillow。 (3)扩展user 扩展user模型与自带的user使用外键进行…

shell基础(3)

声明! 学习视频来自B站up主 **泷羽sec** 有兴趣的师傅可以关注一下,如涉及侵权马上删除文章,笔记只是方便各位师傅的学习和探讨,文章所提到的网站以及内容,只做学习交流,其他均与本人以及泷羽sec团…

JVM面试题总结

1.介绍一下JVM的内存结构 JDK1.8及以后,JVM主要分为元空间、堆、虚拟机栈、本地方法栈、程序计数器五个部分,另外还有一个直接内存部分,是直接属于操作系统的。 其中元空间、堆是线程共享的,虚拟机栈、本地方法栈、程序计数器是线…

小新Pro 14 AHP9 2024款(83D3)原装oem预装系统Win11恢复安装包下载

适用品牌机型 :LENOVO联想【83D3】 链接:https://pan.baidu.com/s/10RAxNdvYPWJ21b_4--Y7Xw?pwdo5ju 提取码:o5ju 联想原装出厂Windows11系统自带所有驱动、出厂主题壁纸、系统属性联机支持标志、系统属性专属LOGO标志、Office365办公软…

【论文笔记】Towards Privacy-Aware Sign Language Translation at Scale

🍎个人主页:小嗷犬的个人主页 🍊个人网站:小嗷犬的技术小站 🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。 基本信息 标题: Towards Privacy-Aware Si…