YOLOv10改进策略【注意力机制篇】| CVPR2024 CAA上下文锚点注意力机制

一、本文介绍

本文记录的是基于CAA注意力模块的YOLOv10目标检测改进方法研究在远程遥感图像或其他大尺度变化的图像中目标检测任务中,为准确提取其长距离上下文信息,需要解决大目标尺度变化和多样上下文信息时的不足的问题CAA能够有效捕捉长距离依赖,并且参数量和计算量更少。

文章目录

  • 一、本文介绍
  • 二、CAA原理
    • 2.1 原理
    • 2.2 优势
  • 三、CAA的实现代码
  • 四、创新模块
    • 4.1 改进点⭐
  • 五、添加步骤
    • 5.1 修改ultralytics/nn/modules/block.py
    • 5.2 修改ultralytics/nn/modules/__init__.py
    • 5.3 修改ultralytics/nn/modules/tasks.py
  • 六、yaml模型文件
    • 6.1 模型改进版本⭐
  • 六、成功运行结果


二、CAA原理

Poly Kernel Inception Network for Remote Sensing Detection

CAA(Context Anchor Attention)注意力的设计原理和优势如下:

2.1 原理

  • 采用平均池化1×1卷积来获取局部区域特征:对输入特征进行平均池化,然后通过1×1卷积得到局部区域特征。
  • 使用深度可分离的条形卷积来近似标准大核深度可分离卷积:通过两个深度可分离的条形卷积来扩大感受野,并且这种设计基于两个考虑。首先,条形卷积是轻量级的,与传统的大核2D深度可分离卷积相比,使用几个1D深度可分离核可以达到类似的效果,同时参数减少了 k b / 2 kb/2 kb/2。其次,条形卷积有助于识别和提取细长形状物体(如桥梁)的特征。
  • 随着CAA模块所属的PKI块深度增加,增大条形卷积的核大小( k b = 11 + 2 × l kb = 11 + 2×l kb=11+2×l),以增强PKINet建立长距离像素间关系的能力,同时由于条形深度可分离设计,不会显著增加计算成本。
  • 最后,CAA模块产生一个注意力权重,用于增强PKI模块的输出特征。具体来说,通过Sigmoid函数确保注意力图在范围 ( 0 , 1 ) (0, 1) (0,1)内,然后通过元素点乘和元素求和操作来增强特征。

在这里插入图片描述

2.2 优势

  • 有效捕捉长距离依赖:通过合适的核大小设置,能够更好地捕捉长距离像素间的依赖关系,相比于较小核大小的情况,能提升模型性能,因为较小核无法有效捕获长距离依赖,而较大核可以包含更多上下文信息。
  • 轻量化:条形卷积的设计使得CAA模块具有轻量化的特点,减少了参数数量和计算量。
  • 增强特征提取:当在PKINet的任何阶段使用CAA模块时,都能带来性能提升,当在所有阶段部署CAA模块时,性能增益达到 1.03 % 1.03\% 1.03%,这表明CAA模块能够有效地增强模型对特征的提取能力。

论文:https://arxiv.org/pdf/2403.06258
源码:https://github.com/NUST-Machine-Intelligence-Laboratory/PKINet

三、CAA的实现代码

CAA模块的实现代码如下:

