图像分类架构

图像分类

  • 一、图像分类简介
  • 二、AlexNet
  • 三、VGG网络架构
  • 四、GoogLeNet
    • 4.1 Inception模块
    • 4.2 GoogLeNet构建
  • 五、ResNet
    • 5.1 定义ResNet的残差块
    • 5.2 ResNet网络中模块的构成
    • 5.3 ResNet网络的构建
  • 六、图像增强
  • 七、模型微调

一、图像分类简介

图像分类实质上就是从给定的类别集合中为图像分配对应标签的任务。也就是说我们的任务是分析一个输入输入图像并返回一个该图像类别的标签。
图像分类常用的数据集:mnist、CIFAR-100、CIFAT-10、ImageNet

二、AlexNet

  • AlexNet包含8层变换,有5层卷积和2层全连接隐藏层,以及一个全连接输出层。
  • AlexNet通过DropOut来控制全连接层的模型复杂度。
  1. AlexNet模型
# 导包
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.datasets import mnist
# 构建AlexNet模型
model = Sequential([layers.Conv2D(filters=96,kernel_size=11,strides=4,activation="relu"),layers.MaxPool2D(pool_size=3,strides=2),layers.Conv2D(filters=256,kernel_size=5,strides=1,padding="same",activation="relu"),layers.MaxPool2D(pool_size=3,strides=2),layers.Conv2D(filters=384,kernel_size=3,strides=1,padding="same",activation="relu"),layers.Conv2D(filters=384,kernel_size=3,strides=1,padding="same",activation="relu"),layers.Conv2D(filters=256,kernel_size=3,strides=1,padding="same",activation="relu"),layers.MaxPool2D(pool_size=3,strides=2),layers.Flatten(),layers.Dense(4096,activation='relu'),layers.Dropout(0.5),layers.Dense(4096,activation='relu'),layers.Dropout(0.5),layers.Dense(10,activation='softmax')
])
x = tf.random.uniform((1,227,227,1))
y = model(x)
model.summary()

在这里插入图片描述

  1. 手写数字势识别
# 加载数据
(train_image,train_label),(test_image,test_label) = mnist.load_data()
train_image.shape,test_image.shape

在这里插入图片描述

# 维度调整
train_image = train_image.reshape(60000,28,28,1)
test_image = test_image.reshape(10000,28,28,1)
import numpy as np
# 随机抽取样本
def get_trian(size):index = np.random.choice(60000,size,replace=False)# 将样本resize为227*227大小resize_image = tf.image.resize_with_pad(train_image[index],227,227)return resize_image.numpy(),train_label[index]
def get_test(size):index = np.random.choice(10000,size,replace=False)resize_image = tf.image.resize_with_pad(test_image[index],227,227)return resize_image.numpy(),test_label[index]
# 获取训练样本和测试样本
train_image, train_label = get_trian(256)
test_image, test_label = get_test(128)
# 数据展示
import matplotlib.pyplot as plt
for i in range(9):plt.subplot(3,3,i+1)plt.imshow(train_image[i],cmap='gray')plt.title(train_label[i])

在这里插入图片描述

# 模型编译
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
model.compile(optimizer=optimizer,loss=tf.keras.losses.sparse_categorical_crossentropy,metrics=['accuracy'])
# 模型训练
model.fit(train_image,train_label,batch_size=128,epochs=3,validation_split=0.1,verbose=1)

在这里插入图片描述

# 模型评估
model.evaluate(test_image,test_label,verbose=1)

在这里插入图片描述

三、VGG网络架构

VGG可以看成加深版的AlexNet,整个网络由卷积层和全连接层叠加而成。VGGNet使用的全部都是33的小卷积核和22的池化核,通过不断加深网络来提升性能。VGG可以通过重复使用简单的基础块来构建深度模型。

# 模型构建
def vgg(convs_arch):model = Sequential()# VGG块的构建for (num_convs,num_filters) in convs_arch:for _ in range(num_convs):model.add(layers.Conv2D(num_filters,kernel_size=3,padding='same',activation='relu'))model.add(layers.MaxPool2D(pool_size=2,strides=2))# 卷积块后添加全连接层model.add(Sequential([layers.Flatten(),layers.Dense(4096,activation='relu'),layers.Dropout(0.5),layers.Dense(4096,activation='relu'),layers.Dropout(0.5),layers.Dense(10,activation='softmax'),]))return model
# 卷积块参数
convs_arch = ((2,64),(2,128),(3,256),(3,512),(3,512))
model = vgg(convs_arch)
x = tf.random.uniform((1,224,224,1))
y = model(x)
model.summary()

