yolov5及yolov7实战之剪枝

之前有讲过一次yolov5的剪枝:yolov5实战之模型剪枝_yolov5模型剪枝-CSDN博客
当时基于的是比较老的yolov5版本,剪枝对整个训练代码的改动也比较多。最近发现一个比较好用的剪枝库,可以在不怎么改动原有训练代码的情况下,实现剪枝的操作,这篇文章就简单介绍一下,剪枝的概念以及为什么要剪枝可以参看上一篇,这里就不赘述了。

Torch-Pruning

VainF/Torch-Pruning: [CVPR 2023] Towards Any Structural Pruning; LLMs / Diffusion / Transformers / YOLOv8 / CNNs (github.com)
今天我们要用到的就是这个剪枝库,这个库集成了很多剪枝的方法,毕竟使用比较简单。

用法

这个剪枝库既有low level的剪枝,也就是手动控制剪枝哪些层,也有high level的剪枝,就是使用预设的剪枝算法,自动选择剪枝的部分。对于我们来说,更适合使用high level剪枝。具体的这里使用和上一篇yolov5里面的剪枝一样的算法,在这个库里叫BNScalePruner。

安装

首先我们需要安装上面提到的库,有两种方式来安装:

pip install torch-pruning

或源码安装(当碰到bug发布版本没修复,源码修复的时候):

pip install git+https://github.com/VainF/Torch-Pruning.git

稀疏化训练

为了更好的剪枝,我们在训练剪枝前的网络时,推荐开启稀疏化训练,利用这个库,我们可以很方便的实现这个操作。
首先在我们的训练代码中定义好剪枝器, 这里的opt.prune是我自己加的来控制是否开启稀疏化训练的标志:

# prune
if opt.prune:examle_input = torch.randn(1, 3, imgsz, imgsz).to(device)imp = tp.importance.BNScaleImportance()pruner = tp.pruner.BNScalePruner(model, examle_input, imp,reg=0.0001)

稀疏化训练主要需要设置reg参数,一般设置0.001~1e-6之间。
定义好剪枝器后,在训练代码的scaler.scale(loss).backward()之后,添加如下代码:

if opt.prune:pruner.regularize(model)

即可实现稀疏化训练。

剪枝

稀疏化训练后(也可以不做稀疏化训练),我们就可以进行剪枝操作了。这个库可以在训练中交互式进行多次剪枝,简单起见,我们这里分离剪枝和训练的代码,只进行剪枝操作。

import torch_pruning as tp
from models.experimental import attempt_load
import torchweights = "yolov7.pt"
model = attempt_load(weights, map_location=torch.device('cuda:0'), fuse=False)
for p in model.parameters():p.requires_grad = True
ignored_layers = []
from models.yolo import Detect, IDetect
from models.common import ImplicitA, ImplicitM
for m in model.modules():if isinstance(m, (Detect,IDetect)):ignored_layers.append(m.m)
unwrapped_parameters = []
for name, m in model.named_parameters():if isinstance(m, (ImplicitA,ImplicitM,)):unwrapped_parameters.append((name,1)) # pruning 1st dimension of implicit matrixprint(ignored_layers)
example_inputs = torch.rand(1, 3, 416, 416, device='cuda:0')
imp = tp.importance.BNScaleImportance()
pruner = tp.pruner.BNScalePruner(model, example_inputs, imp,ignored_layers=ignored_layers,unwrapped_parameters=unwrapped_parameters,global_pruning=True,ch_sparsity=0.3,round_to=8,)base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
pruner.step()
pruned_model = pruner.model
pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(pruned_model, example_inputs)
print(f"macs: {base_macs} -> {pruned_macs}")
print(f"nparams: {base_nparams} -> {pruned_nparams}")
macs_cutoff_ratio = (base_macs - pruned_macs) / base_macs
nparams_cutoff_ratio = (base_nparams - pruned_nparams) / base_nparams
print(f"macs cutoff ratio: {macs_cutoff_ratio}")
print(f"nparams cutoff ratio: {nparams_cutoff_ratio}")
save_path = weights.replace(".pt", "_pruned_bn_0.3.pt")torch.save({"model": pruned_model.module if hasattr(pruned_model, 'module') else pruned_model}, save_path)

去掉一些计算剪枝比例的,保存代码等代码外,剪枝操作其实由pruner.step()这一步完成。这里我们主要需要设置的参数是:

  • ch_sparsity: 可以理解成剪枝的比例,越大剪得越多
  • global_pruning: True表示整个模型的权重按一个整体排序后剪枝,False表示按分组内部按比例剪枝
  • round_to: 剪枝后的通道保留为多少的倍数,一般在硬件上,保留8的倍数

