第J1周:ResNet-50算法实战与解析

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

    文章目录

    • 一、前期工作
      • 1、ResNet-50总体结构
      • 2、设置GPU
      • 3、导入数据
    • 二、数据预处理
      • 1、加载数据
      • 2、可视化数据
      • 3、再次检查数据
      • 4、配置数据集
    • 三、构建ResNet-50模型
    • 四、编译
    • 五、训练模型
    • 六、模型评估
    • 七、预测
    • 八、总结

电脑环境:
语言环境:Python 3.8.0
编译器:Jupyter Notebook
深度学习环境:tensorflow 2.17.0

一、前期工作

1、ResNet-50总体结构

在这里插入图片描述

2、设置GPU

import tensorflow as tfgpus = tf.config.list_physical_devices('GPU')if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)tf.config.set_visible_devices(gpus[0], 'GPU')

3、导入数据

import matplotlib.pyplot as pltplt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = Falseimport os, PIL, pathlib
import numpy as npfrom tensorflow import keras
from keras import layers, modelsdata_dir = './bird_photos'
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
image_count

565

二、数据预处理

1、加载数据

batch_size = 8
img_height = 224
img_width = 224train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset='training',seed=123,image_size=(img_height, img_width),batch_size=batch_size)val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset='validation',seed=123,image_size=(img_height, img_width),batch_size=batch_size)

我们可以通过class_names输出数据集的标签,按字母顺序对应于目录名称。

class_names = train_ds.class_names
class_names

[‘Bananaquit’, ‘Black Skimmer’, ‘Black Throated Bushtiti’, ‘Cockatoo’]

2、可视化数据

plt.figure(figsize=(10, 4))for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i+1)plt.imshow(images[i].numpy().astype('uint8'))plt.title(class_names[labels[i]])plt.axis('off')

在这里插入图片描述

plt.imshow(images[0].numpy().astype('uint8'))

在这里插入图片描述

3、再次检查数据

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break

(8, 224, 224, 3)
(8,)

4、配置数据集

AUTOTUNE = tf.data.experimental.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

三、构建ResNet-50模型

from keras import layersfrom keras.layers import Input, Activation, BatchNormalization, Flatten
from keras.layers import Dense, Conv2D, MaxPooling2D, ZeroPadding2D, AveragePooling2D
from keras.models import Modeldef identity_block(input_tensor, kernel_size, filters, stage, block):filters1, filters2, filters3 = filtersname_base = str(stage) + block + '_identity_block_'x = Conv2D(filters1, (1, 1), name=name_base + 'conv1')(input_tensor)x = BatchNormalization(name=name_base + 'bn1')(x)x = Activation('relu', name=name_base + 'relu1')(x)x = Conv2D(filters2, kernel_size, padding='same', name=name_base + 'conv2')(x)x = BatchNormalization(name=name_base + 'bn2')(x)x = Activation('relu', name=name_base + 'relu2')(x)x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)x = BatchNormalization(name=name_base + 'bn3')(x)x = layers.add([x, input_tensor], name=name_base + 'add')x = Activation('relu', name=name_base + 'relu3')(x)return xdef conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):filters1, filters2, filters3 = filtersres_name_base = str(stage) + block + '_conv_block_res_'name_base = str(stage) + block + '_conv_block_'x = Conv2D(filters1, (1, 1), strides=strides, name=name_base + 'conv1')(input_tensor)x = BatchNormalization(name=name_base + 'bn1')(x)x = Activation('relu', name=name_base + 'relu1')(x)x = Conv2D(filters2, kernel_size, padding='same', name=name_base + 'conv2')(x)x = BatchNormalization(name=name_base + 'bn2')(x)x = Activation('relu', name=name_base + 'relu2')(x)x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)x = BatchNormalization(name=name_base + 'bn3')(x)shortcut = Conv2D(filters3, (1, 1), strides=strides, name=res_name_base + 'conv')(input_tensor)shortcut = BatchNormalization(name=res_name_base + 'bn')(shortcut)x = layers.add([x, shortcut], name=name_base + 'add')x = Activation('relu', name=name_base + 'relu3')(x)return xdef ResNet50(input_shape=(224, 224, 3), classes=1000):img_input = Input(shape=input_shape)x = ZeroPadding2D((3, 3))(img_input)x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1')(x)x = BatchNormalization(name='bn_conv1')(x)x = Activation('relu')(x)x = MaxPooling2D((3, 3), strides=(2, 2))(x)x =     conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')x =     conv_block(x, 3, [128, 128, 512], stage=3, block='a')x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')x =     conv_block(x, 3, [256, 256, 1024], stage=4, block='a')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')x =     conv_block(x, 3, [512, 512, 2048], stage=5, block='a')x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')x = AveragePooling2D((7, 7), name='avg_pool')(x)x = Flatten()(x)x = Dense(classes, activation='softmax', name='fc1000')(x)model = Model(img_input, x, name='resnet50')# 加载预训练模型model.load_weights('./resnet50_weights_tf_dim_ordering_tf_kernels.h5')return modelmodel = ResNet50()
model.summary()
Model: "resnet50"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)              ┃ Output Shape           ┃        Param # ┃ Connected to           ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer (InputLayer)(None, 224, 224, 3)0-                      │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ zero_padding2d            │ (None, 230, 230, 3)0 │ input_layer[0][0]      │
│ (ZeroPadding2D)           │                        │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ conv1 (Conv2D)(None, 112, 112, 64)9,472 │ zero_padding2d[0][0]   │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ bn_conv1                  │ (None, 112, 112, 64)256 │ conv1[0][0]            │
│ (BatchNormalization)      │                        │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ activation (Activation)(None, 112, 112, 64)0 │ bn_conv1[0][0]         │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ max_pooling2d             │ (None, 55, 55, 64)0 │ activation[0][0]       │
│ (MaxPooling2D)            │                        │                │                        │
..............................................................
..............................................................
..............................................................
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ avg_pool                  │ (None, 1, 1, 2048)0 │ 5c_identity_block_rel… │
│ (AveragePooling2D)        │                        │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ flatten (Flatten)(None, 2048)0 │ avg_pool[0][0]         │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ fc1000 (Dense)(None, 1000)2,049,000 │ flatten[0][0]          │
└───────────────────────────┴────────────────────────┴────────────────┴────────────────────────┘Total params: 25,636,712 (97.80 MB)Trainable params: 25,583,592 (97.59 MB)Non-trainable params: 53,120 (207.50 KB)