在这里插入图片描述

四、GoogLeNet

Inception 层通过多个路径(不同的卷积核大小和池化操作)提取不同尺度的特征,并将这些特征图拼接在一起。每次通过一个 Inception 层时,可能会增加特征图的数量。

4.1 Inception模块

Inception 模块的主要目的是通过并行使用不同大小的卷积核来捕捉不同尺度的空间信息,并通过拼接这些特征图来增强模型的表征能力。

class Inception(tf.keras.layers.Layer):def __init__(self,c1,c2,c3,c4):super().__init__()# 线路1self.p1_1 = layers.Conv2D(c1,kernel_size=1,activation='relu',padding='same')# 线路2self.p2_1 = layers.Conv2D(c2[0],kernel_size=1,activation='relu',padding='same')self.p2_2 = layers.Conv2D(c2[1],kernel_size=3,activation='relu',padding='same')# 线路3self.p3_1 = layers.Conv2D(c3[0],kernel_size=1,activation='relu',padding='same')self.p3_2 = layers.Conv2D(c3[1],kernel_size=5,activation='relu',padding='same')# 线路4self.p4_1 = layers.MaxPool2D(pool_size=3,padding='same',strides=1)self.p4_2 = layers.Conv2D(c4,kernel_size=1,activation='relu',padding='same')# 向前传播过程def call(self,input):# 线路1p1 = self.p1_1(input)# 线路2p2 = self.p2_2(self.p2_1(input))# 线路3p3 = self.p3_2(self.p3_1(input))# 线路4p4 = self.p4_2(self.p4_1(input))outputs = tf.concat([p1,p2,p3,p4],axis=-1)return outputs
# 指定通道数,对Inception进行实例化
Inception(64,(96,128),(16,32),32)

4.2 GoogLeNet构建

# B1模块
inputs = tf.keras.Input(shape=(224,224,3),name="input")
x = tf.keras.layers.Conv2D(64,kernel_size=7,strides=2,padding="same",activation="relu")(inputs)
x = tf.keras.layers.MaxPool2D(pool_size=3,strides=2,padding="same")(x)
# B2模块
x = tf.keras.layers.Conv2D(64,kernel_size=1,strides=2,padding="same",activation="relu")(x)
x = tf.keras.layers.Conv2D(192,kernel_size=3,strides=2,padding="same",activation="relu")(x)
x = tf.keras.layers.MaxPool2D(pool_size=3,strides=2,padding="same")(x)
# B2模块
x = tf.keras.layers.Conv2D(64,kernel_size=1,strides=2,padding="same",activation="relu")(x)
x = tf.keras.layers.Conv2D(192,kernel_size=3,strides=2,padding="same",activation="relu")(x)
x = tf.keras.layers.MaxPool2D(pool_size=3,strides=2,padding="same")(x)
# B4模块
# 辅助分类器
def aux_classifier(x,filter_size):x = tf.keras.layers.AveragePooling2D(pool_size=5,strides=3,padding="same")(x)x = tf.keras.layers.Conv2D(filters=filter_size[0],kernel_size=1,strides=1,padding='valid',activation="relu")(x)x = tf.keras.layers.Flatten()(x)x = tf.keras.layers.Dense(units=filter_size[1],activation="relu")(x)x = tf.keras.layers.Dense(10,activation="softmax")(x)return x
x = Inception(192,(96,208),(16,48),64)(x)
# 辅助输出
aux_output1 = aux_classifier(x,[128,1024])
# Inception层
x = Inception(160, (112, 224), (24, 64), 64)(x)
x = Inception(128, (128, 256), (24, 64), 64)(x)
x = Inception(112, (144, 288), (32, 64), 64)(x)
# 辅助输出
aux_output2 = aux_classifier(x,[128,1024])
x = Inception(256,(160,320),(32,128),128)(x)
x = tf.keras.layers.MaxPool2D(pool_size=3,strides=2,padding='same')(x)
# B5模块
x = Inception(256,(160,320),(32,128),128)(x)
x = Inception(384,(192,384),(48,128),128)(x)
# 全局平均池化层(GPA),用来代替全连接层
x = tf.keras.layers.GlobalAvgPool2D()(x)
outputs = tf.keras.layers.Dense(10,activation='softmax')(x)
# 模型
model = tf.keras.Model(inputs=inputs,outputs=[outputs,aux_output1,aux_output2])
model.summary()

