第T2周:TensorFlow实现彩色图片分类(CIFAR10数据集),并实现自己的真实图片分类

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目标
加载CIFAR-10数据集进行训练,然后能够对彩色图片进行分类
具体实现
(一)环境
语言环境:Python 3.10
编 译 器: PyCharm
*框架:*TensorFlow
**(二)具体步骤:
1. 设置使用GPU

# 设置使用GPU  
gpus = tf.config.list_physical_devices("GPU")  
# print(gpus)  
if gpus:  gpu0 = gpus[0]  tf.config.experimental.set_memory_growth(gpu0, True)  tf.config.set_visible_devices([gpu0], "GPU")

2.导入数据集

# 导入数据集  
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

image.png
image.png
3. 数据标准化

# 数据标准化到0-1区间内  
train_images, test_images = train_images / 255.0, test_images / 255.0  
print(train_images, test_images)

image.png
4.可视化数据

# 可视化数据  
class_names = ['飞机', '小汽车', '鸟', '猫', '鹿',  '狗', '青蛙', '马', '船', '卡车']
plt.figure(figsize=(20, 10))  
for i in range(20):  plt.subplot(5, 10, i+1)  plt.xticks([])  plt.yticks([])  plt.grid(False)  plt.imshow(train_images[i], cmap=plt.cm.binary)  plt.xlabel(class_names[train_labels[i][0]])  plt.show()

5.构建CNN网络模型
image.png

# 构建CNN网络  
model = models.Sequential([  layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),  layers.MaxPooling2D((2, 2)),  layers.Conv2D(64, (3, 3), activation='relu'),  layers.MaxPooling2D((2, 2)),  layers.Conv2D(64, (3, 3), activation='relu'),  layers.Flatten(),  layers.Dense(64, activation='relu'),  layers.Dense(10)  
])  print(model.summary())
Model: "sequential"
┌─────────────────────────────────┬────────────────────────┬───────────────┐
│ Layer (type)                    │ Output Shape           │       Param # │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d (Conv2D)                 │ (None, 30, 30, 32)     │           896 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d (MaxPooling2D)    │ (None, 15, 15, 32)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_1 (Conv2D)               │ (None, 13, 13, 64)     │        18,496 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_1 (MaxPooling2D)  │ (None, 6, 6, 64)       │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_2 (Conv2D)               │ (None, 4, 4, 64)       │        36,928 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ flatten (Flatten)               │ (None, 1024)           │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense (Dense)                   │ (None, 64)             │        65,600 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_1 (Dense)                 │ (None, 10)             │           650 │
└─────────────────────────────────┴────────────────────────┴───────────────┘Total params: 122,570 (478.79 KB)Trainable params: 122,570 (478.79 KB)Non-trainable params: 0 (0.00 B)
None

6.编译与训练模型

# 编译模型  
model.compile(optimizer='adam',  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),  metrics=['accuracy'])  
# 训练模型  
history = model.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels))
Epoch 1/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - accuracy: 0.3335 - loss: 1.7990 - val_accuracy: 0.5389 - val_loss: 1.2733
Epoch 2/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 6s 4ms/step - accuracy: 0.5518 - loss: 1.2573 - val_accuracy: 0.5991 - val_loss: 1.1310
Epoch 3/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 6s 4ms/step - accuracy: 0.6235 - loss: 1.0623 - val_accuracy: 0.6547 - val_loss: 0.9888
Epoch 4/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 6s 4ms/step - accuracy: 0.6627 - loss: 0.9574 - val_accuracy: 0.6547 - val_loss: 0.9930
Epoch 5/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 6s 4ms/step - accuracy: 0.6929 - loss: 0.8715 - val_accuracy: 0.6660 - val_loss: 0.9542
Epoch 6/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 6s 4ms/step - accuracy: 0.7174 - loss: 0.8132 - val_accuracy: 0.6943 - val_loss: 0.8771
Epoch 7/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 6s 4ms/step - accuracy: 0.7368 - loss: 0.7568 - val_accuracy: 0.6978 - val_loss: 0.8687
Epoch 8/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 6s 4ms/step - accuracy: 0.7495 - loss: 0.7141 - val_accuracy: 0.6963 - val_loss: 0.8821
Epoch 9/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 6s 4ms/step - accuracy: 0.7682 - loss: 0.6607 - val_accuracy: 0.6795 - val_loss: 0.9167
Epoch 10/10
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 6s 4ms/step - accuracy: 0.7755 - loss: 0.6344 - val_accuracy: 0.7016 - val_loss: 0.8806

7.预测

# 看一下要预测的图片是什么  
plt.imshow(test_images[1])  
plt.show()

image.png
可以看出是一个船。看看模型能否预测准确:

import numpy as np  
pre = model.predict(test_images)  
print(class_names[np.argmax(pre[1])])

image.png
预测准确。
8.预测一下我们自己的图片
工程上新创建一个目录data,网上找一张鹿的图片保存在data中:
image.png

# 预测一下真实照片  
image_path = "data/cat2.jpg"  # 图片存储路径
original_image = tf.io.read_file(image_path, 'r')  
# print(original_image)   # 原始图片数据  # 将原始图片数据转换成tensor格式  
original_image_tensor = tf.io.decode_jpeg(original_image)  
# print(original_image_tensor)    # 打印图片tensor数据  
# print(original_image_tensor.shape)  # 图片形状(750, 500, 3)  
# 根据上面的输入特征(32, 32, 3),因此需要将图片大小改成(32, 32)的。  
original_image_tensor_resize = tf.image.resize(original_image_tensor, [32, 32])  
# print(original_image_tensor_resize.shape)   # resize后的形状  # reshape成(32, 32, 3)  
original_image_tensor_resize_reshape = tf.reshape(original_image_tensor_resize, [-1, 32, 32, 3])  
# 显示图片  
for i in range(3):  plt.imshow(original_image_tensor_resize_reshape[0, :, :, i])  plt.title(str(i))  plt.colorbar()  plt.show()  # 再进行标准化到 0-1 区间  
original_image_tensor_resize_reshape_normalize = original_image_tensor_resize_reshape / 255.0  
# print(original_image_tensor_resize_reshape_normalize.shape)  # 开始预测
import numpy as np  
pre = model.predict(original_image_tensor_resize_reshape_normalize)  
# print(pre)  
# 打印预测结果
print("当前图片预测为: ", class_names[np.argmax(pre[0])])

image.png
预测正确。

(三)总结

  1. 熟悉各个模型搭建、训练到预测的流程
  2. 了解神经网络模型(黑盒子)的细节
  3. 并不是每次都能预测正确,对于真实图片的预处理,要怎么样提升准确性,后续研究。
  4. 并不是把epochs提高,准确性就提高,继续研究。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1549053.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Python | Leetcode Python题解之第442题数组中重复的数据

题目: 题解: class Solution:def findDuplicates(self, nums: List[int]) -> List[int]:ans []for x in nums:x abs(x)if nums[x - 1] > 0:nums[x - 1] -nums[x - 1]else:ans.append(x)return ans

可以免费制作表情包的AI工具来了!

一直想自己制作一套表情包,但一直没有找到好用的工具,要么就是太麻烦,要么就是不免费。 今天AI表情包免费制作工具来了,手机就可以直接做表情包,非常方便。 先看效果~ 工具用到的是通义APP,可以在频道中找…

一款革命性的AI写作工具——文字游侠AI大模型重大升级,创作效率提高高达20倍,小白也能轻松实现月入过万!

在自媒体创作的浪潮中,如何高效地生产高质量内容成为许多创作者的难题。然而,随着AI技术的飞速发展,这一难题得到了完美的解决。今天,我要为大家介绍一款革命性的AI写作工具——文字游侠AI大模型,它不仅能够大幅提高创…

Rust赋能前端:为WebAssembly 瘦身

❝ 凡事你一旦接纳了,就不存在了;你看不惯它,它就一直折磨你 大家好,我是柒八九。一个专注于前端开发技术/Rust及AI应用知识分享的Coder ❝ 此篇文章所涉及到的技术有 WebAssembly Rust SIMD LLVM binaryen 因为,行文字…

Llama 3.1 技术研究报告-5

5.3 人工评测 除了在标准基准测试集上的评估外,我们还进⾏了⼀系列⼈类评估。这些评估使我能够测量和优化模型性能的更微妙⽅⾯,例如模型的语调、冗⻓性和对细微差别及⽂化背景的理解。精⼼设计的⼈类评估密切反映了⽤⼾体验,提供了模型在现…

PacketSender使用说明

1、Packet Sender介绍 Packet Sender是一个开源实用程序,允许发送和接收TCP、UDP和SSL(加密TCP)数据包,以及HTTP/HTTPS请求和面板生成。主线分支正式支持Windows、Mac和桌面Linux(带Qt)。其他地方可能会重…

隧道灯光远程控制系统的设计与实现(论文+源码)_kaic

摘要 随着互联网的发展,物联网的时代己经到来。无线控制技术的应用已经普及到了我们生活中的各个角落。节能环保的意识也在不断的加强,隧道照明作为隧道建设的一个主要的环节,一个好的隧道照明系统不仅仅能保障隧道车辆的正常通行&#xff0c…

无需科学!Copilot网页版GPT-4无限制对话来了!