四、编译

opt = tf.keras.optimizers.Adam(learning_rate=1e-4)
model.compile(optimizer=opt,loss='sparse_categorical_crossentropy',metrics=['accuracy'])

五、训练模型

epochs = 10
history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)
Epoch 1/10
57/57 ━━━━━━━━━━━━━━━━━━━━ 269s 1s/step - accuracy: 0.5021 - loss: 3.6748 - val_accuracy: 0.9646 - val_loss: 0.1640
Epoch 2/10
57/57 ━━━━━━━━━━━━━━━━━━━━ 6s 99ms/step - accuracy: 0.9636 - loss: 0.2068 - val_accuracy: 0.9823 - val_loss: 0.0241
Epoch 3/10
57/57 ━━━━━━━━━━━━━━━━━━━━ 6s 97ms/step - accuracy: 0.9800 - loss: 0.0443 - val_accuracy: 0.9912 - val_loss: 0.0115
Epoch 4/10
57/57 ━━━━━━━━━━━━━━━━━━━━ 10s 98ms/step - accuracy: 0.9943 - loss: 0.0286 - val_accuracy: 0.9912 - val_loss: 0.0183
Epoch 5/10
57/57 ━━━━━━━━━━━━━━━━━━━━ 11s 104ms/step - accuracy: 0.9945 - loss: 0.0377 - val_accuracy: 1.0000 - val_loss: 0.0108
Epoch 6/10
57/57 ━━━━━━━━━━━━━━━━━━━━ 10s 100ms/step - accuracy: 0.9995 - loss: 0.0038 - val_accuracy: 0.9735 - val_loss: 0.0359
Epoch 7/10
57/57 ━━━━━━━━━━━━━━━━━━━━ 10s 104ms/step - accuracy: 1.0000 - loss: 0.0024 - val_accuracy: 0.9912 - val_loss: 0.0196
Epoch 8/10
57/57 ━━━━━━━━━━━━━━━━━━━━ 10s 98ms/step - accuracy: 1.0000 - loss: 6.2409e-04 - val_accuracy: 0.9912 - val_loss: 0.0139
Epoch 9/10
57/57 ━━━━━━━━━━━━━━━━━━━━ 6s 106ms/step - accuracy: 1.0000 - loss: 5.9430e-04 - val_accuracy: 1.0000 - val_loss: 0.0103
Epoch 10/10
57/57 ━━━━━━━━━━━━━━━━━━━━ 6s 99ms/step - accuracy: 1.0000 - loss: 3.5871e-04 - val_accuracy: 1.0000 - val_loss: 0.0094