在这里插入图片描述

五、ResNet

5.1 定义ResNet的残差块

残差块主要用于解决深层神经网络中的梯度消失问题和训练过程中的退化问题。

class Residual(tf.keras.Model):def __init__(self,num_filters,use_1X1conv=False,strides=1):super(Residual,self).__init__()self.conv1 = tf.keras.layers.Conv2D(num_filters,kernel_size=3,strides=strides,padding='same')self.conv2 = tf.keras.layers.Conv2D(num_filters,kernel_size=3,strides=1,padding='same')if use_1X1conv:self.conv3 = tf.keras.layers.Conv2D(num_filters,kernel_size=1,strides=strides,padding='same')else:self.conv3 = None# BN层self.bn1 = tf.keras.layers.BatchNormalization()self.bn2 = tf.keras.layers.BatchNormalization()# 向前传播过程def call(self,x):y = tf.keras.activations.relu(self.bn1(self.conv1(x)))y = self.bn2(self.conv2(y))if self.conv3:x = self.conv3(x)# "跳跃连接"返回相加后激活的结果outputs = tf.keras.activations.relu(y + x)return outputs  

5.2 ResNet网络中模块的构成

ResnetBlock旨在通过串联多个残差块来构建深层神经网络,从而改善模型的训练稳定性和泛化能力。

# ResNet网络中模块的构成
class ResnetBlock(tf.keras.layers.Layer):# 定义所需的网络结构def __init__(self,num_filters,num_res,first_block=False):super(ResnetBlock,self).__init__()# 存储残差块self.listLayers = []for i in range(num_res):# 若为第一个残差块并且不是第一个模块,使用1*1卷积if i == 0 and not first_block:self.listLayers.append(Residual(num_filters,use_1X1conv=True,strides=2))else:self.listLayers.append(Residual(num_filters))def call(self,x):for layer in self.listLayers:x = layer(x)return x

5.3 ResNet网络的构建

# 构建ResNet网络
class ResNet(tf.keras.Model):# 定义网络的组成def __init__(self,num_blocks):super(ResNet,self).__init__()self.conv = tf.keras.layers.Conv2D(64,kernel_size=7,strides=2,padding='same')self.bn = tf.keras.layers.BatchNormalization()self.relu = tf.keras.layers.Activation('relu')self.mp = tf.keras.layers.MaxPool2D(pool_size=3,strides=2,padding='same')# 残差模块self.res_block1 = ResnetBlock(64,num_blocks[0],first_block=True)self.res_block2 = ResnetBlock(128,num_blocks[1])self.res_block3 = ResnetBlock(256,num_blocks[2])self.res_block4 = ResnetBlock(512,num_blocks[3])# GAPself.gap = tf.keras.layers.GlobalAvgPool2D()# 全连接层self.fc = tf.keras.layers.Dense(units=10,activation='softmax')# 定义前向传播过程def call(self,x):x = self.conv(x)x = self.bn(x)x = self.relu(x)x = self.mp(x)x = self.res_block1(x)x = self.res_block2(x)x = self.res_block3(x)x = self.res_block4(x)x = self.gap(x)x = self.fc(x)return x
# 实例化
my_net = ResNet([2,2,2,2])
x = tf.random.uniform((1,224,224,1))
y = my_net(x)
my_net.summary()

在这里插入图片描述

六、图像增强

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
# 读取图像并显示
cat = plt.imread("./cat.jpg")
plt.imshow(cat)

在这里插入图片描述

# 左右翻转
cat = tf.image.random_flip_left_right(cat)
plt.imshow(cat)

在这里插入图片描述

# 上下翻转
image = tf.image.random_flip_up_down(cat)
plt.imshow(image)

在这里插入图片描述

# 随机裁剪
image_1 = tf.image.random_crop(cat,(200,200,3))
plt.imshow(image_1)

在这里插入图片描述

# 亮度调整
image_2 = tf.image.random_brightness(cat,0.5)
plt.imshow(image_2)

在这里插入图片描述

# 颜色调整
image_3 = tf.image.random_hue(cat,0.4)
plt.imshow(image_3)