微调

经过剪枝的网络,精度是下降比较明显的,需要再在数据上finetune一些epoch才能把精度拉回来。
yolov7默认是通过yaml文件创建模型结构,然后再载入权重进行训练的,而我们剪枝后的模型是没有模型结构文件的,因此需要对训练代码做一定的修改,具体而言,只是对模型的载入进行一点修改。其中opt.finetune是用来控制是否处于finetune模式的标志位。

if opt.finetune: # for model without cfgnew = torch.load(weights, map_location=device)  # createmodel = new["model"]print("Finetune Mode...")
elif pretrained:
...

比较简单的改法是这样,从checkpoint中载入结构和权重,还有一种方式则是修改yolov7的Model类,这个在后面讲yolov7剪枝后蒸馏的时候再讲,暂时用上面这种方式就可以了。

评测

我在自己的任务上的效果是yolov7剪枝50%,微调后基本上能达到剪枝前的map,没记错的话这是和稀疏化训练的比,毕竟开启稀疏化训练本身也会掉点。大家可以在自己的任务上尝试一下,总体上精度还是可以的

结语

这篇文章简述了以下yolov7的剪枝,yolov5也可用,希望对大家有帮助。
f77d79a3b79d6d9849231e64c8e1cdfa~tplv-dy-resize-origshort-autoq-75_330.jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/143072.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

使用自定义注解发布webservice服务

使用自定义注解发布webservice服务 概要代码自定义注解WebService接口服务发布配置使用 结果 概要 在springboot使用webservice,发布webservice服务的时候,我们经常需要手动在添加一些发布的代码,比如: Bean public Endpoint or…

破信息壁垒,亿发一站式ERP系统建设,打造五金制造信息管理平台

五金制造拥有明显的行业特征,如体量小、品种繁多、颜色多样、加工工艺不断演进等,呈现出一种独特的管理挑战。大多数五金企业仍然依赖人工管理和经验决策,如今需要寻求更合理和科学的决策方法,以实现生产、销售、仓储、采购和财务…

百度SEO优化技巧(选择、网站结构、内容优化、外链建设、数据分析)

百度关键词SEO优化介绍 SEO是搜索引擎优化的缩写,是指通过优化网站结构、内容和外部链接等方式,提高网站在搜索引擎中的排名,从而获取更多的访问量和流量。百度是中国最大的搜索引擎之一,对于企业来说,优化百度关键词…

uniapp 事件委托失败 获取不到dataset

问题&#xff1a; v-for 多个span ,绑定点击事件 代码:view里包着一个span, <view class"status-list" tap"search"><span class"status-item" v-for"(key,index) in statusList" :key"index" :data-key"k…

USB转换方案介绍

随着科技的不断发展&#xff0c;我们的生活中出现了越来越多的电子设备。然而&#xff0c;这些设备通常具有不同的连接端口和协议&#xff0c;这可能会使它们之间的连接变得困难。这时候&#xff0c;使用USB转换就成为了一种非常方便和实用的解决方法。 无论是在家庭、办公室还…

系统集成|第十章(笔记)

目录 第十章 质量管理10.1 项目质量管理概论10.2 主要过程10.2.1 规划质量管理10.2.2 实施质量保证10.2.3 质量控制 10.3 常见问题 上篇&#xff1a;第九章、成本管理 下篇&#xff1a;第十一章、人力资源管理 第十章 质量管理 10.1 项目质量管理概论 质量管理&#xff1a;指确…

创建型设计模式——工厂模式

摘要 本博文主要介绍软件设计模式中工厂模式&#xff0c;其中工厂设计模式的扩展为简单工厂(Simple Factory)、工厂方法(Factory Method)、抽象工厂(Abstract Factory)三种。 一、简单工厂(Simple Factory) 主要分析设计模式 - 简单工厂(Simple Factory)&#xff0c;它把实例…

PHP8的类与对象的基本操作之类的实例化-PHP8知识详解

定义完类和方法后&#xff0c;并不是真正创建一个对象。类和对象可以描述为如下关系。类用来描述具有相同数据结构和特征的“一组对象”&#xff0c;“类”是“对象”的抽象&#xff0c;而“对象”是“类”的具体实例&#xff0c;即一个类中的对象具有相同的“型”&#xff0c;…

【DETR】

https://tianfeng.space/ 前言 论文 代码 DETR&#xff08;Data-efficient Image Transformer&#xff09;是一种用于目标检测任务的深度学习模型。它与传统的目标检测方法不同&#xff0c;采用了Transformer架构&#xff0c;将目标检测问题转化为一个序列到序列的问题。以下…