之前本公众号讲过: 微软copilot分为免费版copilot、个人家庭版copilot pro(每月20刀)和商业版copilot for Microsoft 365(每月30刀)。 其中免费版和个人家庭版的copilot无论在任何情况下使用都需要科学手段。 商业版…

什么是前缀索引?

什么是前缀索引? 1、什么是前缀索引?2、为什么要使用前缀索引?3、如何选择前缀长度?4、创建前缀索引的SQL语法5、示例 💖The Begin💖点点关注,收藏不迷路💖 在处理包含长字符串的数据…

3款照片人物开口说话AI工具,跟真人说话一样~免费!短视频带货必备!(附教程)

大家好,我是画画的小强 今天给大家分享一个AI图片口播数字人讲认知思维,单号佣金赚5W的AI带货信息差玩法,许多小伙伴表示对这类AI带货玩法感兴趣。 说实话,现在AI照片人物对口型工具,越来越逼真,很难辨识出…

8.使用 VSCode 过程中的英语积累 - Help 菜单(每一次重点积累 5 个单词)

前言 学习可以不局限于传统的书籍和课堂,各种生活的元素也都可以做为我们的学习对象,本文将利用 VSCode 页面上的各种英文元素来做英语的积累,如此做有 3 大利 这些软件在我们工作中是时时刻刻接触的,借此做英语积累再合适不过&a…

牛犇啊!LSTM+Transformer炸裂创新,精准度高至95.65%!

【LSTMTransformer】作为一种混合深度学习模型,近年来在学术界和工业界都受到了极大的关注。它巧妙地融合了长短期记忆网络(LSTM)在处理时序数据方面的专长和Transformer在捕捉长距离依赖关系上的优势,从而在文本生成、机器翻译、…

做中视频计划,哪里找素材?推荐几个热门中视频素材下载网站

在做中视频计划时,寻找合适的素材至关重要。抖音上那些热门的中视频素材都是从哪里下载的呢?以下五大高清素材库值得收藏,赶紧来看看吧! 蛙学网 蛙学网提供了百万级的中视频素材,质量高且是4K高清无水印,视…

Android使用RecyclerView仿美团分类界面

RecyclerView目前来说对大家可能不陌生了。由于在公司的项目中,我们一直用的listview和gridview。某天产品设计仿照美团的分类界面设计了一个界面,我发现用gridview不能实现这样的效果,所以就想到了RecyclerView,确实是一个很好的…

(最新已验证)stm32 + 新版 onenet +dht11+esp8266/01s + mqtt物联网上报温湿度和控制单片机(保姆级教程)

物联网实践教程:微信小程序结合OneNET平台MQTT实现STM32单片机远程智能控制 远程上报和接收数据——汇总 前言 之前在学校获得了一个新玩意:ESP-01sWIFI模块,去搜了一下这个小东西很有玩点,远程控制LED啥的,然后我就想…

[大语言模型-论文精读] Diffusion Model技术-通过时间和空间组合扩散模型生成复杂的3D人物动作

​​​​​​Generation of Complex 3D Human Motion by Temporal and Spatial Composition of Diffusion Models L Mandelli, S Berretti - arXiv preprint arXiv:2409.11920, 2024 通过时间和空间组合扩散模型生成复杂的3D人物动作 摘要 本文提出了一种新的方法&#xff0…

UCS512DHN DMX512差分并联协议LED驱动IC 舞动灯光的魔法芯片

UCS512DHN产品概述: UCS512DHN是DMX512差分并联协议LED驱动芯片,可选择1/2/3/4通道高精度恒流输出,灰度达65536 级。UCS512DHN为带散热片封装的大电流输出版本。UCS512DHN有PWM反极性输出功能,此功能适合外挂三极 管,…

极品飞车14热力追踪原始版高清重制版MOD分享

《极品飞车14:热力追击》(Need for Speed:Hot Pursuit)是由Criterion Games工作室负责开发,EA公司2010年底发行的一款竞速类游戏,也是新一代的热力追踪系列作品,游戏平台为Xbox 360、PS3。 《极品飞车14&a…

11. LCEL:LangChain Expression Language

这篇文章覆盖了LCEL的理解和他是如何工作的。 LCEL(LangChain Expression Language):是把一些有趣python概念抽象成一种格式,从而为构建LangChain组件链提供一种“简约”代码层。 LCEL在下面方面有着强大的支撑: 链的快速开发流式输出、异…

线性方程组的迭代方法

目录 直接方法与迭代方法 常规迭代算法 选择迭代求解器 预条件子 预条件子示例 均衡和重新排序 使用线性运算函数取代矩阵 数值线性代数最重要也是最常见的应用之一是可求解以 A*x b 形式表示的线性方程组。当 A 为大型稀疏矩阵时,您可以使用迭代方法求解线…