TensorFlow_T7 咖啡豆识别

目录

一、前言

二、前期准备

1、设置GPU

2、导入数据

3、查看数据图片

三、数据预处理 

1、加载数据

2、可视化数据

3、配置数据集

四、构建VGG-16网络

1、VGG优缺点分析

2、自建模型

3、网络结构图

五、编译

六、 训练模型


一、前言

  •   🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

二、前期准备

1、设置GPU

import tensorflow as tfgpus=tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0],True)tf.config.set_visible_devices([gpus[0]],"GPU")
gpus

2、导入数据

import pathlibdata_dir="D:\THE MNIST DATABASE\T7"
data_dir=pathlib.Path(data_dir)

3、查看数据图片

image_count=len(list(data_dir.glob('*/*.png')))print("图片总数为:",image_count)

运行结果:

 

三、数据预处理 

1、加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset中。

train_ds=tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(224,224),batch_size=32
)

运行结果如下:

加载验证集:

val_ds=tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(224,224),batch_size=32
)

运行结果如下:

通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。

class_names=train_ds.class_names
print(class_names)

 运行结果如下:

2、可视化数据

import matplotlib.pyplot as pltplt.figure(figsize=(10,4))for images,labels in train_ds.take(1):for i in range(10):ax=plt.subplot(2,5,i+1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")

运行结果如下:

查看图像格式:

for image_batch,labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break

运行结果如下:

3、配置数据集

  • shuffle() :打乱数据;
  • prefetch() :预取数据,加速运行;
  • cache() :将数据集缓存到内存当中,加速运行;

对数据集进行预处理,对于验证集只进行了缓存和预取操作,没有进行打乱操作。

AUTOTUNE=tf.data.AUTOTUNEtrain_ds=train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds=val_ds.cache().prefetch(buffer_size=AUTOTUNE)

对图像数据集进行归一化。 将输入数据除以255,将像素值缩放到0到1之间。然后,使用map函数将这个归一化层应用到训练数据集train_ds和验证数据集val_ds的每个样本上。这样,所有的图像都会被归一化,以便在神经网络中更好地处理。

from tensorflow.keras import layersnormalization_layer=layers.experimental.preprocessing.Rescaling(1./255)train_ds=train_ds.map(lambda x,y:(normalization_layer(x),y))
val_ds=val_ds.map(lambda x,y:(normalization_layer(x),y))

从验证数据集中获取一个批次的图像和标签,然后将第一个图像存储在变量first_image中。接下来,使用numpy库的min和max函数分别计算first_image中的最小值和最大值,并将它们打印出来。这样可以帮助我们了解图像数据的归一化情况,例如是否所有像素值都在0到1之间。

import numpy as npimage_batch,labels_batch=next(iter(val_ds))
first_image=image_batch[0]#查看归一化后的数据
print(np.min(first_image),np.max(first_image))

运行结果如下:

四、构建VGG-16网络

1、VGG优缺点分析

(1)优点:结构简洁,整个网络都使用了同样大小的卷积核尺寸(3x3)和最大池化尺寸(2x2);

(2)缺点

  • 训练时间过长,调参难度大;
  • 需要的存储容量大,不利于部署。例如存储VGG-16权重值文件的大小为500多MB,不利于安装到嵌入式系统中;

2、自建模型

from tensorflow.keras import layers,models,Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D,MaxPooling2D,Dense,Flatten,Dropoutdef vgg16(nb_classes,input_shape):input_tensor=Input(shape=input_shape)#1st blockx=Conv2D(64,(3,3),activation='relu',padding='same')(input_tensor)x=Conv2D(64,(3,3),activation='relu',padding='same')(x)x=MaxPooling2D((2,2),strides=(2,2))(x)#2nd blockx=Conv2D(128,(3,3),activation='relu',padding='same')(x)x=Conv2D(128,(3,3),activation='relu',padding='same')(x)x=MaxPooling2D((2,2),strides=(2,2))(x)#3rd blockx=Conv2D(256,(3,3),activation='relu',padding='same')(x)x=Conv2D(256,(3,3),activation='relu',padding='same')(x)x=Conv2D(256,(3,3),activation='relu',padding='same')(x)x=MaxPooling2D((2,2),strides=(2,2))(x)#4th blockx=Conv2D(512,(3,3),activation='relu',padding='same')(x)x=Conv2D(512,(3,3),activation='relu',padding='same')(x)x=Conv2D(512,(3,3),activation='relu',padding='same')(x)x=MaxPooling2D((2,2),strides=(2,2))(x)#5th blockx=Conv2D(512,(3,3),activation='relu',padding='same')(x)x=Conv2D(512,(3,3),activation='relu',padding='same')(x)x=Conv2D(512,(3,3),activation='relu',padding='same')(x)x=MaxPooling2D((2,2),strides=(2,2))(x)#full connectionx=Flatten()(x)x=Dense(4096,activation='relu')(x)x=Dense(4096,activation='relu')(x)output_tensor=Dense(nb_classes,activation='softmax',name='predictions')(x)model=Model(input_tensor,output_tensor)return modelmodel=vgg16(len(class_names),(224,224,3))
model.summary()

运行结果如下:

   

   

3、网络结构图

结构说明:

  • 13个卷积层(Convolutional Layer),分别用blockX_convX表示;
  • 3个全连接层(Fully connected Layer),分别用fcXpredictions表示;
  • 5个池化层(Pool layer),分别用blockX_pool表示;

VGG-16包含了16个隐藏层(13个卷积层和3个全连接层),故称为VGG-16;

五、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率;
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新;
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率;
#设置初始学习率
initial_learning_rate=1e-4lr_schedule=tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate,decay_steps=30,decay_rate=0.92,staircase=True
)#设置优化器
opt=tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)model.compile(optimizer=opt,loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])

