基于迁移学习的手势分类模型训练

1、基本原理介绍

       这里介绍的单指模型迁移。一般我们训练模型时,往往会自定义一个模型类,这个类中定义了神经网络的结构,训练时将数据集输入,从0开始训练;而迁移学习中(单指模型迁移策略),是在一个已经有过训练基础的模型上,用自己的数据集,进一步训练,使得这个模型能够完成我们需要的任务。

这么做有有这样几个显而易见的好处:

※  因为模型之前被训练过,所以初始参数不会是0,这样能够加速模型训练

※  因为预训练模型(什么是预训练模型下文会讲到)在其他数据集上训练过,而其他数据集往往和我们用的数据集存在一定的区别,所以这可以提高模型的泛化能力

※  通过迁移学习,可以将来自大规模数据的优势转移到小规模或新任务上,提高模型的表现和效果

2、预训练模型

        在进行迁移学习时,我们要先找到一个预训练模型。在分类任务领域,比较流行的如resnet系列、mobilenet系列(更轻量化)、vgg(系列)、efficientnet(系列)等等网络,都是比较常用且容易获得的预训练模型,这些模型都能够通过python直接下载。

        而且由于上述模型基本都是在ImageNet这一大规模,多分类类别的数据集上进行过训练的,所以对于简单的二分类等少数类别分类,能有较好的效果。

3、训练流程

迁移学习完整的训练流程和一般搭建神经网络的训练模型的流程基本类似:数据预处理->数据集的切分->加载预训练模型(搭建神经网络)->设置超参数/损失函数/优化器等->训练模型

3.1 模型训练

下面的代码是一个利用mobilenet网络训练得到的手势分类模型,该模型能够较准确的分类不同类别手势。

相关解释已在代码中注释说明。

from torchvision.models import mobilenet_v2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomHorizontalFlip, RandomRotation# 定义数据预处理和增强器
transform = Compose([RandomHorizontalFlip(),  # 随机水平翻转RandomRotation(10),      # 随机旋转10度Resize((224, 224)),CenterCrop(224),ToTensor(),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载数据集并应用预处理和增强器
dataset = ImageFolder(root='data', transform=transform)
# 这里由于数据比较少,将所有数据集全部用来训练,得到的模型直接拿来用了,这其实不算是非常规范的操作,仅供参考# 定义网络结构
model = mobilenet_v2(pretrained=True)  # 加载预训练模型,也可以试试其他模型,效果差别挺大的
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 5)  # 假设是5分类问题,具体几分类,改这里的参数就行了# 将模型移动到设备上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)# 定义优化器和损失函数
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()# 定义训练循环
def train_model(model, criterion, optimizer, num_epochs, train_loader):for epoch in range(num_epochs):model.train()  # 设置模型为训练模式train_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_loss += loss.item() * inputs.size(0)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()epoch_loss = train_loss / totalepoch_acc = 100. * correct / totalprint(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')# 创建训练集的DataLoader
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)# 开始训练模型
train_model(model, criterion, optimizer, num_epochs=15, train_loader=train_loader)
torch.save(model, 'my_model(1).pth')

3.2 数据集文件结构

当然,你也可以自己定义读取数据集的data_loader类。

3.3 模型推理

这段代码是用训练得到的模型对一张图片进行推理测试的,如果需要对系列图片进行推理,评估模型效果,可自行修改,调用对应函数即可。

import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
def predict_image(image_path, model_path='my_model(1).pth'):image = Image.open(image_path).convert("RGB")# 对测试的图片进行预处理,需要和训练时处理的方式一样transform = Compose([Resize((224, 224)),CenterCrop(224),ToTensor(),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])image_tensor = transform(image).unsqueeze(0)device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')image_tensor = image_tensor.to(device)model = torch.load(model_path,map_location=device)model.eval()with torch.no_grad():output = model(image_tensor)_, predicted = torch.max(output.data, 1)  # 获得分类标记return predicted.item()
if __name__=="__main__":image_path = "test2/6.jpg"print(predict_image(image_path))

