YOLOv10改进策略【注意力机制篇】| 引入MobileNetv4中的Mobile MQA,提高模型效率

一、本文介绍

本文记录的是基于Mobile MQA模块的YOLOv10目标检测改进方法研究MobileNetv4中的Mobile MQA模块是用于模型加速,减少内存访问的模块,相比其他全局的自注意力,其不仅加强了模型对全局信息的关注,同时也显著提高了模型效率。

文章目录

  • 一、本文介绍
  • 二、Mobile MQA注意力原理
  • 三、Mobile MQA的实现代码
  • 四、添加步骤
    • 4.1 改进点⭐
  • 五、添加步骤
    • 5.1 修改ultralytics/nn/modules/block.py
    • 5.2 修改ultralytics/nn/modules/__init__.py
    • 5.3 修改ultralytics/nn/modules/tasks.py
  • 六、yaml模型文件
    • 6.1 模型改进⭐
  • 七、成功运行结果


二、Mobile MQA注意力原理

在论文《MobileNetV4 - Universal Models for the Mobile Ecosystem》中,提出了Mobile MQA

一、原理

  1. 基于MQA改进并结合不对称空间下采样
    • MQA(Multi-Query Attention)简化了传统的多头注意力机制,通过共享keysvalues来减少内存访问需求。在移动混合模型中,当批量大小较小时,这种方式能有效提高运算强度。
    • 借鉴MQA中对querieskeysvalues的不对称计算方式,Mobile MQA引入了空间缩减注意力(SRA),对keysvalues进行下采样,同时保持高分辨率的queries。这是因为在混合模型中,早期层的空间混合卷积滤波器使得空间上相邻的标记具有相关性。
    • Mobile MQA的计算公式为:
      M o b i l e _ M Q A ( X ) = C o n c a t ( a t t e n t i o n 1 , . . . , a t t e n t i o n n ) W O Mobile\_MQA(X)= Concat(attention_1,...,attention_n)W^{O} Mobile_MQA(X)=Concat(attention1,...,attentionn)WO
      其中 a t t e n t i o n j = s o f t m a x ( ( X W Q j ) ( S R ( X ) W K ) T d k ) ( S R ( X ) W V ) attention_j = softmax(\frac{(XW^{Q_j})(SR(X)W^{K})^{T}}{\sqrt{d_k}})(SR(X)W^{V}) attentionj=softmax(dk (XWQj)(SR(X)WK)T)(SR(X)WV),这里SR可以是空间缩减操作(在设计中是一个步长为2的3x3深度卷积),也可以是恒等函数(当不进行空间缩减时)。

二、特点

  1. 针对加速器优化:专门为移动加速器进行了优化,考虑了移动加速器的计算和内存特性。
  2. 不对称空间下采样:通过对keysvalues进行下采样,保持queries的高分辨率,在不损失太多精度的情况下,显著提高了效率。
  3. 操作简单高效:相比传统的注意力机制,Mobile MQA的设计更加简单,操作更加高效,更适合在移动设备上运行。

论文:http://arxiv.org/abs/2404.10518
源码:https://github.com/tensorflow/models/blob/master/official/vision/modeling/backbones/mobilenet.py

三、Mobile MQA的实现代码

Mobile MQA模块的实现代码如下:


def conv2d(in_channels, out_channels, kernel_size=3, stride=1, groups=1, bias=False, norm=True, act=True):conv = nn.Sequential()padding = (kernel_size - 1) // 2conv.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias, groups=groups))if norm:conv.append(nn.BatchNorm2d(out_channels))if act:conv.append(nn.ReLU6())return convclass MultiQueryAttentionLayerWithDownSampling(nn.Module):def __init__(self, in_channels, num_heads, key_dim, value_dim, query_h_strides, query_w_strides, kv_strides, dw_kernel_size=3, dropout=0.0):"""Multi Query Attention with spatial downsampling.Referenced from here https://github.com/tensorflow/models/blob/master/official/vision/modeling/layers/nn_blocks.py3 parameters are introduced for the spatial downsampling:1. kv_strides: downsampling factor on Key and Values only.2. query_h_strides: vertical strides on Query only.3. query_w_strides: horizontal strides on Query only.This is an optimized version.1. Projections in Attention is explict written out as 1x1 Conv2D.2. Additional reshapes are introduced to bring a up to 3x speed up."""super(MultiQueryAttentionLayerWithDownSampling, self).__init__()self.num_heads = num_headsself.key_dim = key_dimself.value_dim = value_dimself.query_h_strides = query_h_stridesself.query_w_strides = query_w_stridesself.kv_strides = kv_stridesself.dw_kernel_size = dw_kernel_sizeself.dropout = dropoutself.head_dim = self.key_dim // num_headsif self.query_h_strides > 1 or self.query_w_strides > 1:self._query_downsampling_norm = nn.BatchNorm2d(in_channels)self._query_proj = conv2d(in_channels, self.num_heads * self.key_dim, 1, 1, norm=False, act=False)if self.kv_strides > 1:self._key_dw_conv = conv2d(in_channels, in_channels, dw_kernel_size, kv_strides, groups=in_channels,norm=True, act=False)self._value_dw_conv = conv2d(in_channels, in_channels, dw_kernel_size, kv_strides, groups=in_channels,norm=True, act=False)self._key_proj = conv2d(in_channels, key_dim, 1, 1, norm=False, act=False)self._value_proj = conv2d(in_channels, key_dim, 1, 1, norm=False, act=False)self._output_proj = conv2d(num_heads * key_dim, in_channels, 1, 1, norm=False, act=False)self.dropout = nn.Dropout(p=dropout)def forward(self, x):bs, seq_len, _, _ = x.size()# print(x.size())if self.query_h_strides > 1 or self.query_w_strides > 1:q = F.avg_pool2d(self.query_h_strides, self.query_w_strides)q = self._query_downsampling_norm(q)q = self._query_proj(q)else:q = self._query_proj(x)px = q.size(2)q = q.view(bs, self.num_heads, -1, self.key_dim)  # [batch_size, num_heads, seq_len, key_dim]if self.kv_strides > 1:k = self._key_dw_conv(x)k = self._key_proj(k)v = self._value_dw_conv(x)v = self._value_proj(v)else:k = self._key_proj(x)v = self._value_proj(x)k = k.view(bs, 1, self.key_dim, -1)   # [batch_size, 1, key_dim, seq_length]v = v.view(bs, 1, -1, self.key_dim)    # [batch_size, 1, seq_length, key_dim]# calculate attention score# print(q.shape, k.shape, v.shape)attn_score = torch.matmul(q, k) / (self.head_dim ** 0.5)attn_score = self.dropout(attn_score)attn_score = F.softmax(attn_score, dim=-1)# context = torch.einsum('bnhm,bmv->bnhv', attn_score, v)# print(attn_score.shape, v.shape)context = torch.matmul(attn_score, v)context = context.view(bs, self.num_heads * self.key_dim, px, px)output = self._output_proj(context)# print(output.shape)return output
参数解释
in_channels输入通道数
num_heads自注意力头的数量
key_dim键的维度
key_dim值的维度
value_dim仅用于查询的,在H方向上的步长
query_h_strides仅用于查询的,在W方向上的步长
query_w_strides仅对键和值进行下采样,1不进行下采样,2下采样
dw_kernel_size=3深度可分离卷积的卷积核大小
dropout=0.0随机丢失比例

四、添加步骤

4.1 改进点⭐

模块改进方法:基于Mobile MQA模块C2f

改进方法是对YOLOv10中的C2f模块进行改进。MobileNetv4中的Mobile MQA模块可用于模型加速,减少内存访问的模块,相比其他全局的自注意力,利用Mobile MQA改进C2f模块后,不仅加强了模型对全局信息的关注,同时也显著提高了模型效率。

改进代码如下:

class C2f_MQA(nn.Module):"""Faster Implementation of CSP Bottleneck with 2 convolutions."""def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):"""Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,expansion."""super().__init__()self.c = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, 2 * self.c, 1, 1)self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))self.att = MultiQueryAttentionLayerWithDownSampling(c2, 2, 48, 48, 1, 1, 1)def forward(self, x):"""Forward pass through C2f layer."""y = list(self.cv1(x).chunk(2, 1))y.extend(m(y[-1]) for m in self.m)return self.att(self.cv2(torch.cat(y, 1)))def forward_split(self, x):"""Forward pass using split() instead of chunk()."""y = list(self.cv1(x).split((self.c, self.c), 1))y.extend(m(y[-1]) for m in self.m)return self.att(self.cv2(torch.cat(y, 1)))

