当前位置: 首页 > news >正文

基于PyTorch的图像分类特征提取与模型训练文档

概述

本代码实现了一个基于PyTorch的图像特征提取与分类模型训练流程。核心功能包括:

  1. 使用预训练ResNet18模型进行图像特征提取

  2. 将提取的特征保存为标准化格式

  3. 基于提取的特征训练分类模型

代码结构详解 

1. 库导入

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Subset
import numpy as np
import os
from ml.model_trainer import ModelTrainer
  • 关键库说明

    • torch:PyTorch核心库

    • torch.nn:神经网络模块

    • torchvision:计算机视觉专用模块

    • numpy:数值计算库

    • os:文件系统操作

    • ModelTrainer:自定义模型训练类(需另行实现)

2. 特征提取器类(FeatureExtractor)

初始化方法 __init__
def __init__(self):self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.model = torchvision.models.resnet18(weights='IMAGENET1K_V1')self.model = nn.Sequential(*list(self.model.children())[:-1])self.model = self.model.to(self.device).eval()self.transform = transforms.Compose([...])
  • 功能说明

    • 设备检测:自动选择GPU/CPU

    • 模型加载:使用ImageNet预训练的ResNet18

    • 模型修改:移除最后的全连接层(保留卷积特征提取器)

    • 预处理设置:标准化图像尺寸和颜色空间

特征提取方法 extract_features
def extract_features(self, data_dir):full_dataset = datasets.ImageFolder(...)loader = DataLoader(...)features = []labels = []with torch.no_grad():for inputs, targets in loader:inputs = inputs.to(self.device)outputs = self.model(inputs)features.append(outputs.squeeze().cpu().numpy())labels.append(targets.numpy())features = np.concatenate(...)labels = np.concatenate(...)return features, labels, full_dataset.classes
  • 关键参数

    • data_dir:包含分类子目录的图像数据集路径

    • batch_size=32:平衡内存使用与处理效率

    • num_workers=4:多线程数据加载

  • 处理流程

    1. 创建ImageFolder数据集

    2. 使用DataLoader批量加载

    3. 禁用梯度计算加速推理

    4. 特征维度压缩(squeeze)

    5. 设备间数据传输(GPU->CPU)

    6. 合并所有批次数据

3. 主执行流程

参数配置
DATA_DIR = "/home/.../data"  # 实际数据路径
SAVE_PATH = "./features.npz"  # 特征保存路径
特征提取与保存 
extractor = FeatureExtractor()
if not os.path.exists(SAVE_PATH):features, labels, classes = extractor.extract_features(DATA_DIR)np.savez(SAVE_PATH, features=features, labels=labels, classes=classes)
else:data = np.load(SAVE_PATH)features = data['features']labels = data['labels']
  • 文件结构

    • features: [N_samples, 512] 的特征矩阵

    • labels: [N_samples] 的标签数组

    • classes: 类别名称列表

模型训练与保存
X, y = features, labels
trainer = ModelTrainer()
model = trainer.train_model(X, y)
joblib.dump(model, 'pest_classifier.pkl')

 

  • 假设条件

    • ModelTrainer需实现训练逻辑(如SVM、随机森林等)

    • 默认使用全部数据进行训练(建议实际添加数据分割)

技术细节说明

1. 图像预处理流程

2. 特征维度分析

  • ResNet18最后层输出:512维特征向量

  • 假设1000张图像:

    • 原始图像:1000×3×224×224 (约150MB)

    • 提取特征:1000×512 (约2MB) → 显著降维

3. 性能优化策略

  • GPU加速:自动检测CUDA设备

  • 批量处理:32张/批平衡效率与内存

  • 缓存机制:避免重复特征提取

  • 梯度禁用:减少内存消耗

 

 

 

 

 

http://www.xdnf.cn/news/216973.html

相关文章:

  • 集群系统的五大核心挑战与困境解析
  • EtherCAT转CANopen方案落地:推动运动控制器与传感器通讯的工程化实践
  • CKESC Breeze 6S 40A_4S 50A FOC BEC电调测评:全新vfast 技术赋能高效精准控制
  • 低代码平台部署方案解析:百特搭四大部署方式
  • 大模型推理:Qwen3 32B vLLM Docker本地部署
  • 强化学习贝尔曼方程推导
  • 流量守门员:接口限流艺术
  • Manus AI多语言手写识别技术全解析:从模型架构到实战部署
  • JavaScript 中深拷贝浅拷贝的区别?如何实现一个深拷贝?
  • 信雅达 AI + 悦数 Graph RAG | 大模型知识管理平台在金融行业的实践
  • C# 类的基本概念(实例成员)
  • 【2024-NIPS-版权】Evaluating Copyright Takedown Methods for Language Models
  • 《云原生》核心内容梳理和分阶段学习计划
  • Alibaba第四版JDK源码学习笔记2025首次开源
  • HCIP【VLAN技术(详解)】
  • Java高频面试之并发编程-11
  • 第三部分:赋予网页灵魂 —— JavaScript(下)
  • Spring Boot - 配置管理与自动化配置进阶
  • 【Bash】可以请您解释性地说明一下“2>1”这个语法吗?
  • Windows 系统下使用 Docker 搭建Redis 集群(6 节点,带密码)
  • C++日更八股--first
  • SpringBoot应用:Docker与Kubernetes全栈实战秘籍
  • git fetch和git pull的区别
  • 域对齐是什么
  • 判断用户选择的Excel单元格区域是否跨页?
  • 力扣hot100——239.滑动窗口最大值
  • 在大数据环境下,使用spingboot为Android APP推送数据方案
  • 【Machine Learning Q and AI 读书笔记】- 02 自监督学习
  • 主流微前端框架比较
  • java面试题目