YOLOv11改进 | 注意力篇 | YOLOv11引入ACmix注意力机制

1. ACmix介绍

1.1  摘要:卷积和自注意力是表示学习的两种强大技术,它们通常被认为是两种彼此不同的同行方法。 在本文中,我们表明它们之间存在很强的潜在关系,从某种意义上说,这两种范式的大量计算实际上是通过相同的操作完成的。 具体来说,我们首先证明内核大小为 k×k 的传统卷积可以分解为 k2 个单独的 1×1 卷积,然后进行移位和求和操作。 然后,我们将自注意力模块中的查询、键和值的投影解释为多个 1×1 卷积,然后计算注意力权重和值的聚合。 因此,两个模块的第一阶段都包括类似的操作。 更重要的是,与第二级相比,第一级贡献了主要的计算复杂性(通道大小的平方)。 这种观察自然地导致了这两个看似不同的范式的优雅整合,即一个混合模型,它既享受自注意力和卷积(ACmix)的好处,同时与纯卷积或自注意力对应物相比具有最小的计算开销。 大量实验表明,我们的模型在图像识别和下游任务方面比竞争基线取得了持续改进的结果。

官方论文地址:https://arxiv.org/pdf/2111.14556v1.pdf

官方代码地址:https://github.com/Panxuran/ACmix

1.2  简单介绍:  

          ACmix模块是一种结合了自注意力和卷积技术的混合模型,旨在通过最小化计算开销来整合这两种看似不同的范式。该模块的核心在于揭示了自注意力和卷积之间存在的强大内在联系,这种联系主要体现在两者在第一阶段的计算复杂性上。具体而言,ACmix首先使用1×1卷积将输入特征图投影,以获得一组丰富的中间特征。然后,这些中间特征被重用,并分别遵循自注意力和卷积的方式聚合。

          在ACmix中,两个主要阶段(I和II)共同作用:在第一阶段,输入特征通过三个1×1卷积进行投影,生成包含3×N特征图的丰富中间特征集;第二阶段则根据不同范式(即自注意力和卷积方式)使用这些中间特征。此外,为了提高模型的灵活性和效率,ACmix还引入了几个改进措施,包括使用多个组卷积分解复杂的位移操作以及采用可学习的卷积分类器初始化固定核。

1.3  ACmix模块结构图

2. 核心代码

import torch
import torch.nn as nndef position(H, W, type, is_cuda=True):if is_cuda:loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1).to(type)loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W).to(type)else:loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W)loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0)return locdef stride(x, stride):b, c, h, w = x.shapereturn x[:, :, ::stride, ::stride]def init_rate_half(tensor):if tensor is not None:tensor.data.fill_(0.5)def init_rate_0(tensor):if tensor is not None:tensor.data.fill_(0.)class ACmix(nn.Module):def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1):super(ACmix, self).__init__()self.in_planes = in_planesself.out_planes = out_planesself.head = headself.kernel_att = kernel_attself.kernel_conv = kernel_convself.stride = strideself.dilation = dilationself.rate1 = torch.nn.Parameter(torch.Tensor(1))self.rate2 = torch.nn.Parameter(torch.Tensor(1))self.head_dim = self.out_planes // self.headself.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)self.softmax = torch.nn.Softmax(dim=1)self.fc = nn.Conv2d(3 * self.head, self.kernel_conv * self.kernel_conv, kernel_size=1, bias=False)self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim, out_planes,kernel_size=self.kernel_conv, bias=True, groups=self.head_dim, padding=1,stride=stride)self.reset_parameters()def reset_parameters(self):init_rate_half(self.rate1)init_rate_half(self.rate2)kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)for i in range(self.kernel_conv * self.kernel_conv):kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)self.dep_conv.bias = init_rate_0(self.dep_conv.bias)def forward(self, x):q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)scaling = float(self.head_dim) ** -0.5b, c, h, w = q.shapeh_out, w_out = h // self.stride, w // self.stridepe = self.conv_p(position(h, w, x.dtype, x.is_cuda))q_att = q.view(b * self.head, self.head_dim, h, w) * scalingk_att = k.view(b * self.head, self.head_dim, h, w)v_att = v.view(b * self.head, self.head_dim, h, w)if self.stride > 1:q_att = stride(q_att, self.stride)q_pe = stride(pe, self.stride)else:q_pe = peunfold_k = self.unfold(self.pad_att(k_att)).view(b * self.head, self.head_dim,self.kernel_att * self.kernel_att, h_out,w_out)  # b*head, head_dim, k_att^2, h_out, w_outunfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att * self.kernel_att, h_out,w_out)  # 1, head_dim, k_att^2, h_out, w_outatt = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1)att = self.softmax(att)out_att = self.unfold(self.pad_att(v_att)).view(b * self.head, self.head_dim, self.kernel_att * self.kernel_att,h_out, w_out)out_att = (att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out)f_all = self.fc(torch.cat([q.view(b, self.head, self.head_dim, h * w), k.view(b, self.head, self.head_dim, h * w),v.view(b, self.head, self.head_dim, h * w)], 1))f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])out_conv = self.dep_conv(f_conv)return self.rate1 * out_att + self.rate2 * out_conv