六、 训练模型

epochs=20history=model.fit(train_ds,validation_data=val_ds,epochs=epochs
)

运行结果如图:

七、可视化结果

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

运行结果如图:


👏觉得文章对自己有用的宝子可以收藏文章并给小编点个赞!

👏想了解更多统计学、数据分析、数据开发、机器学习算法、深度学习等有关知识的宝子们,可以关注小编,希望以后我们一起成长!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/15501.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

平替备用机!不到 5 元的 410 随身 WiFi 免刷机实现短信转发

本文首发于只抄博客,欢迎点击原文链接了解更多内容。 前言 各位用备用机的,应该有很多只是插上不常用的手机卡,然后装个短信转发的 App 来看看验证码什么的吧。每隔一段时间还要去看看备用机还有没有电,但其实这种需求骁龙 410 芯…

【EmbeddedGUI】脏矩阵设计说明

脏矩阵设计说明 背景介绍 一般情况下,当屏幕内容绘制完毕后,实际应用通常需要更新屏幕中的一部分内容,而不是单纯显示一个静态图片在那。 如下图所示,屏幕中有一个图片控件(Img2)和一个文本控件&#xf…

网络基础-超文本协议与内外网划分(超长版)

一、超文本协议 1. HTTP协议简介 1.1. 网络架构简单介绍 (1). C/S架构(Client/Server架构) (2). B/S架构(Browser/Server) 总结对比 2. HTTP协议版本 2.1. HTTP/0.9 (1991年发布) 2.2. HTTP/1.0 &a…

5分钟搞懂 Golang 堆内存

本文主要解释了堆内存的概念,介绍了 Linux 堆内存的工作原理,以及 Golang 如何管理堆内存。原文: Understanding Heap Memory in Linux with Go 你想过为什么堆内存被称为 "堆" 吗?想象一下杂乱堆放的对象,与此类似&…

今日 AI 简报 | 开源 RAG 文本分块库、AI代理自动化软件开发框架、多模态统一生成框架、在线图像背景移除等

❤️ 如果你也关注大模型与 AI 的发展现状,且对大模型应用开发非常感兴趣,我会快速跟你分享最新的感兴趣的 AI 应用和热点信息,也会不定期分享自己的想法和开源实例,欢迎关注我哦! 🥦 微信公众号&#xff…

【C++学习(35)】在Linux中基于ucontext实现C++实现协程(Coroutine),基于C++20的co_await 协程的关键字实现协程

文章目录 为什么使用协程协程的理解协程优势协程的原语操作yield 与 resume 是一个switch操作(三种实现方式): 基于 ucontext 的协程基于 XFiber 库的操作1 包装上下文2 XFiber 上下文调度器2.1 CreateFiber2.2 Dispatch 基于C20的co_return …

技术段子——论如何在0.387秒以内获取到闲鱼的上新数据。

个人一直在做闲鱼辅助相关的工具类软件。因为知道阿里系请求和风控的原因,再加个人做软件一直想的是如何让用户稳定运行。 因为阿里系对于请求的风控,所以个人风格导到软件效率一直一般。并不是做不到快速抓取,而是用效率换稳定。 所以&#…

【C#设计模式(10)——装饰器模式(Decorator Pattern)】

前言 装饰器模式可以在运行时为对象添加额外的功,而无需修改原始对象的代码。这种方式比继承更加灵活。 代码 //蛋糕类(抽象类) public abstract class Cake {public abstract void Create(); } //奶油蛋糕类 public class CreamCake : Cak…

2025年PMP考试安排是怎样?备考计划与重要时间节点公布

PMP考试在中国大陆每年举行四次,分别是在3月、6月、9月和12月。而中国港澳台地区的PMP考试则可以每天进行机考。在中国大陆地区的笔试考试中,主要采用涂卡和机读卡来记录成绩。 每次PMP考试的时间都是在周六的9点到12点50分,共计230分钟。 P…