from mmcv.cnn import ConvModule
from mmengine.model import BaseModuleclass CAA(BaseModule):"""Context Anchor Attention"""def __init__(self,channels: int,h_kernel_size: int = 11,v_kernel_size: int = 11,norm_cfg: Optional[dict] = dict(type='BN', momentum=0.03, eps=0.001),act_cfg: Optional[dict] = dict(type='SiLU'),init_cfg: Optional[dict] = None,):super().__init__(init_cfg)self.avg_pool = nn.AvgPool2d(7, 1, 3)self.conv1 = ConvModule(channels, channels, 1, 1, 0,norm_cfg=norm_cfg, act_cfg=act_cfg)self.h_conv = ConvModule(channels, channels, (1, h_kernel_size), 1,(0, h_kernel_size // 2), groups=channels,norm_cfg=None, act_cfg=None)self.v_conv = ConvModule(channels, channels, (v_kernel_size, 1), 1,(v_kernel_size // 2, 0), groups=channels,norm_cfg=None, act_cfg=None)self.conv2 = ConvModule(channels, channels, 1, 1, 0,norm_cfg=norm_cfg, act_cfg=act_cfg)self.act = nn.Sigmoid()def forward(self, x):attn_factor = self.act(self.conv2(self.v_conv(self.h_conv(self.conv1(self.avg_pool(x))))))return attn_factor

四、创新模块

4.1 改进点⭐

模块改进方法
1️⃣ 加入CAA模块CAA模块添加后如下:

在这里插入图片描述

注意❗:在5.2和5.3小节中需要声明的模块名称为:CAA

2️⃣:加入基于CAA模块C2f。利用CAA改进C2f模块,使模型能够更好地捕捉长距离像素间的依赖关系。

改进代码如下:

class C2f_CAA(nn.Module):"""Faster Implementation of CSP Bottleneck with 2 convolutions."""def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):"""Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,expansion."""super().__init__()self.c = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, 2 * self.c, 1, 1)self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))self.att = CAA(c2)def forward(self, x):"""Forward pass through C2f layer."""y = list(self.cv1(x).chunk(2, 1))y.extend(m(y[-1]) for m in self.m)return self.att(self.cv2(torch.cat(y, 1)))def forward_split(self, x):"""Forward pass using split() instead of chunk()."""y = list(self.cv1(x).split((self.c, self.c), 1))y.extend(m(y[-1]) for m in self.m)return self.att(self.cv2(torch.cat(y, 1)))

在这里插入图片描述

注意❗:在5.2和5.3小节中需要声明的模块名称为:C2f_CAA


五、添加步骤

5.1 修改ultralytics/nn/modules/block.py

此处需要修改的文件是ultralytics/nn/modules/block.py

block.py中定义了网络结构的通用模块,我们想要加入新的模块就只需要将模块代码放到这个文件内即可。

CAAC2f_CAA模块代码添加到此文件下。

5.2 修改ultralytics/nn/modules/init.py

此处需要修改的文件是ultralytics/nn/modules/__init__.py

__init__.py文件中定义了所有模块的初始化,我们只需要将block.py中的新的模块命添加到对应的函数即可。

CAAC2f_CAAblock.py中实现,所有要添加在from .block import

from .block import (C1,C2,...CAA,C2f_CAA
)

在这里插入图片描述

5.3 修改ultralytics/nn/modules/tasks.py

tasks.py文件中,需要在两处位置添加各模块类名称。

首先:在函数声明中引入CAAC2f_CAA

在这里插入图片描述

在这里插入图片描述

其次:在parse_model函数中注册CAAC2f_CAA模块

在这里插入图片描述

在这里插入图片描述


六、yaml模型文件

6.1 模型改进版本⭐

此处以ultralytics/cfg/models/v10/yolov10m.yaml为例,在同目录下创建一个用于自己数据集训练的模型文件yolov10m-C2f_CAA.yaml

yolov10m.yaml中的内容复制到yolov10m-C2f_CAA.yaml文件下,修改nc数量等于自己数据中目标的数量。

