tensorflow-卷积神经网络-图像分类入门demo

猫狗识别

  • 数据预处理:图像数据处理,准备训练和验证数据集
  • 卷积网络模型:构建网络架构
  • 过拟合问题:观察训练和验证效果,针对过拟合问题提出解决方法
  • 数据增强:图像数据增强方法与效果
  • 迁移学习:深度学习必备训练策略

导入工具包

import os
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator

指定好数据路径(训练和验证)

# 数据所在文件夹
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')# 训练集
train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')# 验证集
validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')

构建卷积神经网络模型

  • 几层都可以,大家可以随意玩
  • 如果用CPU训练,可以把输入设置的更小一些,一般输入大小更主要的决定了训练速度
  • model = tf.keras.models.Sequential([#如果训练慢,可以把数据设置的更小一些tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(64, 64, 3)),tf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Conv2D(64, (3,3), activation='relu'),tf.keras.layers.MaxPooling2D(2,2),tf.keras.layers.Conv2D(128, (3,3), activation='relu'),tf.keras.layers.MaxPooling2D(2,2),#为全连接层准备tf.keras.layers.Flatten(),tf.keras.layers.Dense(512, activation='relu'),# 二分类sigmoid就够了tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    model.summary()

  • 配置训练器

    model.compile(loss='binary_crossentropy',optimizer=Adam(lr=1e-4),metrics=['acc'])

    数据预处理

  • 读进来的数据会被自动转换成tensor(float32)格式,分别准备训练和验证
  • 图像数据归一化(0-1)区间
    train_datagen = ImageDataGenerator(rescale=1./255)
    test_datagen = ImageDataGenerator(rescale=1./255)
    train_generator = train_datagen.flow_from_directory(train_dir,  # 文件夹路径target_size=(64, 64),  # 指定resize成的大小batch_size=20,# 如果one-hot就是categorical,二分类用binary就可以class_mode='binary')validation_generator = test_datagen.flow_from_directory(validation_dir,target_size=(64, 64),batch_size=20,class_mode='binary')

    训练网络模型

  • 直接fit也可以,但是通常咱们不能把所有数据全部放入内存,fit_generator相当于一个生成器,动态产生所需的batch数据
  • steps_per_epoch相当给定一个停止条件,因为生成器会不断产生batch数据,说白了就是它不知道一个epoch里需要执行多少个step
    history = model.fit_generator(train_generator,steps_per_epoch=100,  # 2000 images = batch_size * stepsepochs=20,validation_data=validation_generator,validation_steps=50,  # 1000 images = batch_size * stepsverbose=2)
    Epoch 1/20
    100/100 - 7s - loss: 0.6892 - acc: 0.5325 - val_loss: 0.6705 - val_acc: 0.5970
    Epoch 2/20
    100/100 - 6s - loss: 0.6595 - acc: 0.6055 - val_loss: 0.6346 - val_acc: 0.6470
    Epoch 3/20
    100/100 - 6s - loss: 0.6350 - acc: 0.6515 - val_loss: 0.6358 - val_acc: 0.6320
    Epoch 4/20
    100/100 - 7s - loss: 0.5936 - acc: 0.6865 - val_loss: 0.5906 - val_acc: 0.6780
    Epoch 5/20
    100/100 - 7s - loss: 0.5530 - acc: 0.7170 - val_loss: 0.5978 - val_acc: 0.6670
    Epoch 6/20
    100/100 - 8s - loss: 0.5179 - acc: 0.7490 - val_loss: 0.5484 - val_acc: 0.7140
    Epoch 7/20
    100/100 - 8s - loss: 0.4854 - acc: 0.7725 - val_loss: 0.5686 - val_acc: 0.7080
    Epoch 8/20
    100/100 - 8s - loss: 0.4595 - acc: 0.7905 - val_loss: 0.5452 - val_acc: 0.7150
    Epoch 9/20
    100/100 - 8s - loss: 0.4406 - acc: 0.7885 - val_loss: 0.5453 - val_acc: 0.7210
    Epoch 10/20
    100/100 - 7s - loss: 0.4109 - acc: 0.8170 - val_loss: 0.5317 - val_acc: 0.7270
    Epoch 11/20
    100/100 - 8s - loss: 0.3892 - acc: 0.8285 - val_loss: 0.5384 - val_acc: 0.7220
    Epoch 12/20
    100/100 - 8s - loss: 0.3542 - acc: 0.8570 - val_loss: 0.5480 - val_acc: 0.7180
    Epoch 13/20
    100/100 - 8s - loss: 0.3421 - acc: 0.8580 - val_loss: 0.5355 - val_acc: 0.7420
    Epoch 14/20
    100/100 - 8s - loss: 0.3217 - acc: 0.8665 - val_loss: 0.5572 - val_acc: 0.7340
    Epoch 15/20
    100/100 - 8s - loss: 0.2931 - acc: 0.8805 - val_loss: 0.5545 - val_acc: 0.7400
    Epoch 16/20
    100/100 - 8s - loss: 0.2739 - acc: 0.8870 - val_loss: 0.5540 - val_acc: 0.7360
    Epoch 17/20
    100/100 - 8s - loss: 0.2535 - acc: 0.9040 - val_loss: 0.5564 - val_acc: 0.7380
    Epoch 18/20
    100/100 - 8s - loss: 0.2257 - acc: 0.9245 - val_loss: 0.5710 - val_acc: 0.7420
    Epoch 19/20
    100/100 - 8s - loss: 0.2084 - acc: 0.9350 - val_loss: 0.5734 - val_acc: 0.7460
    Epoch 20/20
    100/100 - 8s - loss: 0.2258 - acc: 0.9130 - val_loss: 0.5897 - val_acc: 0.7300
    

    效果展示

    import matplotlib.pyplot as plt
    acc = history.history['acc']
    val_acc = history.history['val_acc']
    loss = history.history['loss']
    val_loss = history.history['val_loss']epochs = range(len(acc))plt.plot(epochs, acc, 'bo', label='Training accuracy')
    plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
    plt.title('Training and validation accuracy')plt.figure()plt.plot(epochs, loss, 'bo', label='Training Loss')
    plt.plot(epochs, val_loss, 'b', label='Validation Loss')
    plt.title('Training and validation loss')
    plt.legend()plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/141969.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

DAZ To UMA⭐三.导入Blender的配置, 及Blender快捷键

文章目录 🟥 Blender快捷键1️⃣ 3D视图快捷键2️⃣ 视角快捷键3️⃣ 编辑快捷键4️⃣ 对物体的操作🟧 Blender导入FBX的配置🟩 设置脸部骨骼大小1️⃣ 切换视角2️⃣ 缩小脸部骨骼3️⃣ 本节效果预览🟦 设置眼角膜透明度🟥 Blender快捷键 1️⃣ 3D视图快捷键 快捷键…

【C语言】进阶——结构体+枚举+联合

①前言: 在之前【C语言】初阶——结构体 ,简单介绍了结构体。而C语言中结构体的内容还有更深层次的内容。 一.结构体 结构体(struct)是由一系列具有相同类型或不同类型的数据项构成的数据集合,这些数据项称为结构体的成员。 1.结构体的声明 …

CSS实现鼠标悬停图片上升显示

文章目录 前言一、实现效果二、实现思路 前言 当我们想在图片上面放置一些文字内容时,发现不管怎么放置,要么就是图片影响到文字的观感,要么就是文字挡住图片的细节,那么怎么可以既看到图片的细节又可以看到对图片的文字描述呢&a…

mysqld_exporter监控MySQL服务

一、MySQL授权 1、登录MySQL服务器对监控使用的账号授权 CREATE USER exporterlocalhost IDENTIFIED BY 123456 WITH MAX_USER_CONNECTIONS 3; GRANT PROCESS, REPLICATION CLIENT, SELECT ON *.* TO exporterlocalhost; flush privileges;2、上传mysqld_exporter安装包&#…

springboot整合MeiliSearch轻量级搜索引擎

一、Meilisearch与Easy Search点击进入官网了解&#xff0c;本文主要从小微型公司业务出发&#xff0c;选择meilisearch来作为项目的全文搜索引擎&#xff0c;还可以当成来mongodb来使用。 二、starter封装 1、项目结构展示 2、引入依赖包 <dependencies><dependenc…

基于图像形态学处理的路面裂缝检测算法matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 matlab2022a 3.部分核心程序 ...................................................... %1&#xff1a;从文件夹中读取多个…

spring的ThreadPoolTaskExecutor装饰器传递调用线程信息给线程池中的线程

概述 需求是想在线程池执行任务的时候&#xff0c;在开始前将调用线程的信息传到子线程中&#xff0c;在子线程完成后&#xff0c;再清除传入的数据。 下面使用了spring的ThreadPoolTaskExecutor来实现这个需求. ThreadPoolTaskExecutor 在jdk中使用的是ThreadPoolExecutor…

asp.net企业生产管理系统VS开发sqlserver数据库web结构c#编程Microsoft Visual Studio

一、源码特点 asp.net 企业生产管理系统 是一套完善的web设计管理系统&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为vs2010&#xff0c;数据库为sqlserver2008&#xff0c;使用c#语 言开发 二、功能介绍 (1)用户管理&…

C# OpenCvSharp 基于直线检测的文本图像倾斜校正

效果 项目 代码 using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq; using System.Text; using System.Windows.Forms; using OpenCvSharp;namespace OpenCvSharp_基于直线检测的文…

Linux基础指令(五)

目录 前言1. 打包和压缩1.1 是什么1.2 为什么1.3 怎么办&#xff1f; 2. zip & unzip3. tar 指令结语&#xff1a; 前言 欢迎各位伙伴来到学习 Linux 指令的 第五天&#xff01;&#xff01;&#xff01; 在上一篇文章 Linux基本指令(四) 当中&#xff0c;我们学习了 fin…

C语言入门Day_25 函数与指针小结

目录 前言&#xff1a; 1.函数 2.指针 3.易错点 4.思维导图 前言&#xff1a; 函数就像一个“有魔法的加工盒”&#xff0c;你从入口丢一些原材料进去&#xff0c;它就能加工出一个成品。不同的函数能加工出不同的成品。 入口丢进去的瓶子&#xff0c;水和标签就是输入&a…

【服务端 | Redis】如何使用redis 有序集合实现股票交易的订单表(价格优先、时间优先)

前两天倒腾redis的有序集合时&#xff0c;自己发现了一个问题&#xff0c;redis的有序集合在score相同的情况 下是如何排序的&#xff1f; 通过谷歌搜索&#xff0c;发现了一些线索&#xff0c;在score相同的情况下&#xff0c;redis使用字典排序&#xff0c;不过不是太明白什…

计算机网络相关知识点(二)

TCP如何保证传输过程的可靠性&#xff1f; 校验和&#xff1a;发送方在发送数据之前计算校验和&#xff0c;接收方收到数据之后同样需要计算&#xff0c;如果不一致&#xff0c;那么代表传输有问题。 确认应答序&#xff0c;序列号&#xff1a;TCP进行传输时数据都进行了编号…

排查内存泄露

1 通过Performance确认是否存在内存泄露 一个存在内存泄露的 DEMO 代码&#xff1a; App.vue <template><div><button click"myFn" style"width: 200px; height: 200px;"></button><home v-if"ishow"></hom…

Xcode15下载iOS17一直中断解决办法

1、问题描述 目前的 xcode 15 安装时&#xff0c;跟以前有个差别&#xff1a;以往的 xcode 安装时自带了 ide、sdk 等工具包&#xff0c;安装后即可开始开发&#xff0c;而最新的包则被分开成了不同的包&#xff0c;这里以 ios 开发包为例&#xff1a;Xcode_15.xip 和 iOS_17_…

软件测试之敏捷项目风险管理

敏捷项目管理是近年来最为流行的项目管理方式之一。这主要归功于敏捷管理的特点&#xff1a;尽早交付、持续改进、灵活管理、团队投入、充分测试。它能充分利用测试周期&#xff0c;并监测每个测试过程中容易出现的问题&#xff0c;加快项目迭代速度&#xff0c;从而推进项目高…

docker 配置 gpu版pytorch环境--部署缺陷检测--Anomalib

目录 一、docker 配置 gpu版pyhorch环境1、显卡驱动、cuda版本、pytorch cuda版本三者对应2、拉取镜像 二、部署Anomalib1、下载Anomalib2、创建容器并且运行3、安装Anomalib进入项目路径安装依赖测试&#xff1a; 一、docker 配置 gpu版pyhorch环境 1、显卡驱动、cuda版本、p…

《优化接口设计的思路》系列:第四篇—接口的权限控制

系列文章导航 《优化接口设计的思路》系列&#xff1a;第一篇—接口参数的一些弯弯绕绕 《优化接口设计的思路》系列&#xff1a;第二篇—接口用户上下文的设计与实现 《优化接口设计的思路》系列&#xff1a;第三篇—留下用户调用接口的痕迹 《优化接口设计的思路》系列&#…

ROS2 从头开始:第 08/8回 - 使用 ROS2 生命周期节点简化机器人软件组件管理

一、说明 欢迎来到我在 ROS2 上的系列的第八部分。对于那些可能不熟悉该系列的人,我已经涵盖了一系列主题,包括 ROS2 简介、如何创建发布者和订阅者、自定义消息和服务创建、

02 MIT线性代数-矩阵消元 Elimination with matrices

一, 消元法 Method of Elimination 消元法是计算机软件求解线形方程组所用的最常见的方法。任何情况下&#xff0c;只要是矩阵A可逆&#xff0c;均可以通过消元法求得Axb的解 eg: 我们将矩阵左上角的1称之为“主元一”&#xff08;the first pivot&#xff09;&#xff0c;第…