yolov8实战第四天——yolov8图像分类 ResNet50图像分类(保姆式教程)

yolov8实战第一天——yolov8部署并训练自己的数据集(保姆式教程)_yolov8训练自己的数据集-CSDN博客在前几天,我们使用yolov8进行了部署,并在目标检测方向上进行自己数据集的训练与测试,今天我们训练下yolov8的图像分类,看看效果如何,同时使用resnet50也训练一个分类模型,看看哪个效果好!

图像分类是指将输入的图像自动分类为不同的类别。它是计算机视觉领域的一个重要应用,可以用于人脸识别、物体识别、场景分类等任务。

通常情况下,图像分类的流程如下:

  1. 收集和准备数据集:收集与任务相关的图像数据,并将其打上标签。
  2. 定义模型:选择一种适合于你的任务的深度学习模型,例如卷积神经网络(CNN)。
  3. 训练模型:使用收集到的数据集对模型进行训练,通过反向传播算法来更新模型参数,使其可以根据输入图像进行正确的分类。
  4. 评估模型性能:使用测试集对已经训练好的模型进行评估,比较模型预测结果与真实标签之间的差异,从而评估模型的性能。
  5. 使用模型进行预测:使用已经训练好的模型对新的图像进行分类预测。

在实际应用中,可以使用各种深度学习框架(例如 TensorFlow、PyTorch、Keras 等)来构建图像分类模型,并使用各种数据增强技术(例如旋转、缩放、裁剪等)来增加数据集的多样性和数量。

如果你想学习如何使用深度学习框架来构建图像分类模型,可以参考一些在线教程、书籍或者 MOOC。

一、yolov8图像分类

1.模型选型

下载yolov8分类模型。

分别使用模型进行测试:

yolov8n-cls效果:

yolov8m-cls效果:

总结:n效果不咋地,还是得使用m进行后续训练工作。 

2.数据集准备

皮肤癌检测_数据集-飞桨AI Studio星河社区

同目标检测,还是放在datasets下。

直接改成这个,省去分数据集操作。 

 3.训练

yolo classify train data=./datasets/skin-cancer-detection model=yolov8n-cls.pt epochs=100

测试:

yolo classify predict model=runs/classify/train4/weights/best.pt source='./datasets/skin-cancer-detection/train/nevus'

  

label: 

 pred:

总结:数据集比较小,yolov8效果不太好。

、resnet50图像分类

Resnet50 网络中包含了 49 个卷积层、一个全连接层。如图下图所示,Resnet50网络结构可以分成七个部分,第一部分不包含残差块,主要对输入进行卷积、正则化、激活函数、最大池化的计算。第二、三、四、五部分结构都包含了残差块,图 中的绿色图块不会改变残差块的尺寸,只用于改变残差块的维度。在 Resnet50 网 络 结 构 中 , 残 差 块 都 有 三 层 卷 积 , 那 网 络 总 共 有1+3×(3+4+6+3)=49个卷积层,加上最后的全连接层总共是 50 层,这也是Resnet50 名称的由来。网络的输入为 224×224×3,经过前五部分的卷积计算,输出为 7×7×2048,池化层会将其转化成一个特征向量,最后分类器会对这个特征向量进行计算并输出类别概率。

运行train.py即可。

train.py

