使用TensorFlow实现简化版 GoogLeNet 模型进行 MNIST 图像分类

        在本文中,我们将使用 TensorFlow 和 Keras 实现一个简化版的 GoogLeNet 模型来进行 MNIST 数据集的手写数字分类任务。GoogLeNet 采用了 Inception 模块,这使得它在处理图像数据时能更高效地提取特征。本教程将详细介绍如何在 MNIST 数据集上训练和测试这个模型。

项目结构

        我们的代码将分为两个部分:

  1. 训练部分 (train.py): 包含模型定义、数据加载、模型训练等。
  2. 测试部分 (test.py): 用于加载训练好的模型,并在测试集上评估其性能。

训练部分:train.py

1. 数据加载与预处理

        首先,我们需要加载 MNIST 数据集并进行预处理。预处理包括调整图像形状、归一化以及 One-Hot 编码标签。

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categoricaldef load_and_preprocess_data():# 加载 MNIST 数据集(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# 数据预处理:将图像形状调整为 [28, 28, 1],并归一化到 [0, 1] 范围train_images = train_images.reshape((train_images.shape[0], 28, 28, 1)) / 255.0test_images = test_images.reshape((test_images.shape[0], 28, 28, 1)) / 255.0# One-Hot 编码标签train_labels = to_categorical(train_labels, 10)test_labels = to_categorical(test_labels, 10)return train_images, train_labels, test_images, test_labels

2. 创建简化版 GoogLeNet 模型

        接下来,我们定义一个简化版的 GoogLeNet 模型。该模型包括卷积层、Inception 模块和全连接层。

from tensorflow.keras import layers, modelsdef googlenet(input_shape=(28, 28, 1), num_classes=10):inputs = layers.Input(shape=input_shape)# 第一卷积层 + 池化层x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(inputs)x = layers.MaxPooling2D((2, 2))(x)# 第二卷积层 + 池化层x = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)x = layers.MaxPooling2D((2, 2))(x)# 第三卷积层 + 池化层x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)x = layers.MaxPooling2D((2, 2))(x)# Inception 模块inception1 = layers.Conv2D(64, (1, 1), activation='relu', padding='same')(x)inception2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)inception3 = layers.Conv2D(32, (5, 5), activation='relu', padding='same')(x)# 拼接 Inception 模块的输出x = layers.concatenate([inception1, inception2, inception3], axis=-1)# 全局平均池化层x = layers.GlobalAveragePooling2D()(x)# 全连接层x = layers.Dense(1024, activation='relu')(x)x = layers.Dropout(0.5)(x)  # Dropout 层减少过拟合outputs = layers.Dense(num_classes, activation='softmax')(x)  # 输出层,使用 softmax 激活函数进行多分类model = models.Model(inputs=inputs, outputs=outputs)return model

3. 模型训练

        定义好模型之后,我们使用 Adam 优化器和交叉熵损失函数来训练模型,并保存训练好的模型。

def train_model(model, train_images, train_labels, epochs=5, batch_size=64):# 训练模型history = model.fit(train_images, train_labels,epochs=epochs,batch_size=batch_size)return historydef save_model(model, filename='googlenet_mnist.h5'):model.save(filename)print(f"Model saved to {filename}")

4. 主程序

        最后,在主程序中,我们加载数据、创建模型并开始训练。

def main():train_images, train_labels, test_images, test_labels = load_and_preprocess_data()model = googlenet(input_shape=(28, 28, 1), num_classes=10)model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])train_model(model, train_images, train_labels, epochs=5, batch_size=64)save_model(model)if __name__ == '__main__':main()


测试部分:test.py

1. 加载训练好的模型

        在测试部分,我们将加载训练好的模型,并在测试集上进行评估。

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categoricaldef load_and_preprocess_data():(_, _), (test_images, test_labels) = mnist.load_data()test_images = test_images.reshape((test_images.shape[0], 28, 28, 1)) / 255.0test_labels = to_categorical(test_labels, 10)return test_images, test_labelsdef load_model(model_path='googlenet_mnist.h5'):model = tf.keras.models.load_model(model_path)return model

2. 评估模型

        我们通过 evaluate 方法评估模型的损失和准确率。

def evaluate_model(model, test_images, test_labels):test_loss, test_acc = model.evaluate(test_images, test_labels)print(f"Test accuracy: {test_acc * 100:.2f}%")return test_loss, test_acc

3. 显示预测结果

        使用 Matplotlib 可视化前几张图片的预测结果。