在这里插入图片描述

在这里插入图片描述

注意❗:在5.2和5.3小节中的文件中需要声明的模块名称为:MultiQueryAttentionLayerWithDownSamplingC2f_MQA


五、添加步骤

5.1 修改ultralytics/nn/modules/block.py

此处需要修改的文件是ultralytics/nn/modules/block.py

block.py中定义了网络结构的通用模块,我们想要加入新的模块就只需要将模块代码放到这个文件内即可。

MultiQueryAttentionLayerWithDownSamplingC2f_MQA模块代码添加到此文件下。

5.2 修改ultralytics/nn/modules/init.py

此处需要修改的文件是ultralytics/nn/modules/__init__.py

__init__.py文件中定义了所有模块的初始化,我们只需要将block.py中的新的模块命添加到对应的函数即可。

MultiQueryAttentionLayerWithDownSamplingC2f_MQAblock.py中实现,所有要添加在from .block import

from .block import (C1,C2,...MultiQueryAttentionLayerWithDownSampling,C2f_MQA
)

在这里插入图片描述

5.3 修改ultralytics/nn/modules/tasks.py

tasks.py文件中,需要在两处位置添加各模块类名称。

首先:在函数声明中引入MultiQueryAttentionLayerWithDownSamplingC2f_MQA

在这里插入图片描述

在这里插入图片描述

其次:在parse_model函数中注册MultiQueryAttentionLayerWithDownSamplingC2f_MQA模块

在这里插入图片描述

在这里插入图片描述


六、yaml模型文件

6.1 模型改进⭐

在代码配置完成后,配置模型的YAML文件。

此处以ultralytics/cfg/models/v10/yolov10m.yaml为例,在同目录下创建一个用于自己数据集训练的模型文件yolov10m-C2f_MQA.yaml

yolov10m.yaml中的内容复制到yolov10m-C2f_MQA.yaml文件下,修改nc数量等于自己数据中目标的数量。

📌 模型的修改方法是将骨干网络中的所有C2f模块替换成C2f_MQA模块,优化整体,提高效率。