3.4 整体项目文件

4、补充说明

        这种利用迁移学习策略,进行少类别,不同类别特征差距小的任务需求来说,效果一般来说是比较好的。因为之前做过相关实验,准确率90%以上是很容易的,所以这里没有模型评估,生成混淆矩阵等过程。对于多类别分类,建议有完整的评估体系。

        上述使用的方法仅适用于分类任务,对于真正的目标检测如手势识别,直接使用该模型的问题是:由于无法定位手势的位置,所以导致识别不准确。

        本实验数据集是不同类别手势图片,为自制,不开源。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1488880.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

一文掌握YOLOv1-v10

引言 YOLO目标检测算法,不过多介绍,是基于深度学习的目标检测算法中最出名、发展最好的检测器,没有之一。本文简要的介绍一下从YOLOv1-YOLOv10的演化过程,详细技术细节不过多介绍,只提及改进点,适合初学者…

Vue3二次封装axios

官网: https://www.axios-http.cn/docs/interceptors steps1: 安装 npm install axios -ssteps2: /src/api/request.js 文件 >>> 拦截器 import axios from axios // 如果没用element-plus就不引入 import { ElMessage } from element-plusconst service axios.cre…

7月22日学习笔记 文件共享服务nfs,SAMBA文件共享与DNS域名服务

任务背景 由于业务驱动,为了提⾼⽤户的访问效率,现需要将原有web服务器上的静态资源 ⽂件分离出来,单独保存到⼀台⽂件服务器上。 任务要求 1. ⼀台应⽤服务器web-server部署apache,静态⽹⻚资源存放在另外⼀台NFS服 务器上 …

四、GD32 MCU 常见外设介绍 (2) GPIO 模块介绍

2.GPIO 模块介绍 GPIO的全称为通用输入输出口,是很多外设能够正常工作的必要条件。除了一些特定功能的引脚(如电源脚)外,MCU上其他的引脚都可以当做GPIO来使用。本章,我们将对GPIO进行简单介绍,并通过一个“流水灯”的实验来熟悉…

MATLAB基础:数组及其数学运算

今天我们继续学习MATLAB中的数组 我们在学习MATLAB时了解到,MATLAB作者秉持着“万物皆可矩阵”的思想企图将数学甚至世间万物使用矩阵表示出来,而矩阵的处理,自然成了这门语言的重中之重。 数组基础 在MATLAB中,数组是一个基本…

【人工智能 | 机器学习 | 理论篇】线性模型