六、模型评估

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

七、预测

plt.figure(figsize=(10, 4))for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i+1)plt.imshow(images[i].numpy().astype('uint8'))img_array = tf.expand_dims(images[i], 0)predictions = model.predict(img_array)plt.title(class_names[np.argmax(predictions)])plt.axis('off')

在这里插入图片描述

八、总结

在一般的卷积神经网络中,由于深度的增加,可能会带来梯度爆炸,梯度消失,ResNet的残差网络结构可以有效解决这些问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1534148.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

初级练习[2]:Hive SQL查询汇总分析

目录 SQL查询汇总分析 成绩查询 查询编号为“02”的课程的总成绩 查询参加考试的学生个数 分组查询 查询各科成绩最高和最低的分 查询每门课程有多少学生参加了考试(有考试成绩) 查询男生、女生人数 分组结果的条件 查询平均成绩大于60分的学生的学号和平均成绩 查询至少…

基于python+django+vue的农业管理系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【2025最新】基于pythondjangovueMySQL的农…

C++ push_back和emplace_back的区别

基本类型情况西&#xff0c;两者几乎没什么区别 但是再自定义类型的时候&#xff1f;emplace——back更高效&#xff0c;但是emplace_back 没有类型检查的安全&#xff1b;只有运行时候才会报错。 #include <vector> #include <iostream> using namespace std; …

基于 CycleGAN 对抗网络的自定义数据集训练

目录 生成对抗网络&#xff08;GAN&#xff09; CycleGAN模型训练 训练数据生成 下载开源项目CycleGAN 配置训练环境 开始训练 模型测试 可视化结果 生成对抗网络&#xff08;GAN&#xff09; 首先介绍一下什么是GAN网络&#xff0c;它是由生成器&#xff08;Generator…

分类预测|基于差分优化DE-支持向量机数据分类预测完整Matlab程序 DE-SVM

分类预测|基于差分优化DE-支持向量机数据分类预测完整Matlab程序 DE-SVM 文章目录 一、基本原理DE-SVM 分类预测原理和流程总结 二、实验结果三、核心代码四、代码获取五、总结 一、基本原理 DE-SVM 分类预测原理和流程 1. 差分进化优化算法&#xff08;DE&#xff09; 原理…

【运维监控】Prometheus+grafana监控tomcat运行情况

运维监控系列文章入口&#xff1a;【运维监控】系列文章汇总索引 文章目录 一、prometheus二、grafana三、tomcat与jmx_exporter配置1、下载jmx_exporter2、部署jmx_exporter3、添加tomcat的配置信息4、修改tomcat的启动文件5、重启tomcat及验证6、其他 四、集成prometheus与gr…

vue3 动态 svg 图标使用

前言 在做后台管理系统中,我们经常会用到很多图标,比如左侧菜单栏的图标 当然这里 element-ui 或者 element-plus 组件库都会提供图标 但是在有些情况下 element-ui 或者 element-plus 组件库提供的图标满足不了我们的需求时,这个时候我们就需要自己去网上找一些素材或者…

CAN通讯常见错误

CAN通讯常见错误 1.在使用CAN设备进行数据通讯时&#xff0c;有时候参数配置不当可能就会导致通讯的失败&#xff0c;如下图1所示&#xff0c;出现通信错误的原因是两个设备的波特率配置不一致导致。 图1 2.有时候在配置参数的时候&#xff0c;不能只关注波特率速度配置一致…

JEE 设计模式

Java 数据访问对象模式 Java设计模式 - 数据访问对象模式 数据访问对象模式或DAO模式将数据访问API与高级业务服务分离。 DAO模式通常具有以下接口和类。 数据访问对象接口定义模型对象的标准操作。 数据访问对象类实现以上接口。可能有多个实现&#xff0c;例如&#xff0c…

关于Redis缓存一致性问题的优化和实践