📌 模型的修改方法是将骨干网络中的所有C2f模块替换成C2f_CAA模块

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4- [-1, 3, C2f_CAA, [128, True]]- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8- [-1, 6, C2f_CAA, [256, True]]- [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16- [-1, 6, C2f_CAA, [512, True]]- [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32- [-1, 3, C2fCIB, [1024, True]]- [-1, 1, SPPF, [1024, 5]] # 9- [-1, 1, PSA, [1024]] # 10# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 6], 1, Concat, [1]] # cat backbone P4- [-1, 3, C2f, [512]] # 13- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 4], 1, Concat, [1]] # cat backbone P3- [-1, 3, C2f, [256]] # 16 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 13], 1, Concat, [1]] # cat head P4- [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium)- [-1, 1, SCDown, [512, 3, 2]]- [[-1, 10], 1, Concat, [1]] # cat head P5- [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large)- [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)

六、成功运行结果

分别打印网络模型可以看到C2f_CAA已经加入到模型中,并可以进行训练了。

yolov10m-C2f_CAA

                   from  n    params  module                                       arguments                     0                  -1  1      1392  ultralytics.nn.modules.conv.Conv             [3, 48, 3, 2]                 1                  -1  1     41664  ultralytics.nn.modules.conv.Conv             [48, 96, 3, 2]                2                  -1  2    172416  ultralytics.nn.modules.block.C2f_CAA         [96, 96, True]                3                  -1  1    166272  ultralytics.nn.modules.conv.Conv             [96, 192, 3, 2]               4                  -1  4   1353216  ultralytics.nn.modules.block.C2f_CAA         [192, 192, True]              5                  -1  1     78720  ultralytics.nn.modules.block.SCDown          [192, 384, 3, 2]              6                  -1  4   5360640  ultralytics.nn.modules.block.C2f_CAA         [384, 384, True]              7                  -1  1    228672  ultralytics.nn.modules.block.SCDown          [384, 576, 3, 2]              8                  -1  2   1689984  ultralytics.nn.modules.block.C2fCIB          [576, 576, 2, True]           9                  -1  1    831168  ultralytics.nn.modules.block.SPPF            [576, 576, 5]                 10                  -1  1   1253088  ultralytics.nn.modules.block.PSA             [576, 576]                    11                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          12             [-1, 6]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           13                  -1  2   1993728  ultralytics.nn.modules.block.C2f             [960, 384, 2]                 14                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          15             [-1, 4]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           16                  -1  2    517632  ultralytics.nn.modules.block.C2f             [576, 192, 2]                 17                  -1  1    332160  ultralytics.nn.modules.conv.Conv             [192, 192, 3, 2]              18            [-1, 13]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           19                  -1  2    831744  ultralytics.nn.modules.block.C2fCIB          [576, 384, 2, True]           20                  -1  1    152448  ultralytics.nn.modules.block.SCDown          [384, 384, 3, 2]              21            [-1, 10]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           22                  -1  2   1911168  ultralytics.nn.modules.block.C2fCIB          [960, 576, 2, True]           23        [16, 19, 22]  1   2282134  ultralytics.nn.modules.head.v10Detect        [1, [192, 384, 576]]          
YOLOv10m-C2f_CAA summary: 707 layers, 19198246 parameters, 19198230 gradients, 80.9 GFLOPs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1557892.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

嵌入式 c 内存堆栈增长方向往低地址方向好处

如下是堆和栈内存空间使用方式有如下好处: 1、stack从高地址向低地址扩展,这样栈空间的起始位置就能确定下来;如果反向,则要考虑这个起点从哪里合适,要确定堆的大小。 2、可以共用中间部分区域空间,最大化…

高速电路中电源设计问题

DCDC芯片都有一个开关频率,选型的时候注意一点这个问题。 纹波:纹波是电源波动中的低频部分,一般处于5Mhz以内的频段,铲子来自MOS的开关动作。 噪声:噪声值电源波动的高频部分,一般高于5Mhz,成分比较复杂…

UE5运行时动态加载场景角色动画任意搭配-角色及动画(一)

通过《MMD模型及动作一键完美导入UE5》系列文章,我们可以把外部场景、角色、动画资产导入UE5,接下来我们将实现运行时动态加载这些资产,并任意组合搭配。 1、骨骼动画复用 1、大部分模型骨骼是不通用的,比如这些裙子也是有骨骼的,属于模型特有的,但是对于动画来说,很多…

OmniCorpus数据集:最大(百亿级别)多模态数据集

2024-06-12 ,由上海人工智能实验室、哈尔滨工业大学、南京大学、复旦大学等联合创建OmniCorpus,一个达到百亿级别的图文交错数据集。它不仅规模空前,更以其多元化的数据来源和高质量的数据内容,为多模态大语言模型的研究提供了坚实…

Axure大屏可视化模板在多领域实践应用案例分析

Axure大屏可视化模板,凭借其强大的功能性和灵活性,在众多领域中发挥着举足轻重的作用。本文将详细探讨Axure大屏可视化模板在农业、园区管理、智慧城市、企业数据可视化和医疗领域的应用案例,展示其如何助力各行业实现智能化管理和决策优化。…

Mythical Beings:Web3游戏如何平衡创造内容、关注度与实现盈利的不可能三角

Web3游戏自其诞生以来,以去中心化和独特的代币经济体系迅速引起关注。然而,如何在创造内容、吸引用户和实现盈利之间达到平衡,始终是Web3游戏面临的核心挑战。Mythical Beings作为一款Web3卡牌游戏,通过创新设计和独特机制&#x…

【LeetCode: 1436. 旅行终点站 | 哈希表】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

面试题:Redis(一)

1. redis是单线程还是多线程? 2. IO多路复用听说过么? 3. Redis为什么快? 1. Redis是单线程还是多线程? 版本不同,Redis基于的架构也不同,所以单单问是单还是多线程并不严谨 3.x 之前 redis都是单线程 4.x …

微知-如何临时设置Linux系统时间?(date -s “2024-10-08 22:55:00“, time, hwclock, timedatectl)

背景 在tar解压包的时候经常出现时间不对,可以临时用date命令修改一下,也可以其他,本文主要介绍临时修改的方法 date命令修改 sudo date -s "2024-10-08 22:55:00"其他查看和修改的命令 本文只记录查看方式,修改的暂…

【java数据结构】ArrayList实例

【java数据结构】ArrayList实例 一、杨辉三角二、打扑克 一、杨辉三角 已知条件:给定行数的大小 思路:首先定义一个二维列表(也可以称为集合),我们对每一列处理完,最后把每一列加起来,不就是完整…

某象异形滑块99%准确率方案

注意,本文只提供学习的思路,严禁违反法律以及破坏信息系统等行为,本文只提供思路 如有侵犯,请联系作者下架 该文章模型已经上线ocr识别网站,欢迎测试!!,地址:http://yxlocr.nat300.top/ocr/slider/6 所谓的顶象异形滑块,是指没有采用常规的缺口,使用各种形状的缺口…

国外电商系统开发-运维系统文件上传-高级上传

如果您要上传文件到10台服务器中,有3台服务器的路径不是一样的,那么在这种情况下您就可以使用本功能,单独执行不一样的路径 点击【高级】上传

仿真技术入门书籍:《模拟集成电路设计与仿真》(可下载)

无论是在通信、医疗、消费电子还是工业控制领域,模拟集成电路都是实现复杂电子系统功能的关键。在电子工程领域,模拟集成电路设计是一门深奥而复杂的学科。随着技术的发展,设计者们需要掌握的不仅是电路设计原理,还包括仿真技术的…

【C语言刷力扣】1436.旅行终点站

题目: 解题思路: 两层循环查找,第一次循环中初始化 destination 为 path中每次旅行的终点作为最终的终点。二次循环查找当前 destination ,若是作为某次旅行的起点,说明不是最后的终点。 char* destCity(char ***paths…

Tomcat服务部署、优化及多实例实验

目录 一、Tomcat的基本介绍 1. tomcat是什么? 2.tomcat构成组件 2.1 web容器 2.2 servlet容器 2.3 jsp容器 3. tomcat的顶层架构 4.tomcat的核心功能 5.tomcat的请求过程 6.tomcat的配置文件 二、tomcat服务部署 1. 安装jdk、设置环境变量并测试 2.安装启动t…

Windows无需管理员权限,命令轻松修改IP和DNS

哈喽大家好,欢迎来到虚拟化时代君(XNHCYL)。 “ 大家好,我是虚拟化时代君,一位潜心于互联网的技术宅男。这里每天为你分享各种你感兴趣的技术、教程、软件、资源、福利…(每天更新不间断,福利…

【数据分享】1901-2023年我国省市县三级逐月最高气温数据(免费获取/Shp/Excel格式)

之前我们分享过1901-2023年1km分辨率逐月最高气温栅格数据(可查看之前的文章获悉详情),该数据来源于国家青藏高原科学数据中心,很多小伙伴拿到数据后反馈栅格数据不太方便使用,问我们能不能把数据处理为更方便使用的Sh…

计算机网络:数据链路层详解

目录 一、点对点信道: (1)封装成帧 (2)透明传输 (3)差错检测 二、点对点协议 (1)数据链路层的特点 (2)PPP协议的组成 (3&…

Vue3 使用 pinia

什么是Pinia Pinia是 Vue 的存储库,它允许您跨组件/页面共享状态,与vuex功能一样。 准备 安装 npm install pinia 或者 yarn add pinia使用 首先修改main.ts文件 main.ts import ./assets/main.cssimport { createApp } from vue import App from…

《强烈推荐一个强大的书签管理工具》

在信息爆炸的时代,我们每天都会浏览大量的网页,收藏各种各样的书签。然而,随着书签数量的增加,管理起来也变得越来越困难。这时,一个强大的书签管理工具就显得尤为重要。今天,我要向大家推荐一款备受好评的…