import matplotlib.pyplot as pltdef display_predictions(model, test_images, test_labels, num_images=6):predictions = model.predict(test_images[:num_images])fig, axes = plt.subplots(2, 3, figsize=(10, 6))axes = axes.flatten()for i in range(num_images):ax = axes[i]ax.imshow(test_images[i].reshape(28, 28), cmap='gray')ax.set_title(f"Pred: {tf.argmax(predictions[i]).numpy()} \n True: {tf.argmax(test_labels[i]).numpy()}")ax.axis('off')plt.tight_layout()plt.show()

4. 主程序

        在主程序中,我们加载模型,评估其性能,并显示预测结果。

def main():test_images, test_labels = load_and_preprocess_data()model = load_model('googlenet_mnist.h5')evaluate_model(model, test_images, test_labels)display_predictions(model, test_images, test_labels)if __name__ == '__main__':main()


总结

        本文介绍了如何使用 TensorFlow 实现简化版 GoogLeNet,并在 MNIST 数据集上进行训练和测试。我们将代码分为训练和测试两部分,分别处理数据预处理、模型训练与评估、结果展示等工作。

        通过使用 GoogLeNet 进行图像分类,我们不仅能够提高分类性能,还能了解 Inception 模块在图像处理中的强大能力。希望这篇博客能够帮助你更好地理解深度学习模型的训练与测试过程。

完整项目:GoogLeNet-TensorFlow: 使用TensorFlow实现简化版 GoogLeNet 进行 MNIST 图像分类icon-default.png?t=O83Ahttps://gitee.com/qxdlll/goog-le-net-tensor-flow

qxd-ljy/GoogLeNet-TensorFlow: 使用 TensorFlow实现简化版 GoogLeNet 进行 MNIST 图像分类icon-default.png?t=O83Ahttps://github.com/qxd-ljy/GoogLeNet-TensorFlow 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/19039.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

TON商城与Telegram App:生态融合与去中心化未来的精彩碰撞

随着区块链技术的快速发展,去中心化应用(DApp)逐渐成为了数字生态的重要组成部分。而Telegram作为全球领先的即时通讯应用,不仅仅满足于传统的社交功能,更在区块链领域大胆探索,推出了基于其去中心化网络的…

vulhub之log4j

Apache Log4j Server 反序列化命令执行漏洞(CVE-2017-5645) 漏洞简介 Apache Log4j是一个用于Java的日志记录库,其支持启动远程日志服务器。Apache Log4j 2.8.2之前的2.x版本中存在安全漏洞。攻击者可利用该漏洞执行任意代码。 Apache Log4j 在应用程序中添加日志记录最…

web服务nginx实验4:访问控制

4-1:基于不同用户的访问控制: 安装软件: 创建HTTP基本认证用户密码文件,tom,密码:1,lisa,密码:1: -c:表示创建一个新的密码文件。如果该文件已经…

基于FastAPI实现本地大模型API封装调用

关于FastAPI FastAPI 是一个现代、快速(高性能)的 Python Web 框架,用于构建基于标准 Python 类型提示的 API。它以简洁、直观和高效的方式提供工具,特别适合开发现代 web 服务和后端应用程序。 问题:_pad() got an un…

数字化点亮库布其沙漠的绿色梦想

Bentley 应用程序助力提升设计和施工效率,提前六周交付设计成果 清洁能源为沙漠带来新活力 库布其光伏治沙项目(以下简称“该项目”)位于内蒙古鄂尔多斯市库布其沙漠,占地约 10 万亩,是中国单体规模最大的光伏治沙项目…

基于单片机的风能太阳能供电的路灯智能控制系统设计(论文+源码)

1系统总体设计 本课题为风能太阳能供电的路灯智能控制系统设计,系统的主要功能设计如下: (1) 供电模块:采用太阳能板以及风机模拟风扇充电,经过充电电路给锂电池进行充电。再由锂电池给照明模块以及整个项…

Linux Centos7 Rocky网卡配置

目录 1.Vmare 虚拟机配置 (1)打开虚拟机输入ip a,查看ip网段,若为192.168.81.135 (2)在Vmare上的虚拟网络配置器配置 (3)确保电脑有VMnet1 VMnet8 2.Linux虚拟机Centos配置 &#…

MySQL索引原理之查询优化

MySQL索引原理之查询优化 1、慢查询定位 开启慢查询日志 查看 MySQL 数据库是否开启了慢查询日志和慢查询日志文件的存储位置的命令如下: SHOW VARIABLES LIKE %slow_query_log%通过如下命令开启慢查询日志: SET global slow_query_log 1; SET global …