文章目录 1. 基本形式2. 线性回归3. 对数几率回归4. 线性判别分析5. 多分类学习6. 类别不平衡问题 1. 基本形式 设有 d 个属性描述的示例 x ( x 1 , x 2 , x 3 , . . . , x d ) x ({x_1, x_2, x_3, ..., x_d}) x(x1​,x2​,x3​,...,xd​) 线性模型(linear mode…

使用C#手搓Word插件

WordTools主要功能介绍 编码语言:C#【VSTO】 1、选择 1.1、表格 作用:全选文档中的表格; 1.2、表头 作用:全选文档所有表格的表头【第一行】; 1.3、表正文 全选文档中所有表格的除表头部分【除第一行部分】 1.…

Android AI应用开发:移动检测

基于Google ML模型的Android移动物体检测应用——检测、跟踪视频中的物体 A. 项目描述 ML Kit物体检测器可以对视频流进行操作,能够检测视频中的物体并在连续视频帧中跟踪该物体。 相机捕捉视频时,检测到移动物体并为其生成一个边界框,并分…

【性能测试-登录时密码加密存储如何传参】

目的】 登录接口,密码加密传输,开发不做处理的情况下,密码如何加密传输 【方案】 使用前置处理器:JSR223 预处理程序,主要是在执行登录接口前将密码按照加密算法获得对应的加密密码,并传入接口 【说明】前…

【2024最新华为OD-C/D卷试题汇总】[支持在线评测] 亲子游戏(200分) - 三语言AC题解(Python/Java/Cpp)

🍭 大家好这里是清隆学长 ,一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为OD-C/D卷的三语言AC题解 💻 ACM银牌🥈| 多次AK大厂笔试 | 编程一对一辅导 👏 感谢大家的订阅➕ 和 喜欢💗 🍿 最新华为OD机试D卷目录,全、新、准,题目覆盖率达 95% 以上,支持题目在线…

【BUG】已解决:TypeError: a bytes-like object is required, not ‘str‘

TypeError: a bytes-like object is required, not ‘str‘ 目录 TypeError: a bytes-like object is required, not ‘str‘ 【常见模块错误】 【解决方案】 错误原因分析 解决方案 示例代码 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998https://bbs.csdn.net…

基于扩散的生成模型的语音增强和去噪

第二章 目标说话人提取之《Speech Enhancement and Dereverberation with Diffusion-based Generative Models》 文章目录 前言一、任务二、动机三、挑战四、方法1.方法:基于分数的语音增强生成模型(sgmse)2.网络结构 五、实验评价1.数据集2.采样器设置和评价指标3.基线模型4.评…

PaliGemma:A versatile 3B VLM for transfer

1.model 1.1 Architecture 图像分辨率为固定的正方形,224,448,896,这导致每种模型都有固定数量的图像token,256,1024,4096。图像在最前面,无需特殊的位置标记,BOS标记文本的开始,\n作为SEP token,不出现在前缀中,单独对SEP进行标记,以避免它与前缀的结束或后缀的…

力扣94题(java语言)

题目 思路 使用一个栈来模拟递归的过程,以非递归的方式完成中序遍历(使用栈可以避免递归调用的空间消耗)。 遍历顺序步骤: 遍历左子树访问根节点遍历右子树 package algorithm_leetcode;import java.util.ArrayList; import java.util.List; import…

立仪光谱共焦传感器应用测量之:汽车连接器高度差测量

01 检测要求,要求测量汽车连接器的高度差 02 检测方式 根据观察,我们采用立仪科技光谱共焦H4UC控制器搭配D65A52系列镜头,角度最大,外径最大,量程大,可以有效应用于测量弧面,大角度面等零件。 0…

SAPUI5基础知识19 - 视图嵌套(Nested Views)

1. 背景 SAPUI5 是一个用于构建企业级 Web 应用程序的 JavaScript 框架。它提供了丰富的 UI 控件和工具,帮助开发者创建复杂的用户界面。Nested Views 是 SAPUI5 中的一种设计模式,允许在一个视图中嵌套另一个视图。这种模式有助于模块化和重用代码&…

什么是反射以及反射的应用及例子

反射是Java中框架设计的核心,通过对类的构造、属性、方法等数据的获取提供抽象的底层构建。 反射机制: 反射需要先获得类的class字节码,由JVM类加载器(ClassLoader)负责加载,并在内存中缓存class的内部结构。借助于Java的反射机制…

面试常考Linux指令

文件权限 操作系统中每个文件都拥有特定的权限、所属用户和所属组。权限是操作系统用来限制资源访问的机制,在 Linux 中权限一般分为读(readable)、写(writable)和执行(executable),分为三组。分别对应文件的属主(owner),属组(group)和其他用…

腾讯智影PC端“智能画布”功能上线

随着人工智能技术的不断发展,图片编辑领域也迎来了创新的变革。腾讯智影PC端近日推出了一项革命性的新功能——“智能画布”,它将AI绘画技术与传统图片编辑相结合,为用户带来了前所未有的便捷体验。 腾讯智影官网 地址:点击此处…

【Sentinel】Sentinel超简单入门,一看就懂!!!

sentinel入门 一、什么是Sentinel1.1、流量控制1.2、熔断降级1.3、热点参数限流1.4、系统负载保护 二、资源和规则的概念三、Sentinel工作主流程四、代码初体验五、集成控制台 一、什么是Sentinel Sentinel 是阿里巴巴开源的流量控制组件,主要用来保护微服务和分布…