3. YOLOv11中添加ACmix

3.1 在ultralytics/nn下新建Extramodule

 3.2 在Extramodule里创建ACmix

在ACmix.py文件里添加给出的ACmix代码

添加完ACmix代码后,在ultralytics/nn/Extramodule/__init__.py文件中引用

3.3 在tasks.py里引用

在ultralytics/nn/tasks.py文件里引用Extramodule

在tasks.py找到parse_model(ctrl+f可以直接搜索parse_model位置

添加如下代码:

        elif m in {ACmix}:args = [ch[f],  ch[f]]

4. 新建一个yolo11ACmix.yaml文件

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'# [depth, width, max_channels]n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPss: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPsm: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPsl: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPsx: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs# YOLO11n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4- [-1, 2, C3k2, [256, False, 0.25]]- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8- [-1, 2, C3k2, [512, False, 0.25]]- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16- [-1, 2, C3k2, [512, True]]- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32- [-1, 2, C3k2, [1024, True]]- [-1, 1, SPPF, [1024, 5]] # 9- [-1, 2, C2PSA, [1024]] # 10# YOLO11n head
head:- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 6], 1, Concat, [1]] # cat backbone P4- [-1, 2, C3k2, [512, False]] # 13- [-1, 1, ACmix, []]- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 4], 1, Concat, [1]] # cat backbone P3- [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)- [-1, 1, ACmix, []]- [-1, 1, Conv, [256, 3, 2]]- [[-1, 13], 1, Concat, [1]] # cat head P4- [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)- [-1, 1, ACmix, []]- [-1, 1, Conv, [512, 3, 2]]- [[-1, 10], 1, Concat, [1]] # cat head P5- [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)- [-1, 1, ACmix, []]- [[17, 21, 26], 1, Detect, [nc]] # Detect(P3, P4, P5)

大家根据自己的数据集实际情况,修改nc大小。