结构如下:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4- [-1, 3, C2f_MQA, [128, True]]- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8- [-1, 6, C2f_MQA, [256, True]]- [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16- [-1, 6, C2f_MQA, [512, True]]- [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32- [-1, 3, C2fCIB, [1024, True]]- [-1, 1, SPPF, [1024, 5]] # 9- [-1, 1, PSA, [1024]] # 10# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 6], 1, Concat, [1]] # cat backbone P4- [-1, 3, C2f, [512]] # 13- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 4], 1, Concat, [1]] # cat backbone P3- [-1, 3, C2f, [256]] # 16 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 13], 1, Concat, [1]] # cat head P4- [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium)- [-1, 1, SCDown, [512, 3, 2]]- [[-1, 10], 1, Concat, [1]] # cat head P5- [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large)- [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)

七、成功运行结果

分别打印网络模型可以看到C2f_MQA已经加入到模型中,并可以进行训练了。

YOLOv10m-C2f_MQA

                   from  n    params  module                                       arguments                     0                  -1  1      1392  ultralytics.nn.modules.conv.Conv             [3, 48, 3, 2]                 1                  -1  1     41664  ultralytics.nn.modules.conv.Conv             [48, 96, 3, 2]                2                  -1  2    185472  ultralytics.nn.modules.block.C2f_MQA         [96, 96, True]                3                  -1  1    166272  ultralytics.nn.modules.conv.Conv             [96, 192, 3, 2]               4                  -1  4   1257984  ultralytics.nn.modules.block.C2f_MQA         [192, 192, True]              5                  -1  1     78720  ultralytics.nn.modules.block.SCDown          [192, 384, 3, 2]              6                  -1  4   4580352  ultralytics.nn.modules.block.C2f_MQA         [384, 384, True]              7                  -1  1    228672  ultralytics.nn.modules.block.SCDown          [384, 576, 3, 2]              8                  -1  2   1689984  ultralytics.nn.modules.block.C2fCIB          [576, 576, 2, True]           9                  -1  1    831168  ultralytics.nn.modules.block.SPPF            [576, 576, 5]                 10                  -1  1   1253088  ultralytics.nn.modules.block.PSA             [576, 576]                    11                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          12             [-1, 6]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           13                  -1  2   1993728  ultralytics.nn.modules.block.C2f             [960, 384, 2]                 14                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          15             [-1, 4]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           16                  -1  2    517632  ultralytics.nn.modules.block.C2f             [576, 192, 2]                 17                  -1  1    332160  ultralytics.nn.modules.conv.Conv             [192, 192, 3, 2]              18            [-1, 13]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           19                  -1  2    831744  ultralytics.nn.modules.block.C2fCIB          [576, 384, 2, True]           20                  -1  1    152448  ultralytics.nn.modules.block.SCDown          [384, 384, 3, 2]              21            [-1, 10]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           22                  -1  2   1911168  ultralytics.nn.modules.block.C2fCIB          [960, 576, 2, True]           23        [16, 19, 22]  1   2282134  ultralytics.nn.modules.head.v10Detect        [1, [192, 384, 576]]          
YOLOv10m-C2f_MQA summary: 657 layers, 18335782 parameters, 18335766 gradients, 77.8 GFLOPs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1560523.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Spring Boot洗衣店订单系统:简化您的业务流程

1系统概述 1.1 研究背景 如今互联网高速发展,网络遍布全球,通过互联网发布的消息能快而方便的传播到世界每个角落,并且互联网上能传播的信息也很广,比如文字、图片、声音、视频等。从而,这种种好处使得互联网成了信息传…

JavaScript 常量/数据类型/类型转换 简单学习

目录 1. 常量 1.1 常量概述 1.2 案例 1.3 总结 2. 数据类型 2.1 概述 2.2 分类 2.2.1 基本数据类型 2.2.1 基本数据类型——number (数值/字型) 2.2.1 数字型——算术运算符 2.2.1 基本数据类型——String (字符串类型) 2.2.1 字符串拼接 2.2.1 模板字符串 2.2.1…

VMwareWorkstation安装KylinV10sp3(银河麒麟)系统教程

版本说明 VMware版本如下 OS版本如下 创建虚拟机 点击创建新的虚拟机 按图下所示选择,点击下一步 按照图下所示选择,点击下一步 按照图下所示选择,点击下一步 按照图下所示选择,点击下一步 设置虚拟机名称,点击下一步…

java-02 数据结构-队列

在Java中,队列是一种常见的数据结构,用于在保持顺序的同时存储和检索数据。Java提供了java.util.Queue接口,它的常见实现包括ArrayDeque、LinkedList和PriorityQueue等。 如果你觉得我分享的内容或者我的努力对你有帮助,或者你只…

git删除错误的commit

git的流程如图: 当某次失误造成commit的版本有问题,需要回退到正常的版本修改后重新add。 首先通过git log查看commit提交记录,可以看到HEAD->mater是本地最新的commit,而origin/master, origin/HEAD是远程仓库上的最新记录…

Golang | Leetcode Golang题解之第467题环绕字符串中唯一的子字符串

题目: 题解: func findSubstringInWraproundString(p string) (ans int) {dp : [26]int{}k : 0for i, ch : range p {if i > 0 && (byte(ch)-p[i-1]26)%26 1 { // 字符之差为 1 或 -25k} else {k 1}dp[ch-a] max(dp[ch-a], k)}for _, v :…

【GT240X】【3】Wmware17和Centos 8 安装

文章目录 一、说明二、安装WMware2.1 下载WMware2.2 安装2.3 虚拟机的逻辑结构 三、安装Centos3.1 获取最新版本Centos3.2 创建虚拟机 四、问题和简答4.1 centos被淘汰了吗?4.2 centos里面中文显示成小方块的解决方法4.3 汉语-英语输入切换4.4 全屏和半屏切换 五、练…

【mmsegmentation】Backbone模块(进阶)自定义自己的BACKBONE

1、定义自己的backboe driving\models\backbones\efficientnetlite.py import math import torch import torch.nn as nn import torch.functional as F from mmengine.model import BaseModule from mmseg.models import BACKBONES, build_backboneefficientnet_lite_params …

双主轴车床的优势

双主轴车床作为现代制造业中的一项重要技术,其优势显而易见。下面我将从几个方面详细阐述双主轴车床的优势: ‌一、提高生产效率‌ ‌并行加工‌:双主轴车床最大的特点在于其能够同时在两个主轴上进行加工,这种并行加工方式使得在…

苍穹外卖--分页查询

pagehelper插件 先导入坐标 重点代码:service层 利用pagehelper动态拼接limit语句 不用在mapper中写limit 底层利用localthread来传递数据 时间显示不规范 解决方式: 方式一:在属性上加入注解,对日期进行格式化 方式二&#x…

vue基础(总结很详细)

vue 基础 1. vue 是什么? Vue.js (读音 /vju ː /, 类似于 view ) 是一套构建用户界面的渐进式框架。 Vue 只关注视图层, 采用自底向上增量开发的设计。 Vue 的目标是通过尽可能简单的 API 实现响应的数据绑定和组合的视图…

set的基本用法 和 底层简单了解

在前面,已经了解过搜索二叉树了,也了解了一点红黑树的内容(不太了解的可以先查看前面的内容哦);现在我们了学习一下,底层以红黑树实现,遍历以搜索树的中序实现的set/multset; 序列式…

Java | Leetcode Java题解之第470题用Rand7()实现Rand10()

题目&#xff1a; 题解&#xff1a; class Solution extends SolBase {public int rand10() {int a, b, idx;while (true) {a rand7();b rand7();idx b (a - 1) * 7;if (idx < 40) {return 1 (idx - 1) % 10;}a idx - 40;b rand7();// get uniform dist from 1 - 63…

如何实现MySQL异地多活场景

作为现代化的互联网企业 &#xff0c;最怕的是什么 &#xff1f;是意外&#xff01;由各种意外导致的数据库问题&#xff0c;磁盘问题、网络问题、人员误操作问题等等&#xff0c;这些问题都可能导致数据不可用或者丢失&#xff0c;造成重大损失。因此&#xff0c;很少会有企业…

【AI系统】AI 学习方法与算法现状

在人工智能&#xff08;AI&#xff09;的漫长历史中&#xff0c;我们见证了从早期的规则驱动系统到现代的机器学习模型的转变。AI的学习方法是其进步的核心&#xff0c;而算法现状则反映了当前技术的高度和未来的发展方向。 Ⅰ.AI 学习方法 AI的工作原理基于深度神经网络&…

24.4 基于consul服务发现模式

本节重点介绍 : consul 安装consul go代码注册服务&#xff0c;注销服务&#xff0c;获取服务node_exporter改造为consul服务发现在数量比较大时&#xff0c;在注册服务的时候&#xff0c;关闭check&#xff0c;可以降低consul的压力 consul 安装 准备工作 # 下载consul wge…

C++ | Leetcode C++题解之第470题用Rand7()实现Rand10()

题目&#xff1a; 题解&#xff1a; class Solution { public:int rand10() {int a, b, idx;while (true) {a rand7();b rand7();idx b (a - 1) * 7;if (idx < 40) {return 1 (idx - 1) % 10;}a idx - 40;b rand7();// get uniform dist from 1 - 63idx b (a - 1)…

hadoop入门

1.1 hadoop是什么 Hadoop是一个由Apache基金会所开发的分布式系统基础架构&#xff0c;主要是解决海量数据的存储和海量数据的分析计算的问题。通常Hadoop指的是一个更为广泛的概念Hadoop生态圈 1.2 hadoop发展历程 Hadoop创始人Doug Cutting&#xff0c;为了实现与Google类…

小猿口算APP脚本(协议版)

小猿口算是一款专注于数学学习的教育应用,主要面向小学阶段的学生。它提供多种数学练习和测试,包括口算、速算、应用题等。通过智能化的题目生成和实时批改功能,帮助学生提高数学计算能力。此外,它还提供详细的学习报告和分析,帮助家长和教师了解学生的学习进度和薄弱环节…

数据结构前置知识(下)

1. 包装类 Java为了让基本数据类型也能够继承Object,因此给每个基本数据类型提供了包装类, 这样就可以和平常的引用数据类型一样使用了,并且也可以应用在泛型上(后续讲) 基本数据类型包装类byteByteshortShortintIntergerlongLongfloatFloatdoubleDoublecharCharacterboolean…