ArchGuard 架构分析器发布:多语言、跨项目架构数据生成,助力 AI 时代知识挖掘...

TL;DR:https://github.com/archguard/archguard 过去的几个月里,我们一直在探索用 AI 辅助跨项目、跨大量微服务的系统的开发。其中一个重要的话题就是,从现有的软件架构去生成知识,文档是落后、多版本的, 只有代码才保…

NLP论文速读(多伦多大学)|利用人类偏好校准来调整机器翻译的元指标

论文速读|MetaMetrics-MT: Tuning Meta-Metrics for Machine Translation via Human Preference Calibration 论文信息: 简介: 本文的背景是机器翻译(MT)任务的评估。在机器翻译领域,由于不同场景和语言对的需求差异&a…

工程车识别算法平台LiteAIServer算法定制工程车类型检测算法:建筑工地安全管理的得力助手

随着科技的飞速发展,智能化技术正在逐步改变我们的生活方式,特别是在交通管理和安全管理领域。其中,算法定制LiteAIServer工程车类型检测算法以其高效、准确和实时的特性,成为了建筑工地管理、矿山开采以及物流运输等多个领域的重…

机器学习2

三、特征工程 接机器学习1 4、特征降维 4.2、主成分分析PCA 从原始特征空间中找到一个新的坐标系统,使得数据在新坐标轴上的投影能够最大程度地保留数据的方差,同时减少数据的维度。 保留信息/丢失信息信息保留的比例 from sklearn.decomposition imp…

【Linux之权限】提升篇

前言 在前两篇文章里,我们已经学习了Linux中权限的理论、实践和重点,接下来我们将进一步提升对Linux权限的全面认知。虽是拓展,其实还是重点。 本文内容并不多,那我们就开始吧。 目录的权限该如何理解呢? 如果我想进…

亮数据结合AI大模型,实现数据自由

目录 一、获取网络数据的挑战1、反爬虫机制的威胁2、IP封锁与访问频率控制3、数据隐私与法律合规 二、亮数据动态代理:数据采集的最佳拍档1、高质量IP资源2、智能调度与自动切换3、合规与隐私保护4、多场景应用支持 三、使用亮数据代理 IP进行网络数据抓取1、引入 r…

elasticsearch是如何实现master选举的?

大家好,我是锋哥。今天分享关于【elasticsearch是如何实现master选举的?】面试题。希望对大家有帮助; elasticsearch是如何实现master选举的? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 在 Elasticsearch 中&…

EtherNet/IP转Profinet网关连接发那科机器人配置实例解析

本案例主要展示了如何通过Ethernet/IP转Profinet网关实现西门子1200PLC与发那科搬运机器人的连接。所需的设备有西门子1200PLC、开疆智能Ethernet/IP转Profinet网关以及Fanuc机器人。 具体配置步骤:打开西门子博图配置软件,添加PLC。这是配置的第一步&am…

Uniapp运行环境判断和解决跨端兼容性详解

Uniapp运行环境判断和解决跨端兼容性 开发环境和生产环境 uniapp可通过process.env.NODE_ENV判断当前环境是开发环境还是生产环境,一般用于链接测试服务器或者生产服务器的动态切换。在HX中,点击运行编译出来的代码是开发环境,点击发行编译…

C语言 for 循环:解谜数学,玩转生活!

放在最前面的 🎈 🎈 我的CSDN主页:OTWOL的主页,欢迎!!!👋🏼👋🏼 🎉🎉我的C语言初阶合集:C语言初阶合集,希望能…

【专题】2024AIGC创新应用洞察报告汇总PDF洞察(附原数据表)

原文链接:https://tecdat.cn/?p38310 在科技日新月异的今天,人工智能领域正以前所未有的速度发展,AIGC(人工智能生成内容)成为其中最耀眼的明珠。从其应用场景的不断拓展,到对各行业的深刻变革&#xff0…

.NET桌面应用架构Demo与实战|WPF+MVVM+EFCore+IOC+DI+Code First+AutoMapper

目录 .NET桌面应用架构Demo与实战|WPFMVVMEFCoreIOCDICode FirstAutoPapper技术栈简述项目地址:功能展示项目结构项目引用1. 新建模型2. Data层,依赖EF Core,实现数据库增删改查3. Bussiness层,实现具体的业务逻辑4. Service层&am…