5.模型训练

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLOif __name__ == '__main__':model = YOLO(r'D:\yolo\yolov11\ultralytics-main\datasets\yolo11ACmix.yaml')model.train(data=r'D:\yolo\yolov11\ultralytics-main\datasets\data.yaml',cache=False,imgsz=640,epochs=100,single_cls=False,  # 是否是单类别检测batch=8,close_mosaic=10,workers=0,device='0',optimizer='SGD',amp=True,project='runs/train',name='exp',)

模型结构打印,成功运行 :

6.本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的,后期我会根据各种前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

YOLOv11有效涨点专栏

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1552892.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Linux 进程状态、僵尸进程与孤儿进程

目录 0.前言 1. 进程状态 1.1 定义 1.2 常见进程 2.僵尸进程 2.1 定义 2.2 示例 2.3 僵尸进程的危害与防止方法 3. 孤儿进程 3.1 介绍 3.2 示例 4.小结 (图像由AI生成) 0.前言 在上一篇文章中,我们介绍了进程的基本概念、进程控制块&#…

蓝桥杯—STM32G431RBT6(IIC通信--EEPROM(AT24C02)存储器进行通信)

一、什么是IIC?24C02存储器有什么用? IIC (IIC 是半双工通信总线。半双工意味着数据在某一时刻只能沿一个方向传输,即发送数据的时候不能接收数据,接收数据的时候不能发送数据)即集成电路总线(…

Activiti7 工作流引擎学习

目录 一. 什么是 Activiti 工作流引擎 二. Activiti 流程创建步骤 三. Activiti 数据库表含义 四. BPMN 建模语言 五. Activiti 使用步骤 六. 流程定义与流程实例 一. 什么是 Activiti 工作流引擎 Activiti 是一个开源的工作流引擎,用于业务流程管理&#xf…

第二弹:面向对象编程中的类与对象

文章目录 面向对象编程中的类与对象1. 类与对象的定义1.1 类和对象的概念1.2 类的基本定义 2. 类的封装2.1 类的封装语法2.2 类成员访问权限2.3 struct和class的区别2.4 类封装与成员函数定义分离 3. 类对象的创建与销毁3.1 静态与动态对象的创建3.2 对象的销毁 4. 构造函数和析…

深入解析 ConcurrentHashMap:从 JDK 1.7 到 JDK 1.8

✨探索Java基础 ConcurrentHashMap✨ 引言 ConcurrentHashMap 是 Java 中一个线程安全的高效 Map 集合。它在多线程环境下提供了高性能的数据访问和修改能力。本文将详细探讨 ConcurrentHashMap 在 JDK 1.7 和 JDK 1.8 中的不同实现方式,以及它们各自的优缺点。 …

(笔记)第三期书生·浦语大模型实战营(十一卷王场)--书生入门岛通关第2关Python 基础知识

学员闯关手册:https://aicarrier.feishu.cn/wiki/ZcgkwqteZi9s4ZkYr0Gcayg1n1g?open_in_browsertrue 课程视频:https://www.bilibili.com/video/BV1mS421X7h4/ 课程文档:https://github.com/InternLM/Tutorial/tree/camp3/docs/L0/Python 关…

如何使用ssm实现基于JSP的高校听课评价系统

TOC ssm753基于JSP的高校听课评价系统jsp 绪论 1.1 研究背景 现在大家正处于互联网加的时代,这个时代它就是一个信息内容无比丰富,信息处理与管理变得越加高效的网络化的时代,这个时代让大家的生活不仅变得更加地便利化,也让时…

【LeetCode: 1870. 准时到达的列车最小时速 | 二分】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

各种饺子的做法

【羊肉馅水饺】 材料:羊肉1000克、洋葱2个、香油3汤匙、盐适量、姜2片、料酒1汤匙、白胡椒粉、十三香1茶匙、 做法: 1.把羊肉剁成肉馅,羊肉选用带一些肥肉的,味道比较香,如果羊肉比较瘦,可以放一些猪的肥肉一起剁成馅…

电商店铺多开自动回复软件

在电商平台上开设多个店铺,即店铺多开,是一种扩展业务和增加销售额的策略。然而,店铺多开需要谨慎规划和执行,以避免违反平台规定和管理上的混乱。以下是如何实现店铺多开的详细步骤和注意事项。 1. 确定多开目标 在决定多开店铺…

4个顶级的大模型推理引擎

LLM 在文本生成应用中表现出色,例如具有高理解度和流畅度的聊天和代码完成模型。然而,它们的庞大规模也给推理带来了挑战。基本推理速度很慢,因为 LLM 会逐个生成文本标记,需要对每个下一个标记进行重复调用。随着输入序列的增长&…

【CKA】七、七层负载-Ingress应用

7、七层负载-Ingress应用 1. 考题内容: 2. 答题思路: 1、要先查到集群中使用的ingressclass 2、编写yaml 我考的题只是把 hi 服务换成了 hello,其他都一模一样 3. 官网地址: https://kubernetes.io/zh-cn/docs/concepts/serv…

Pytorch实现RNN实验

一、实验要求 用 Pytorch 模块的 RNN 实现生成唐诗。要求给定一个字能够生成一首唐诗。 二、实验目的 理解循环神经网络(RNN)的基本原理:通过构建一个基于RNN的诗歌生成模型,学会RNN是如何处理序列数据的,以及如何在…

LabVIEW提高开发效率技巧----快速实现原型和测试

在LabVIEW开发中,DAQ助手(DAQ Assistant)和Express VI为快速构建原型和测试功能提供了极大的便利,特别适合于简单系统的开发和早期验证阶段。 DAQ助手:是一种可视化配置工具,通过图形界面轻松设置和管理数据…

HISTCITE分析进阶

不可否认histcite是一个很好的文献分析的工具,他能很好的找到最重要的那几篇文章,同时也能找到研究的发文趋势、研究机构和著名的研究学者等。但是它是一个很老的软件,因而很多东西都没能跟上下载的分析。我在使用过程中,尝试做一些改变使其更好用,同时也做一些记录。 1.…

C语言数组和指针笔试题(四)

目录 二维数组例题一例题二例题三例题四例题五例题六例题七例题八例题九例题十例题十一 结果 感谢各位大佬对我的支持,如果我的文章对你有用,欢迎点击以下链接 🐒🐒🐒个人主页 🥸🥸🥸C语言 🐿️…

vulnhub-Matrix 1靶机

vulnhub:https://www.vulnhub.com/entry/matrix-1,259/ 导入靶机,扫描IP 靶机在192.168.81.6,扫描端口 存在三个端口,有两个都是http服务,访问 80端口的网页没什么信息,31337的网页元素里有注释 ZWNobyAi…

加密与安全_HTOP 一次性密码生成算法

文章目录 HOTP 的基础原理HOTP 的工作流程HOTP 的应用场景HOTP 的安全性安全性增强措施Code生成HOTP可配置项校验HOTP可拓展功能计数器(counter)计数器在客户端和服务端的作用计数器的同步机制客户端和服务端中的计数器表现服务端如何处理计数器不同步计…

dubbo微服务

一.启动nacos和redis 1.虚拟机查看是否开启nacos和redis docker ps2.查看是否安装nacos和redis docker ps -a3.启动nacos和redis docker start nacos docker start redis-6379 docker ps二.创建三个idea的maven项目 1.第一个项目dubboapidemo 2.1.1向pom.xml里添加依赖 …

MES(软件)系统是什么?MES系统为何如此重要呢?

一、MES系统的定义与功能 MES系统是一套面向制造企业车间执行层的生产信息化管理系统,它涵盖了多种功能模块,包括但不限于: 订单管理:处理客户订单,确保生产需求与市场需求相匹配。生产调度:根据订单和生…