目录标题 导语正文分布式场景下无法做到强一致即使是达到最终一致性也很难缓存的一致性问题缓存是如何写入的 如何感知数据库的变化最佳实践一&#xff1a;数据库变更后失效缓存最佳实践二&#xff1a;带版本写入 总结与展望阿里XKV腾讯DCache 导语 Redis缓存一致性的问题是经…

【API安全】威胁猎人发布超大流量解决方案

随着数字化进程加速&#xff0c;企业API接口数量激增&#xff0c;已经成为连接内外部服务的重要桥梁。然而&#xff0c;对于拥有庞大的外部客户群体和错综复杂的内部业务系统的大型企业而言&#xff0c;API安全管控面临超大流量下的性能瓶颈与数据安全双重挑战。 性能上&#…

【软件测试】常用的开发、测试模型

哈喽&#xff0c;哈喽&#xff0c;大家好~ 我是你们的老朋友&#xff1a;保护小周ღ 今天给大家带来的是 【软件测试】常用的开发、测试模型&#xff0c;首先了解, 什么是软件的生命周期, 测试的生命周期, 常见的开发模型: 瀑布, 螺旋, 增量, 迭代, 敏捷. 常用的测试模型, …

Serverless 安全新杀器:云安全中心护航容器安全

作者&#xff1a;胡志广(独鳌) 云安全中心对于 Serverless 容器用户的价值 从云计算发展之初&#xff0c;各大云厂商及传统安全厂商就开始围绕云计算的形态来做安全解决方案。传统安全与云计算安全的形态与做法开始发生变化&#xff0c;同时随着这 10 多年的发展&#xff0c;…

ThreeJS入门(002):学习思维路径

查看本专栏目录 - 本文是第 002篇入门文章 文章目录 如何使用这个思维导图 Three.js 学习思维导图可以帮助你系统地了解 Three.js 的各个组成部分及其关系。下面是一个简化的 Three.js 学习路径思维导图概述&#xff0c;它包含了学习 Three.js 的主要概念和组件。你可以根据这个…

Redis 入门 - 收官

《Redis 入门》系列文章总算完成了&#xff0c;希望这个系列文章可以想入门或刚入门的同学提供帮助&#xff0c;希望能让你形成学习Redis系统性概念。 当时为什么要写这个系列文章&#xff0c;是因为我自己就是迷迷糊糊一路踩坑走过来的&#xff0c;我踩完的坑就踩完了&#x…

Kamailio-基于Zabbix+Kamcli的SIP指标监控

什么是Kamailio? Kamailio 是一个开源的 Session Initiation Protocol (SIP) 服务器&#xff0c;它主要用于建立和管理实时通信会话&#xff0c;如语音和视频通话&#xff0c;与opensips这个产品是同根同源的存在。它们相似&#xff0c;没有更好&#xff0c;是有更合适。 此…

LLM - 理解 多模态大语言模型 (MLLM) 的指令微调与相关技术 (四)

欢迎关注我的CSDN&#xff1a;https://spike.blog.csdn.net/ 本文地址&#xff1a;https://spike.blog.csdn.net/article/details/142063880 免责声明&#xff1a;本文来源于个人知识与公开资料&#xff0c;仅用于学术交流&#xff0c;欢迎讨论&#xff0c;不支持转载。 完备(F…

获取京东商品详情数据API接口优惠券信息(通过商品id获取商品详情页数据)调用说明文档

在当今数字化时代&#xff0c;应用程序之间的互操作性已成为推动业务创新和技术进步的关键因素。API&#xff08;Application Programming Interface&#xff0c;应用程序编程接口&#xff09;作为这一生态系统中不可或缺的一环&#xff0c;扮演着连接不同软件服务、数据资源和…

AE 让合成重复循环播放

在合成上点右键 > Time > Enable Time Remapping 按住 Alt 键&#xff0c;点秒表图标 输入 loop_out("cycle", 0) 将子合成拖到此合成结束的位置 结束

Ton的编译过程(上)

系列文章目录 FunC编写初始准备 文章目录 系列文章目录预先准备第一个FunC合约深入compileFunc的内部compileFunc初探艾丽卡的疑惑package.json 初览index.js 预先准备 首先请大家跟着艾丽卡一步一步的完成FunC编写初始准备 这里面环境的搭建。 接下来&#xff0c;请做好下面…