在这里插入图片描述

# 使用ImageDataGenerator()进行图像增强
tf.keras.preprocessing.image.ImageDataGenerator(rotation_range=0, #随机旋转的度数范围width_shift_range=0.0, #宽度平移height_shift_range=0.0, #高度平移brightness_range=None, #亮度调整shear_range=0.0, #裁剪zoom_range=0.0, #缩放horizontal_flip=False, #左右翻转vertical_flip=False, # 垂直翻转rescale=None, # 尺度调整
)
from tensorflow.keras.datasets import mnist
# 数据获取
(x_train,y_train),(x_test,y_test) = tf.keras.datasets.mnist.load_data()
x_train.shape,x_test.shape

在这里插入图片描述

x_train = x_train.reshape(60000,28,28,1)
x_test = x_test.reshape(10000,28,28,1)
# 实例化
datagen = tf.keras.preprocessing.image.ImageDataGenerator(horizontal_flip=True)
for x,y in datagen.flow(x_train,y_train,batch_size=9):for i in range(0,9):plt.subplot(3,3,i + 1)plt.imshow(x[i])plt.title(y[i])plt.show()break

在这里插入图片描述

七、模型微调

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1534315.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Spring扩展点系列-BeanFactoryAware

文章目录 简介源码分析示例代码示例一:验证BeanFactoryAware执行顺序示例二:动态获取其他bean示例三:动态bean的状态 简介 spring容器中Bean的生命周期内所有可扩展的点的调用顺序 扩展接口 实现接口ApplicationContextlnitializer initia…

git 更换远程地址的方法

需要将正在开发的代码远程地址改成新的地址,通过查询发现有三个方法可以实现,特此记录。具体方法如下: (1)通过命令直接修改远程仓库地址 git remote 查看所有远程仓库git remote xxx 查看指定远程仓库地址git remote…

MySQL代码顺序(整合)

这个图片也就是说明执行顺序 FROM > WHERE > GOURP BY > HAVING > SELECT > ORDER BY > LIMIT; 编写按照这个顺序写即可。

SX_VMware联网_23

利用Nat模式联网,NAT模式(Network Address Translation): 在NAT模式下,虚拟机通过主机的网络接口访问外部网络。 虚拟机之间可以相互通信,也可以访问主机网络以及互联网。 虚拟机使用私有IP地址&#xff0c…

工业互联网网络集成与实训系统解决方案

随着工业4.0时代的到来和信息技术的高速发展,工业互联网已成为推动产业升级的重要力量。本方案旨在通过构建高度仿真的实训环境,帮助学生全面掌握工业互联网技术,为未来的职业生涯奠定坚实基础。 一、设计理念 在设计理念上,本方…

【GIS开发小课堂】写一个高德地图巡航功能的小DEMO