缓冲式线程池C++简易实现

前言 : 代码也比较短&#xff0c;简单说一下代码结构&#xff0c;是这样的&#xff1a; SyncQueue.hpp封装了一个大小为MaxTaskCount的同步队列&#xff0c;这是一个模板类&#xff0c;它在线程池中承担了存放任务等待线程组中的线程来执行的角色。最底层是std::list<T>…

推荐一款功能强大的光学识别OCR软件:Readiris Dyslexic

Readiris Dyslexic是一款功能强大的光学识别OCR软件&#xff0c;可以扫描任何纸质文档并将其转换为完全可编辑的数字文件(Word&#xff0c;Excel&#xff0c;PDF)&#xff0c;然后用你喜欢的编辑器进行编辑。该软件提供了一种轻松创建&#xff0c;修改和签名PDF的完整解决方法&…

【面试全纪实 | Nginx 04】请回答,你真的精通Nginx吗?

&#x1f5fa;️博客地图 &#x1f4cd;1、location的作用是什么&#xff1f; &#x1f4cd;2、你知道漏桶流算法和令牌桶算法吗&#xff1f; &#x1f4cd;3、Nginx限流怎么做的&#xff1f; &#x1f4cd;4、为什么要做动静分离&#xff1f; &#x1f4cd;5、Nginx怎么做…

如何为你的 SaaS 公司做好国际化发展的准备?

随着 SaaS&#xff08;软件即服务&#xff09;公司的不断发展&#xff0c;确定扩张机会并建立可扩展的流程和策略以支持这些机会变得至关重要。一些公司向上游市场扩张&#xff0c;向企业销售产品&#xff0c;而此前他们主要面向中小企业。一些公司则朝着相反的方向发展&#x…

Towards Reasoning in Large Language Models: A Survey

文章目录 题目摘要引言什么是推理?走向大型语言模型中的推理测量大型语言模型中的推理发现与启示反思、讨论和未来方向 为什么要推理?结论题目 大型语言模型中的推理:一项调查 论文地址:https://arxiv.org/abs/2212.10403 项目地址: https://github.com/jeffhj/LM-reason…

推荐一款硬盘数据清除工具:Macrorit Data Wiper

Macrorit Data Wiper是一款硬盘数据清除工具&#xff0c;用于安全擦除数据、分区和磁盘的一站式工具包。完全擦除系统/引导分区。许多程序文件默认存储在系统磁盘驱动器中。如果您或您的组织想要永久擦除磁盘驱动器以防止未经授权使用您的数据&#xff0c;则此功能是必要的。 为…

第13章 Zabbix分布式监控企业实战

企业服务器对用户提供服务,作为运维工程师最重要的事情就是保证该网站正常稳定的运行,需要实时监控网站、服务器的运行状态,并且有故障及时去处理。 监控网站无需人工时刻去访问WEB网站或者登陆服务器去检查,可以借助开源监控软件例如Zabbix、Cacti、Nagios、Ganglia等来实…

2024IJCAI | MetalISP: 仅用1M参数的RAW到RGB高效映射模型

文章标题是&#xff1a;《MetaISP:Effcient RAW-to-sRGB Mappings with Merely 1M Parameters》 MetaISP收录于2024IJCAI&#xff0c;是新加坡国立大学&#xff08;Xinchao Wang为通讯作者&#xff09;和华为联合研发的新型ai-isp。 原文链接&#xff1a;MetaISP 【1】论文的…

使用 ts-node 运行 ts文件,启动 nodejs项目

最近在写一个nodejs项目&#xff0c;使用 ts-node 启动项目。遇到了一些问题&#xff0c;在此记录一下。 ts-node 是 TypeScript 执行引擎和 Node.js 的 REPL(一个简单的交互式的编程环境)。 它能够直接在 Node.js 上执行 TypeScript&#xff0c;而无需预编译。 这是通过挂接…

《鸿蒙生态:开发者的机遇与挑战》

一、引言 在当今科技飞速发展的时代&#xff0c;操作系统作为连接硬件与软件的核心枢纽&#xff0c;其重要性不言而喻。鸿蒙系统的出现&#xff0c;为开发者带来了新的机遇与挑战。本文将从开发者的角度出发&#xff0c;阐述对鸿蒙生态的认知和了解&#xff0c;分析鸿蒙生态的…

PHP代码审计 - SQL注入

SQL注入 正则搜索(update|select|insert|delete).*?where.*示例一&#xff1a; bluecms源码下载&#xff1a;source-trace/bluecms 以项目打开网站根目录&#xff0c;并以ctrlshiftf打开全局搜索 (update|select|insert|delete).*?where.*并开启正则匹配 最快寻找脆弱点的…