【使用resnet18训练自己的数据集】

1.背景及准备

书接上文【以图搜图代码实现】–犬类以图搜图示例 总结了一下可以优化的点,其中提到使用自己的数据集训练网络,而不是单纯使用预训练的模型,这不就来了!!

使用11类犬类微调resnet18网络模型:
1. 数据准备
数据集】11种犬类,共1089张
链接:百度网盘链接
提取码:qlrt
在这里插入图片描述
2. 数据集划分
按照train和val8:2的比例进行划分,划分代码如下:

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :ImageRec 
@File    :split_data.py
@IDE     :PyCharm 
@Author  :菜菜2024
@Date    :2024/9/30
'''
import os
import shutil
import randomdef split_images_into_train_test(source_directory, train_directory, val_directory, train_ratio=0.8):"""将源文件夹下的图片按照指定比例分为训练集和测试集,并分别复制到train和val文件夹下。"""# 确保train和test目录存在,如果不存在则创建os.makedirs(train_directory, exist_ok=True)os.makedirs(val_directory, exist_ok=True)# 获取源文件夹中所有图片文件的列表image_files = [f for f in os.listdir(source_directory) if os.path.isfile(os.path.join(source_directory, f))]image_files = [f for f in image_files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]# 打乱图片文件列表的顺序random.shuffle(image_files)total_images = len(image_files)train_images_count = int(train_ratio * total_images)# 将图片复制到对应的文件夹下for i, image_file in enumerate(image_files):source_path = os.path.join(source_directory, image_file)if i < train_images_count:dest_path = os.path.join(train_directory, image_file)else:dest_path = os.path.join(val_directory, image_file)shutil.copy2(source_path, dest_path)  print(f"Copied {image_file} to {os.path.dirname(dest_path)}")if __name__ == '__main__':source_directory = "E:\\xxx\\datas\\imgs"train_directory = "E:\\xxx\\datas\\pet_dog\\train"val_directory = "E:\\xxx\\datas\\pet_dog\\val"file_list = os.listdir(source_directory)for file in file_list:source=os.path.join(source_directory, file)val = os.path.join(val_directory, file)train = os.path.join(train_directory, file)split_images_into_train_test(source, train, val)

最终效果:
在这里插入图片描述
train和val下的目录结果都是如下图所示,只是数量不一样。
在这里插入图片描述

2.代码实现

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :ImageRec 
@File    :train.py
@IDE     :PyCharm 
@Author  :菜菜2024
@Date    :2024/9/30  
'''
import torch
import os
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import argparsedef train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=5):"""参数:model: torch.nn.Module - 要训练的模型实例。dataloaders: dict - 包含训练集和验证集的数据加载器,例如{'train': train_loader, 'val': val_loader}。criterion: nn.Module - 用于计算损失的函数。optimizer: torch.optim.Optimizer - 用于更新模型参数的优化器。scheduler: torch.optim.lr_scheduler._LRScheduler - 学习率调度器。num_epochs: int - 训练的总轮数。"""device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model.to(device)for epoch in range(num_epochs):print(f"Epoch {epoch + 1}/{num_epochs}")for phase in ['train', 'val']:if phase == 'train':model.train()  # 设置模型为训练模式else:model.eval()   # 设置模型为评估模式running_loss = 0.0running_corrects = 0i = 0for inputs, labels in dataloaders[phase]:i+=1inputs, labels = inputs.to(device), labels.to(device)# 前向传播with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)loss = criterion(outputs, labels)if i%10==0:print(f"{phase} Loss: {loss:.4f}")_, preds = torch.max(outputs, 1)# 反向传播与优化(仅在训练阶段)if phase == 'train':optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)if phase == 'train':scheduler.step()epoch_loss = running_loss / len(dataloaders[phase].dataset)epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")print("Training complete.")# 训练完成后,保存模型状态字典(包含权重)torch.save(model.state_dict(), './weights/resnet18_dog.pth')def main():# 创建参数解析器parser = argparse.ArgumentParser(description='使用自己的数据集训练resnet18')# 添加参数parser.add_argument('--data_dir', type=str, default="E:\HWR_files\datas\pet_dog",help='Path to the dataset directory')parser.add_argument('--batch_size', type=int, default=16, help='Input batch size for training (default: 16)')parser.add_argument('--num_workers', type=int, default=2, help='Number of workers for data loading (default: 2)')parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (default: 0.001)')parser.add_argument('--num_epochs', type=int, default=5, help='Number of epochs to train (default: 25)')args = parser.parse_args()# 数据预处理和加载data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),}image_datasets = {x: datasets.ImageFolder(os.path.join(args.data_dir, x), data_transforms[x]) for x in['train', 'val']}dataloaders_dict = {x: DataLoader(image_datasets[x], batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) for xin ['train', 'val']}# 使用ResNet18模型model = models.resnet18(pretrained=True)# 加载之前保存的权重# model.load_state_dict(torch.load('./weights/resnet18_dog.pth'))num_features = model.fc.in_featuresmodel.fc = torch.nn.Linear(num_features, len(image_datasets['train'].classes))  # 修改最后一层全连接层device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = model.to(device)# 定义损失函数和优化器criterion = torch.nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9)scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)# 训练模型train_model(model, dataloaders_dict, criterion, optimizer, scheduler, args.num_epochs)if __name__ == '__main__':main()