Java之IO流概述

1.1 什么是IO 生活中&#xff0c;你肯定经历过这样的场景。当你编辑一个文本文件&#xff0c;忘记了ctrls &#xff0c;可能文件就白白编辑了。当你电脑上插入一个U盘&#xff0c;可以把一个视频&#xff0c;拷贝到你的电脑硬盘里。那么数据都是在哪些设备上的呢&#xff1f;键…

【数据库——MySQL】(10)视图和索引

目录 1. 视图1.1 创建视图1.2 查询视图 2. 索引2.1 索引的分类2.2 索引的建立 参考书籍 1. 视图 1.1 创建视图 基础语法&#xff1a; CREATE [OR REPLACE] VIEW 视图名[(列名表)]ASSELECT语句[WITH CHECK OPTION]说明&#xff1a; 在默认情况下&#xff0c;将在当前数据库创…

Linux 用户 用户组管理

用户 Linux系统是一个多用户多任务的分时操作系统&#xff0c;任何要使用系统资源的用户&#xff0c;都必须首先向系统管理员申请一个账号&#xff0c;然后以这个账号的身份进入系统。每个用户账号都拥有一个唯一的用户名和各自的口令。用户在登录时键入正确的用户名和口令后&a…

华为ICT——第二章-数字图像处理私人笔记

目录 1&#xff1a;计算机视觉&#xff1a;​编辑 2&#xff1a;计算机视觉应用&#xff1a;​编辑 3&#xff1a;计算机视界核心问题&#xff1a;​编辑 4&#xff1a;相关学科&#xff1a; 5&#xff1a;计算机视觉与人工智能&#xff1a; 最成熟的技术方向是图像识别 6…

【面试算法——动态规划 20】最长公共子序列 不相交的线

1143. 最长公共子序列 链接: 1143. 最长公共子序列 给定两个字符串 text1 和 text2&#xff0c;返回这两个字符串的最长 公共子序列 的长度。如果不存在 公共子序列 &#xff0c;返回 0 。 一个字符串的 子序列 是指这样一个新的字符串&#xff1a;它是由原字符串在不改变字…

Spring面试题8:面试官:说一说Spring的BeanFactory

该文章专注于面试,面试只要回答关键点即可,不需要对框架有非常深入的回答,如果你想应付面试,是足够了,抓住关键点 面试官:说一说Spring的BeanFactory Spring的BeanFactory是Spring框架的核心容器,负责管理和创建Bean对象。它是一个工厂类,用于实例化、配置和管理Bean的…

SpringBoot 如何使用 Druid 进行数据库连接池管理

使用 Druid 进行数据库连接池管理的 Spring Boot 应用 数据库连接池是任何Web应用程序的重要组成部分&#xff0c;它们有助于管理数据库连接的复用&#xff0c;提高性能和资源利用率。Druid是一个强大的数据库连接池&#xff0c;它具有监控、防SQL注入、快速、可扩展等特点。在…

谈谈最近招人的感受!

最近折腾新的项目&#xff0c;面试了很多实习生小伙伴&#xff0c;我说说我的一些「面试」感受&#xff0c; 虽然是一个老生常谈的话题&#xff0c;但是依然提一下。 准时很重要&#xff1a;提前一点时间&#xff0c;踩个点&#xff0c;别迟到&#xff0c;面试的过程中由于每个…

低功耗引擎Cliptrix为什么可以成为IOT的高效能工具

在万物互联的时代&#xff0c;现代人已普遍接受电视、音箱等电器设备具备智能化能力&#xff0c;也是在这个趋势下&#xff0c;我们身边越来越多的iOT设备联网和交互成为刚需。 但iot设备也面临到一些非常显著的痛点&#xff0c;例如iot设备的内存、处理器等核心元件无法与手机…

爬虫 — 多线程

目录 一、多任务概念二、实现多任务方式1、多进程 &#xff08;Multiprocessing&#xff09;2、多线程&#xff08;Multithreading&#xff09;3、协程&#xff08;Coroutine&#xff09; 三、多线程执行顺序四、多线程的方法1、join()2、setDaemon()3、threading.enumerate() …

python运算函数

简 python输入输出函数input() :用户用于读取键盘输入的函数&#xff0c;返回值为“string”类型 运算函数abs(x) &#xff1a;x的绝对值int(x) &#xff1a;将x转换成整型(截掉小数部分)float(x):浮点数divmod(x,y):返回&#xff08;x//y,x%y&#xff09;complex(re,im):返回一…