import torch
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import timeimport numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm# 一、建立数据集
# animals-6
#   --train
#       |--dog
#       |--cat
#       ...
#   --valid
#       |--dog
#       |--cat
#       ...
#   --test
#       |--dog
#       |--cat
#       ...
# 我的数据集中 train 中每个类别60张图片,valid 中每个类别 10 张图片,test 中每个类别几张到几十张不等,一共 6 个类别。# 二、数据增强
# 建好的数据集在输入网络之前先进行数据增强,包括随机 resize 裁剪到 256 x 256,随机旋转,随机水平翻转,中心裁剪到 224 x 224,转化成 Tensor,正规化等。
image_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),transforms.RandomRotation(degrees=15),transforms.RandomHorizontalFlip(),transforms.CenterCrop(size=224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])]),'valid': transforms.Compose([transforms.Resize(size=256),transforms.CenterCrop(size=224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
}# 三、加载数据
# torchvision.transforms包DataLoader是 Pytorch 重要的特性,它们使得数据增加和加载数据变得非常简单。
# 使用 DataLoader 加载数据的时候就会将之前定义的数据 transform 就会应用的数据上了。
dataset = 'skin-cancer-detection'
train_directory = './skin-cancer-detection/train'
valid_directory = './skin-cancer-detection/val'batch_size = 32
num_classes = 9 #分类种类数
print(train_directory)
data = {'train': datasets.ImageFolder(root=train_directory, transform=image_transforms['train']),'valid': datasets.ImageFolder(root=valid_directory, transform=image_transforms['valid'])
}
print("训练集图片类别及其对应编号(种类名:编号):",data['train'].class_to_idx)
print("测试集图片类别及其对应编号:",data['valid'].class_to_idx)train_data_size = len(data['train'])
valid_data_size = len(data['valid'])train_data = DataLoader(data['train'], batch_size=batch_size, shuffle=True, num_workers=0)
valid_data = DataLoader(data['valid'], batch_size=batch_size, shuffle=True, num_workers=0)print("训练集图片数量:",train_data_size, "测试集图片数量:",valid_data_size)# 四、迁移学习
# 这里使用ResNet-50的预训练模型。
#resnet50 = models.resnet50(pretrained=True)
resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)# 在PyTorch中加载模型时,所有参数的‘requires_grad’字段默认设置为true。这意味着对参数值的每一次更改都将被存储,以便在用于训练的反向传播图中使用。
# 这增加了内存需求。由于预训练的模型中的大多数参数已经训练好了,因此将requires_grad字段重置为false。
for param in resnet50.parameters():param.requires_grad = False# 为了适应自己的数据集,将ResNet-50的最后一层替换为,将原来最后一个全连接层的输入喂给一个有256个输出单元的线性层,接着再连接ReLU层和Dropout层,然后是256 x 6的线性层,输出为6通道的softmax层。
fc_inputs = resnet50.fc.in_features
resnet50.fc = nn.Sequential(nn.Linear(fc_inputs, 256),nn.ReLU(),nn.Dropout(0.4),nn.Linear(256, num_classes),nn.LogSoftmax(dim=1)
)# 用GPU进行训练。
resnet50 = resnet50.to('cuda:0')# 定义损失函数和优化器。
loss_func = nn.NLLLoss()
optimizer = optim.Adam(resnet50.parameters())# 五、训练
def train_and_valid(model, loss_function, optimizer, epochs=25):device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")history = []best_acc = 0.0best_epoch = 0for epoch in range(epochs):epoch_start = time.time()print("Epoch: {}/{}".format(epoch+1, epochs))model.train()train_loss = 0.0train_acc = 0.0valid_loss = 0.0valid_acc = 0.0for i, (inputs, labels) in enumerate(tqdm(train_data)):inputs = inputs.to(device)labels = labels.to(device)#因为这里梯度是累加的,所以每次记得清零optimizer.zero_grad()outputs = model(inputs)loss = loss_function(outputs, labels)print("标签值:",labels)print("输出值:",outputs)loss.backward()optimizer.step()train_loss += loss.item() * inputs.size(0)ret, predictions = torch.max(outputs.data, 1)correct_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor))train_acc += acc.item() * inputs.size(0)with torch.no_grad():model.eval()for j, (inputs, labels) in enumerate(tqdm(valid_data)):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)loss = loss_function(outputs, labels)valid_loss += loss.item() * inputs.size(0)ret, predictions = torch.max(outputs.data, 1)correct_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor))valid_acc += acc.item() * inputs.size(0)avg_train_loss = train_loss/train_data_sizeavg_train_acc = train_acc/train_data_sizeavg_valid_loss = valid_loss/valid_data_sizeavg_valid_acc = valid_acc/valid_data_sizehistory.append([avg_train_loss, avg_valid_loss, avg_train_acc, avg_valid_acc])if best_acc < avg_valid_acc:best_acc = avg_valid_accbest_epoch = epoch + 1epoch_end = time.time()print("Epoch: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \n\t\tValidation: Loss: {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s".format(epoch+1, avg_valid_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start))print("Best Accuracy for validation : {:.4f} at epoch {:03d}".format(best_acc, best_epoch))torch.save(model, 'models/'+dataset+'_model_'+str(epoch+1)+'.pt')return model, historynum_epochs = 100 #训练周期数
trained_model, history = train_and_valid(resnet50, loss_func, optimizer, num_epochs)
torch.save(history, 'models/'+dataset+'_history.pt')history = np.array(history)
plt.plot(history[:, 0:2])
plt.legend(['Tr Loss', 'Val Loss'])
plt.xlabel('Epoch Number')
plt.ylabel('Loss')
plt.ylim(0, 1)
plt.savefig(dataset+'_loss_curve.png')
plt.show()plt.plot(history[:, 2:4])
plt.legend(['Tr Accuracy', 'Val Accuracy'])
plt.xlabel('Epoch Number')
plt.ylabel('Accuracy')
plt.ylim(0, 1)
plt.savefig(dataset+'_accuracy_curve.png')
plt.show()

测试:图片名改下即可。

import torch
from torchvision import  models, transforms
import torch.nn as nn
import cv2
classes = ["1","2","3","4","5","6","7","8","9"] #识别种类名称(顺序要与训练时的数据导入编号顺序对应,可以使用datasets.ImageFolder().class_to_idx来查看)transf = transforms.ToTensor()
device = torch.device('cuda:0')
num_classes = 2
model_path = "models/skin-cancer-detection_model_3.pt"
image_input = cv2.imread("ISIC_0000019.jpg")
image_input = transf(image_input)
image_input = torch.unsqueeze(image_input,dim=0).cuda()
#搭建模型
resnet50 = models.resnet50(pretrained=True)
for param in resnet50.parameters():param.requires_grad = Falsefc_inputs = resnet50.fc.in_features
resnet50.fc = nn.Sequential(nn.Linear(fc_inputs, 256),nn.ReLU(),nn.Dropout(0.4),nn.Linear(256, num_classes),nn.LogSoftmax(dim=1)
)
resnet50 = torch.load(model_path)outputs = resnet50(image_input)
value,id =torch.max(outputs,1)
print(outputs,"\n","结果是:",classes[id])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/823400.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【滑动窗口】C++算法:K 个不同整数的子数组

作者推荐 动态规划 多源路径 字典树 LeetCode2977:转换字符串的最小成本 本题涉及知识点 滑动窗口 LeetCoe992 K 个不同整数的子数组 给定一个正整数数组 nums和一个整数 k&#xff0c;返回 nums 中 「好子数组」 的数目。 如果 nums 的某个子数组中不同整数的个数恰好为 …

【AI导师】利用Coding Agent完成AIGC编程

利用Coding Agent完成AIGC编程 一、前言二、Coding Agent三、1024code四、AI导师README项目初版功能定义代码结构设计方案函数方法设计方案迭代记录 一、前言 AI产品的发展确实在过去两年年中取得了显著进展&#xff0c;尤其是在编程领域。一开始&#xff0c;ChatGPT和类似的语…

前后端分离架构的特点以及优缺点

文章目录 一、前后端不分离架构(传统单体结构)1.1 什么是前后端不分离1.2 工作原理1.3 前后端不分离的优缺点1.4 应用场景 二、前后端分离架构2.1 为什么要前后端分离2.2 什么是前后端分离2.3 工作原理2.4 前后端分离的优缺点 参考资料 一、前后端不分离架构(传统单体结构) 首…

阿里后端实习二面

阿里后端实习二面 记录面试题目&#xff0c;希望可以帮助到大家 类加载的流程&#xff1f; 类加载分为三个部分&#xff1a;加载、连接、初始化 加载 类的加载主要的职责为将.class文件的二进制字节流读入内存(JDK1.7及之前为JVM内存&#xff0c;JDK1.8及之后为本地内存)&…

EBU7140 Security and Authentication(一)常见加密算法

前言 主要根据 EBU7140 课程内容整理&#xff0c;比较偏向应试~ Block1&#xff1a;介绍课程&#xff0c;传统加密方式。 Block2&#xff1a;公钥加密的原理和应用。 Block3&#xff1a;一些特定安全协议技术&#xff08;如防火墙 Kerberos身份验证协议等&#xff09;。 B…

【教学类-43-03】20231229 N宫格数独3.0(n=1、2、3、4、6、8、9) (ChatGPT AI对话大师生成 回溯算法)

作品展示&#xff1a; 背景需求&#xff1a; 大4班20号说&#xff1a;我不会做这种&#xff08;九宫格&#xff09;&#xff0c;我做的是小格子的&#xff0c; 他把手工纸翻过来&#xff0c;在反面自己画了矩阵格子。向我展示&#xff1a;“我会做这种&#xff01;” 原来他会…

《PCI Express体系结构导读》随记 —— 第I篇 第1章 PCI总线的基本知识(15)

接前一篇文章&#xff1a;《PCI Express体系结构导读》随记 —— 第I篇 第1章 PCI总线的基本知识&#xff08;14&#xff09; 1.3 PCI总线的存储器读写总线事务 1.3.4 PCI读写主存储器 前文已提到&#xff0c;由于本节内容较长&#xff0c;因此将后一部分内容放在本文中。 为…

Java多线程技术五——单例模式与多线程

1 概述 本章的知识点非常重要。在单例模式与多线程技术相结合的过程中&#xff0c;我们能发现很多以前从未考虑过的问题。这些不良的程序设计如果应用在商业项目中将会带来非常大的麻烦。本章的案例也充分说明&#xff0c;线程与某些技术相结合中&#xff0c;我们要考虑的事情会…

java注解和反射

java注解和反射 内置注解 Override 重写生命 Deprecated 已过时的方法&#xff0c;不推荐使用&#xff0c;可以使用 SuppressWarning 镇压警告&#xff0c;懂的都懂 元注解 作用&#xff1a;负责注解其他的注解 Target 描述注解的使用范围 Retention 描述注解的生命周期 Docu…

从座舱到跨域融合,老牌汽车零部件厂商如何破局数字化变革

当前&#xff0c;整个汽车供应链正在经历深层次的重构&#xff0c;传统零部件厂商必须加速“自我革新”。 在汽车“新四化”的巨变下&#xff0c;大量传统零部件濒临消失或者减少了需求&#xff0c;传统汽车零部件企业的相关业务开始日益萎缩&#xff0c;生存空间遭受不同程度…

我的128天之创作纪念日

目录 序 机缘 收获 日常 成就 憧憬 序 今天收到CSDN的一条消息推送&#xff0c;“初九之潜龙勿用 &#xff0c;不知不觉今天已经是你成为创作者的 第128天 啦。。。” 是啊&#xff0c;自今年8月24日开始写文章以来&#xff0c;时间过得好快&#xff0c;无论开心、痛苦…

51单片机之LED灯

51单片机之LED灯 &#x1f334;前言&#xff1a;&#x1f3ee;点亮LED灯的原理&#x1f498;点亮你的第一个LED灯&#x1f498;点亮你的八个LED灯 &#x1f4cc;让LED灯闪烁的原理&#x1f3bd; LED灯的闪烁&#x1f3d3;错误示范1&#x1f3d3;正确的LED闪烁代码应该是这样&am…

【开源】基于Vue+SpringBoot的公司货物订单管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 客户管理模块2.2 商品维护模块2.3 供应商管理模块2.4 订单管理模块 三、系统展示四、核心代码4.1 查询供应商信息4.2 新增商品信息4.3 查询客户信息4.4 新增订单信息4.5 添加跟进子订单 五、免责说明 一、摘要 1.1 项目…

4.26 构建onnx结构模型-Suqeeze

前言 构建onnx方式通常有两种&#xff1a; 1、通过代码转换成onnx结构&#xff0c;比如pytorch —> onnx 2、通过onnx 自定义结点&#xff0c;图&#xff0c;生成onnx结构 本文主要是简单学习和使用两种不同onnx结构&#xff0c; 下面以 Suqeeze 结点进行分析 方式 方法一…

three.js实现点击选中模型,模型描边高亮效果

射线投射器Raycaster通过.intersectObjects()判断模型是否选中EffectComposer.js进行后期处理&#xff0c;添加描边高亮效果 <template><div class"app"><div ref"canvesRef" class"canvas-wrap"></div></div> &…

Python面向对象编程 —— 类和异常处理

​ &#x1f308;个人主页: Aileen_0v0 &#x1f525;热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 &#x1f4ab;个人格言:"没有罗马,那就自己创造罗马~" 目录 1. 类 1.1 类的定义 1.2 类变量和实例变量 1.3 类的继承 2. 异常处理 2.1类型异常 2.…

【docker实战】安装tomcat并连接mysql数据库

本节用docker来安装tomcat&#xff0c;并用这个tomcat连接我们上一节安装好的mysql数据库 一、拉取镜像 [rootlocalhost data]# docker pull tomcat:8.5.69二、运行tomcat bitnami的tomcat的根目录在/opt/bitnami/tomcat/webapps下面&#xff0c;所以我们为了方便部署我们的…

Springboot整合MybatisPlus的基本CRUD

目录 前言1. 搭建项目2. 基本的CRUD 前言 发现项目框架是MybatisPlus的&#xff0c;由于个人使用该框架的CRUD比较少 对此学习过程中&#xff0c;从零到有开始搭建学习还是比较重要的&#xff0c;感悟会比较多 关于各个类的使用&#xff0c;可看如下文章&#xff1a; 剖析Ja…

DotNet 命令行开发

DotNet 命令行开发 下载安装下载 SDK安装 SDK绿色版下载绿化脚本 常用命令创建 dotnet new运行 dotnet run发布应用 dotnet publish更多命令 VSCode 调试所需插件调试 CS 配置项目.csproj排除依赖关系 launch.jsontasks.json 参考资料 下载安装 下载 SDK 我们就下最新的好&am…

每日一题——LeetCode961

方法一 排序法&#xff1a; 2*n长度的数组里面有一个元素重复了n次&#xff0c;那么将数组排序&#xff0c;求出排序后数组的中间值&#xff08;因为长度是偶数&#xff0c;没有刚好的中间值&#xff0c;默认求的中间值是偏左边的那个&#xff09;那么共有三种情况&#xff1a;…