在这里插入图片描述

3.代码测试

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :ImageRec 
@File    :test.py
@IDE     :PyCharm 
@Author  :菜菜2024
@Date    :2024/9/30 
'''
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, modelsdef test_model(weights_path, val_root, batch_size=4):"""使用验证集测试模型性能。参数:- weights_path: str, 训练好的模型权重文件路径- val_root: str, 验证数据集的根目录- batch_size: int, 数据加载时的批次大小"""# 设定数据预处理transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 加载验证集val_dataset = datasets.ImageFolder(root=val_root, transform=transform)val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)# 初始化模型并加载权重device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 修改最后一层全连接层确保num_classes与实际类别数匹配model = models.resnet18()num_features = model.fc.in_featuresmodel.fc = torch.nn.Linear(num_features, len(val_dataset.classes))model.load_state_dict(torch.load(weights_path, map_location=device))model.to(device)model.eval()  # 设置模型为评估模式# 测试循环correct = 0total = 0with torch.no_grad():for images, labels in val_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# 计算准确率并打印结果accuracy = 100 * correct / totalprint(f'Accuracy on validation set: {accuracy}%')if __name__ == '__main__':weights_path = './weights/resnet18_dog.pth'val_root = "E:\\xxx\\datas\\pet_dog\\val"test_model(weights_path, val_root)

结果图:
在这里插入图片描述
4.效果对比

书接上篇的图像检索:【以图搜图代码实现】–犬类以图搜图示例
来看看有没有准一点的

使用预训练的resnet18:
在这里插入图片描述
离谱了,匹配的前三个都是吉娃娃

看看使用微调之后的resnet18:
对应在上一篇种,模型加载和最后一层的输出个数变成类别数,这里是11。

在这里插入图片描述
哇哇哇!效果显著呀!!!

可以可以,下次尝试使用faiss喽

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1551327.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

如何构建一个生产级的AI平台(1)?

本文概述了生成式 AI 平台的常见组件、它们的作用以及它们的实现方式。 本文重点介绍部署 AI 应用程序的整体架构。 它讨论了需要哪些组件以及构建这些组件时的注意事项。 它不是关于如何构建 AI 应用程序。 这就是整体架构的样子。 这是一个相当复杂的系统。 这篇文章将从最…

css 中 ~ 符号、text-indent、ellipsis、ellipsis-2、text-overflow: ellipsis、::before的使用

1、~的使用直接看代码 <script setup> </script><template><div class"container"><p><a href"javascript:;">纪检委</a><a href"javascript:;">中介为</a><a href"javascript:…

SpringBoot技术栈:打造下一代网上租赁系统

第2章 关键技术简介 2.1 Java技术 Java是一种非常常用的编程语言&#xff0c;在全球编程语言排行版上总是前三。在方兴未艾的计算机技术发展历程中&#xff0c;Java的身影无处不在&#xff0c;并且拥有旺盛的生命力。Java的跨平台能力十分强大&#xff0c;只需一次编译&#xf…

传统操作系统和分布式操作系统的区别

分布式操作系统和传统操作系统之间的区别&#xff0c;根植于它们各自的设计哲学和目标。要理解这些差异&#xff0c;需要从操作系统的基本定义、结构、功能以及它们在不同计算环境中的表现进行分析。每种系统都试图解决特定的计算挑战&#xff0c;因此在不同的使用场景下具有各…

基于springboot+vue的社区流浪动物救助系统

摘要 本文介绍了一个基于Spring Boot和Vue.js技术的社区流浪动物救助系统。该系统采用前后端分离架构&#xff0c;后端使用Spring Boot框架进行开发&#xff0c;负责业务逻辑的处理和数据的交互&#xff1b;前端则使用Vue.js框架&#xff0c;为用户提供友好的交互界面。系统实现…

Springboot学习笔记(4)MybatisPlus

1. MybatisPlus 1.1 ORM介绍 ORM&#xff08;Object Relational Mapping&#xff0c;对象关系映射&#xff09;是为了解决面向对象与关系数据库存在的互不匹配现象的一种技术。 比如&#xff0c;将java中的对象传递到关系型数据库中去&#xff0c;或者将关系型数据库传递到jav…

HarmonyOS Next系列之水波纹动画特效实现(十三)

系列文章目录 HarmonyOS Next 系列之省市区弹窗选择器实现&#xff08;一&#xff09; HarmonyOS Next 系列之验证码输入组件实现&#xff08;二&#xff09; HarmonyOS Next 系列之底部标签栏TabBar实现&#xff08;三&#xff09; HarmonyOS Next 系列之HTTP请求封装和Token…

Webpack 打包后文件过大,如何优化?

聚沙成塔每天进步一点点 本文回顾 ⭐ 专栏简介Webpack 打包后文件过大&#xff0c;如何优化&#xff1f;1. 代码分割&#xff08;Code Splitting&#xff09;1.1 概念1.2 Webpack 的 SplitChunksPlugin示例配置&#xff1a; 1.3 按需加载&#xff08;Lazy Loading&#xff09;示…

【无人机设计与技术】四旋翼无人机的建模

摘要 本项目的目标是通过 Simulink 建模和仿真&#xff0c;研究四旋翼无人机的建模、姿态控制、定点位置控制及航点规划功能。无人机建模包含了动力单元模型、控制效率模型和刚体模型&#xff0c;并运用这些模型实现了姿态控制和位置控制。姿态控制为无人机的平稳飞行提供基础…

字体文件压缩

技术点 npm、html、font-spider 实现原理 个人理解&#xff1a;先引入原先字体&#xff0c;然后重置字符为空&#xff0c;根据你自己填充文字、字符等重新生成字体文件&#xff0c;因此在引入的时候务必添加自己使用的文字、字符等&#xff01;&#xff01;&#xff01; 实…

高校体育场小程序|高校体育场管理系统系统|体育场管理系统小程序设计与实现(源码+数据库+文档)

高校体育场管理系统小程序 目录 体育场管理系统小程序设计与实现 一、前言 二、系统功能设计 三、系统实现 四、数据库设计 1、实体ER图 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取&#xff1a; 博主介绍&#xff1a;✌️大厂码农|毕设布道…

ClickHouse入库时间与实际相差8小时问题

原因一&#xff1a;服务端未修改默认时区 解决方案&#xff1a; 1、找 ClickHouse 配置文件 config.xml&#xff0c;通常位于 /etc/clickhouse-server/ 目录。 2、编辑 config.xml 文件&#xff0c;找到 <timezone> 标签。如果标签不存在&#xff0c;需要手动添…

unity一键注释日志和反注释日志

开发背景&#xff1a;游戏中日志也是很大的开销&#xff0c;虽然有些日志不打印但是毕竟有字符串的开销&#xff0c;甚至有字符串拼接的开销&#xff0c;有些还有装箱和拆箱的开销&#xff0c;比如Debug.Log(1) 这种 因此需要注释掉&#xff0c;当然还需要提供反注释的功能&am…

避免学术欺诈!在ChatGPT帮助下实现严格引用并避免抄袭

学境思源&#xff0c;一键生成论文初稿&#xff1a; AcademicIdeas - 学境思源AI论文写作 当今的学术环境中&#xff0c;保持学术诚信至关重要。随着ChatGPT等技术的发展&#xff0c;写作变得更加高效&#xff0c;但也增加了不当使用的风险。严格的引用和避免抄袭不仅是学术道…

C++基础---类和对象(上)

1.类的定义 C程序设计允许程序员使用类&#xff08;class&#xff09;定义特定程序中的数据类型。这些数据类型的实例被称为对象 &#xff0c;这些实例可以包含程序员定义的成员变量、常量、成员函数&#xff0c;以及重载的运算符。语法上&#xff0c;类似C中结构体&#xff0…

Jmeter常用函数、逻辑控制器

目录 一、Jmeter常用函数 counter函数 machineName函数 machineIP函数 Random函数 RandomString函数 RandomDate函数 time函数 二、逻辑控制器 IF控制器 循环控制器 foreach控制器 仅一次控制器 事务控制器 聚合报告 随机控制器 随机顺序控制器 一、Jmeter常用…

趣味运动会分组记分指南

团队比赛时如何记分&#xff1f; 趣味运动会的组织过程中&#xff0c;分组和记分是两个关键环节。云分组小程序提供了一个高效的解决方案&#xff0c;无论是随机分组还是内定分组&#xff0c;都能轻松实现。系统还能自动统计积分和排名&#xff0c;极大简化了组织者的工作。 分…

如何在Python中计算移动平均值?

在这篇文章中&#xff0c;我们将看到如何在Python中计算移动平均值。移动平均是指总观测值集合中固定大小子集的一系列平均值。它也被称为滚动平均。 考虑n个观测值的集合&#xff0c;k是用于确定任何时间t的平均值的窗口的大小。然后&#xff0c;移动平均列表通过最初取当前窗…

Android Studio | 无法识别Icons.Default.Spa中的Spa

编写底部导航栏&#xff0c;涉及到Spa部分出现报红&#xff1a; 解决办法&#xff1a;在build.gradle.kts中引入图标依赖 dependencies {implementation "androidx.compose.material:material-icons-extended:<version>" }

Linux相关概念和重要知识点(9)(父进程、子进程、进程状态)

1.父进程、子进程 &#xff08;1&#xff09;父进程 CLI本质上是一款命令行界面的软件&#xff0c;是用户调用接口层面的程序&#xff08;上层&#xff0c;可以和系统调用接口做沟通&#xff09;&#xff0c;CLI和GUI是同级别的。用户的操作都是建立在CLI和GUI之上的。 但是…