介绍 此项目使用vite为基础架构,内部实现均以typescript开发,可替换为自己的业务逻辑,并迁移到react,vue,umi等其他框架。 通过调用高德地图的API和threejs的开发,实现了一个小鸭子(可替换为自己…

TiDB 扩容过程中 PD 生成调度的原理及常见问题丨TiDB 扩缩容指南(一)

导读 作为一个分布式数据库,扩缩容是 TiDB 集群最常见的运维操作之一。本系列文章,我们将基于 v7.5.0 具体介绍扩缩容操作的具体原理、相关配置及常见问题的排查。 通常,我们根据当前资源状态来决定是否需要调整 TiKV 节点的规模&#xff0…

Version ‘18.19.0‘ not found - try `nvm ls-remote` to browse available versions.

nvm安装指定版本不好使了 使用 nvm install 18.19.0 一直报错 Version 18.19.0 not found - try nvm ls-remote to browse available versions.然而使用 nvm ls-remote 只看到 iojs-v1.0.0iojs-v1.0.1iojs-v1.0.2iojs-v1.0.3iojs-v1.0.4iojs-v1.1.0iojs-v1.2.0iojs-v1.3.0iojs…

Wildberries测评自养号支付下单技术

Wildberries(俄语:ООО Ягодки)是俄罗斯最大的在线零售商,由Tatyana Bakalchuk于 2004 年创立。除俄罗斯外,他们还在其他 15 个国家提供服务:亚美尼亚、白俄罗斯、法国、德国、以色列、意大利、哈萨…

PHP省时省力海报在线制作系统小程序源码

省时省力海报在线制作系统:设计小白也能秒变大师 🎨 开篇:告别繁琐,拥抱高效设计 你还在为设计一张海报而熬夜加班吗?还在为找不到合适的素材而焦头烂额吗?别担心,“省时省力海报在线制作系统”…

使用开源框架HandyControl

准备 NuGet 搜索安装 HandyControl。 在App.xaml中添加以下代码&#xff1a; <Application.Resources><ResourceDictionary><ResourceDictionary.MergedDictionaries><ResourceDictionary Source"pack://application:,,,/HandyControl;component/…

大雪纷飞的视频素材去哪里找啊?雪景素材库分享

当冬季的银装素裹覆盖大地&#xff0c;无数抖音创作者便开始寻找那些可以捕捉到大雪纷飞的壮观画面。无论是为了制作节日主题的视频、记录下雪天的活动&#xff0c;还是单纯展示雪的清新美&#xff0c;优质的大雪视频素材都显得尤为重要。如果你正为寻找这类素材而苦恼&#xf…

建造者模式:灵活构建复杂对象的利器

在软件开发中&#xff0c;创建一个复杂对象通常需要多个步骤和参数&#xff0c;直接在客户端代码中进行这些操作不仅繁琐&#xff0c;而且难以维护。建造者模式&#xff08;Builder Pattern&#xff09;提供了一种优雅的解决方案&#xff0c;使得对象的创建过程更加清晰、灵活和…

磁盘写操作压力测试工具的设计与实现

磁盘写操作压力测试工具的设计与实现 1. 设计概述2. 关键技术点3. 伪代码设计4. C代码实现5. 运行与测试6. 结论在进行磁盘性能评估时,写操作压力测试是不可或缺的一部分。本篇文章将介绍如何使用C语言结合系统调用,设计并实现一个针对磁盘写操作的压力测试工具。这个工具将模…

LINUX网络编程:http

目录 1.认识http请求的字段 2.HTTP请求类 3.认识HTTP应答字段 4.HTTP应答类 5.源代码 协议就是一种约定&#xff0c;http也并不例外&#xff0c;使用http也无非就是&#xff0c;定义一个http请求的结构体&#xff0c;将结构体序列化为字符串&#xff0c;发送给服务器&…

2024年06月中国电子学会青少年软件编程(图形化)等级考试试卷(一级)答案 + 解析

青少年软件编程&#xff08;图形化&#xff09;等级考试试卷&#xff08;一级&#xff09; 分数&#xff1a;100 题数&#xff1a;37 一、单选题 音乐Video Game1的时长将近8秒&#xff0c;点击一次角色&#xff0c;下列哪个程序不能完整地播放音乐两次&#xff1f;&#xff0…

【Hot100】LeetCode—169. 多数元素

目录 1- 思路题目识别技巧 2- 实现⭐136. 只出现一次的数字——题解思路 3- ACM 实现 原题链接&#xff1a;169. 多数元素 1- 思路 题目识别 识别1 &#xff1a;统计数组中出现数量多余 [n/2] 的元素 技巧 值相同&#xff0c;则对 count 1&#xff0c;如果不相同则对值进行…

【C#】VS插件

翻译 目前推荐较多的 可以单词发言&#xff0c;目前还在开发阶段 TranslateIntoChinese - Visual Studio Marketplace 下载量最高的(推荐) Visual-Studio-Translator - Visual Studio Marketplace 支持翻译的版本较多&#xff0c;在 Visual Studio 代码编辑器中通过 Googl…

vue使用TreeSelect设置带所有父级节点的回显

Element Plus的el-tree-select组件 思路&#xff1a; 选中节点时&#xff0c;给选中的节点赋值 pathLabel&#xff0c;pathLabel 为函数生成的节点名字拼接&#xff0c;数据源中不包含。 在el-tree-select组件中设置 props“{ label: ‘pathLabel’ }” 控制选中时input框中回…

【信创】推荐一款好用的免费在线流程图思维导图工具 _ 统信 _ 麒麟 _ 方德

原文链接&#xff1a;【信创】推荐一款好用的免费在线流程图思维导图工具 | 统信 | 麒麟 | 方德 Hello&#xff0c;大家好啊&#xff01;今天给大家推荐一款非常好用的免费在线流程图和思维导图工具——ProcessOn。无论是项目管理、数据分析、头脑风暴还是日常办公&#xff0c;…