笔记02----重新思考轻量化视觉Transformer中的局部感知CloFormer(即插即用)

1. 基本信息

  • 论文标题: 《Rethinking Local Perception in Lightweight Vision Transformer》
  • 中文标题: 《重新思考轻量化视觉Transformer中的局部感知》
  • 作者单位: 清华大学
  • 发表时间: 2023
  • 论文地址: https://arxiv.org/abs/2303.17803
  • 代码地址: https://github.com/qhfan/CloFormer

2. 应用场景

  • 图像分类、目标检测、语义分割等领域。

3. 研究背景

  • 现阶段,Transformer在图像分类、目标检测、语义分割等领域表现出优异的性能。然而Transformer参数量和计算量太大,不适合部署到移动设备。
  • 在现有的轻量级Transformer模型中,大多数方法只注重设计稀疏注意力以有效处理低频全局信息,而处理高频局部信息的方法相对简单。

4. 方法概述

为了同时利用共享权重和上下文感知权重的优势,提出了CloFormer,这是一种具有上下文感知局部增强功能的轻量级视觉转换器,具体贡献如下:

  1. CloFormer 中,引入了一种名为 AttnConv 的卷积算子,它采用注意力机制,充分利用共享权重和上下文感知权重的优势来实现局部感知。 此外,它使用了一种新方法,该方法结合了比普通局部自注意力更强的非线性来生成上下文感知权重。
  2. CloFormer 中,采用双分支架构,其中一个分支使用AttnConv 捕获高频信息,而另一个分支使用带有下采样的普通注意力捕获低频信息。 双分支结构使 CloFormer 能够同时捕获高频和低频信息。
  3. 该方法在图像分类、目标检测和语义分割方面的广泛实验证明了 CloFormer 的优越性。 CloFormerImageNet1k 上仅用 4.2M 参数和 0.6G FLOP 就实现了 77.0% 的准确率,明显优于其他模型。

4.1 整体网络结构

在这里插入图片描述

如上图所示,CloFormer包含一个卷积主干和四个阶段。每个阶段由Clo block和ConvFFN组成, 先通过卷积主干传递输入图像以获得tokens。 该系统由四个卷积组成,每个卷积的步幅分别为2、2、1和1。 随后,标记经过四个阶段的Clo block和ConvFFN来提取层次特征。 最后,利用全局平均池化和全连接层来生成预测。

  • ConvFFN

为了将局部信息整合到FFN过程中,用ConvFFN取代了传统的FFN。 ConvFFN和常用的FFN之间的主要区别是,ConvFFN在GELU激活后使用深度卷积(DWconv),这使得ConvFFN能够聚合局部信息。 由于DWconv,下行采样可以直接在ConvFFN中执行,而无需引入PatchMerge模块。 CloFormer使用了两种类型的ConvFFN。 第一种是级内ConvFFN,它直接利用跳过连接。 另一个是连接两个阶段的ConvFFN。 在这种类型的ConvFFN的跳过连接中,使用DWconv和全连接层分别对输入进行下采样和上维。

在这里插入图片描述

  • Clo block

每个块由一个本地分支和一个全局分支组成。 在全局分支中,首先对K和V进行下采样,然后对Q、K和V进行标准attention处理,提取低频全局信息。

在这里插入图片描述

4.2 AttnConv模块

全局分支的模式有效地减少了需要注意的flop的数量,也产生了一个全局接受野。 然而,它在有效捕获低频全局信息的同时,对高频局部信息的处理能力不足。
在AttnConv中,首先应用线性变换得到Q,K, V,这与标准注意力相同,在进行线性变换后,首先对V进行共享权值的局部特征聚合处理,然后基于处理后的V和Q, K进行上下文感知的局部增强。具体分为为如下三个步骤:

  • Local Feature Aggregation

使用一个简单的深度卷积(DWconv)来对 V 进行局部信息聚合。

  • Context-aware Local Enhancement

使用两个DWconv分别聚合Q和K的本地信息。 然后,计算Q和K的Hadamard积,并对结果进行一系列变换,以获得−1到1之间的上下文感知权重。 最后,利用生成的权值对局部特征进行增强。

  • Fusion with Global Branch

使用简单的方法将局部分支的输出与全局分支的输出融合。

4.3 代码

可以将Clo block当作注意力机制使用,具体代码如下:

import torch  
import torch.nn as nn  
from efficientnet_pytorch.model import MemoryEfficientSwish  # 从 EfficientNet 的库中引入高效激活函数 Swish  
class AttnMap(nn.Module):  def __init__(self, dim):  super().__init__()  # 定义一个包含两层卷积和激活函数的块,用于生成注意力映射  self.act_block = nn.Sequential(  nn.Conv2d(dim, dim, 1, 1, 0),  # 1x1 卷积,保持通道数不变  MemoryEfficientSwish(),       # Swish 激活函数  nn.Conv2d(dim, dim, 1, 1, 0)  # 再次使用 1x1 卷积  )  def forward(self, x):  return self.act_block(x)  # 前向传播,返回处理后的张量  class CloAttention(nn.Module):  def __init__(self, dim, num_heads=8, group_split=[4, 4], kernel_sizes=[5], window_size=4,  attn_drop=0., proj_drop=0., qkv_bias=True):  super().__init__()  # 参数初始化和断言检查  assert sum(group_split) == num_heads  # 确保分组的头总数等于注意力头总数  assert len(kernel_sizes) + 1 == len(group_split)  # 核大小和分组数一致  self.dim = dim  # 输入通道数  self.num_heads = num_heads  # 总的多头注意力头数  self.dim_head = dim // num_heads  # 每个头的通道数  self.scalor = self.dim_head ** -0.5  # 注意力缩放因子  self.kernel_sizes = kernel_sizes  # 高频分支的卷积核大小  self.window_size = window_size  # 低频分支窗口大小  self.group_split = group_split  # 每个分支分配的头数  # 创建高频和低频分支的模块  convs = []  # 高频卷积  act_blocks = []  # 高频注意力模块  qkvs = []  # 高频分支的 QKV 卷积  for i in range(len(kernel_sizes)):  kernel_size = kernel_sizes[i]  group_head = group_split[i]  if group_head == 0:  continue  # 如果分组头数为 0,跳过此分支  convs.append(nn.Conv2d(3 * self.dim_head * group_head, 3 * self.dim_head * group_head, kernel_size,  1, kernel_size // 2, groups=3 * self.dim_head * group_head))  # 高频卷积  act_blocks.append(AttnMap(self.dim_head * group_head))  # 注意力映射模块  qkvs.append(nn.Conv2d(dim, 3 * group_head * self.dim_head, 1, 1, 0, bias=qkv_bias))  # QKV 卷积  # 定义低频全局注意力分支  if group_split[-1] != 0:  self.global_q = nn.Conv2d(dim, group_split[-1] * self.dim_head, 1, 1, 0, bias=qkv_bias)  # Q 卷积  self.global_kv = nn.Conv2d(dim, group_split[-1] * self.dim_head * 2, 1, 1, 0, bias=qkv_bias)  # KV 卷积  self.avgpool = nn.AvgPool2d(window_size, window_size) if window_size != 1 else nn.Identity()  # 平均池化  # 将模块添加到 ModuleList 中  self.convs = nn.ModuleList(convs)  self.act_blocks = nn.ModuleList(act_blocks)  self.qkvs = nn.ModuleList(qkvs)  self.proj = nn.Conv2d(dim, dim, 1, 1, 0, bias=qkv_bias)  # 投影层  self.attn_drop = nn.Dropout(attn_drop)  # 注意力权重的 dropout        self.proj_drop = nn.Dropout(proj_drop)  # 输出的 dropout  def high_fre_attntion(self, x: torch.Tensor, to_qkv: nn.Module, mixer: nn.Module, attn_block: nn.Module):  '''  高频分支的注意力计算  x: (b c h w) 输入特征  '''        b, c, h, w = x.size()  qkv = to_qkv(x)  # 计算 QKV,得到 (b, 3*m*d, h, w)        qkv = mixer(qkv).reshape(b, 3, -1, h, w).transpose(0, 1).contiguous()  # 混合后得到 (3, b, m*d, h, w)        q, k, v = qkv  # 分解为 Q、K、V  attn = attn_block(q.mul(k)).mul(self.scalor)  # 计算缩放后的注意力  attn = self.attn_drop(torch.tanh(attn))  # 使用 tanh 激活并应用 dropout        res = attn.mul(v)  # 应用注意力权重到 V        return res  def low_fre_attention(self, x: torch.Tensor, to_q: nn.Module, to_kv: nn.Module, avgpool: nn.Module):  '''  低频分支的注意力计算  x: (b c h w) 输入特征  '''        b, c, h, w = x.size()  q = to_q(x).reshape(b, -1, self.dim_head, h * w).transpose(-1, -2).contiguous()  # 计算 Q 并调整形状为 (b, m, h*w, d)        kv = avgpool(x)  # 对输入特征进行平均池化  kv = to_kv(kv).view(b, 2, -1, self.dim_head, (h * w) // (self.window_size ** 2)).permute(1, 0, 2, 4, 3).contiguous()  # 计算 KV        k, v = kv  # 分解为 K、V  attn = self.scalor * q @ k.transpose(-1, -2)  # 计算缩放后的注意力  attn = self.attn_drop(attn.softmax(dim=-1))  # 对注意力进行 softmax 和 dropout        res = attn @ v  # 应用注意力权重到 V        res = res.transpose(2, 3).reshape(b, -1, h, w).contiguous()  # 调整形状为原始形状  return res  def forward(self, x: torch.Tensor):  '''  x: (b c h w) 输入特征  '''        res = []  # 保存各分支的输出  for i in range(len(self.kernel_sizes)):  if self.group_split[i] == 0:  continue  res.append(self.high_fre_attntion(x, self.qkvs[i], self.convs[i], self.act_blocks[i]))  # 高频分支输出  if self.group_split[-1] != 0:  res.append(self.low_fre_attention(x, self.global_q, self.global_kv, self.avgpool))  # 低频分支输出  return self.proj_drop(self.proj(torch.cat(res, dim=1)))  # 合并分支输出并应用投影  # 输入 N C HW,  输出 N C H W
if __name__ == '__main__':  block = CloAttention(64).cuda()  # 初始化 CloAttention 模块  input = torch.rand(1, 64, 64, 64).cuda()  # 创建一个随机输入  output = block(input)  # 前向传播  print(f"Input_Size:{input.size()}\nOutput_Size:{output.size()}")  # 打印输入和输出的张量形状

5. 结果

表中报告了ImageNet1K分类结果。 结果表明,当模型大小和FLOPs相似时,模型比以前的模型性能更好。 其中,CloFormer-XXS仅使用4.2万个参数和0.6G FLOPs, Top-1准确率达到77.0%,分别超过ShuffleNetV22x、MobileViT-XS和EdgeViT-XXS 1.6%、2.2%和2.6%

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/19119.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

LVGL-从入门到熟练使用

LVGL简介 LVGL( Light and Versatile Graphics Library )是一个轻量、多功能的开源图形库。 1、丰富且强大的模块化图形组件:按钮 、图表 、列表、滑动条、图片等 2、高级的图形引擎:动画、抗锯齿、透明度、平滑滚动、图层混合等…

【python系列】python数据类型的分类和比较

一、数据类型的定义 在程序设计的类型系统中,数据类型(英语:Data type),又称资料型态、资料型别,是用来约束数据的解释。——Wikipedia 从定义我们可以看出来,数字类型的理解最主要的是约束数据…

SpringBoot(二十七)SpringBoot集成XRebel实现异常定位

之前我使用JRebel实现了IDEA热更新。 这几天我无聊的时候,研究了一下JRebel发现,好像不止JRebel一个插件,同时安装的还有一个XRebel插件,百度了一下,XRebel可以实现异常定位,还有方法的执行分析&#xff0c…

windows上部署flask程序

文章目录 前言一、准备工作二、配置 Gunicorn 或 uWSGI1.安装 Waitress2.修改启动文件来使用 Waitress 启动 Flask 应用3.配置反向代理(可选)4.启动程序访问 三.Flask 程序在 Windows 启动时自动启动1.使用 nssm(Non-Sucking Service Manager…

python调用MySql保姆级教程(包会的)

目录 一、下载MySql 二、安装MySql 三、验证MySql是否OK 1、MySQL控制台验证 2、命令提示符cmd窗口验证 四、Python调用MySql 4.1 安装pysql 4.2 使用pysql 4.2.1、连接数据库服务器并且创建数据库和表 4.2.2 、将人脸识别考勤系统识别到的数据自动填入到数据库的表单中…

如何解决将长视频转换为易于处理的 Spacetime Patch 的问题?

🍉 CSDN 叶庭云:https://yetingyun.blog.csdn.net/ 将长视频转换为易于处理的 Spacetime Patch(时空补丁)是一项挑战,尤其是当视频内容复杂或包含长时间连续场景时。在计算机视觉和视频分析等领域,Spacetim…

大数据学习16之Spark-Core

1. 概述 1.1.简介 Apache Spark 是专门为大规模数据处理而设计的快速通用的计算引擎。 一种类似 Hadoop MapReduce 的通用并行计算框架,它拥有MapReduce的优点,不同于MR的是Job中间结果可以缓存在内存中,从而不需要读取HDFS,减少…

LeetCode 力扣 热题 100道(五)最长回文子串(C++)

最长回文子串 给你一个字符串 s,找到 s 中最长的 回文子串。 回文性 如果字符串向前和向后读都相同,则它满足 回文性 子字符串子字符串 是字符串中连续的 非空 字符序列。 动态规划法 class Solution { public:string longestPalindrome(string s) {i…

dropout层/暂退法

作用:正则化,缓解过拟合 实现方式: 在前向传播过程中,将该层的一部分神经元的输出特征随机丢掉(设为 0),相当于随机消灭一部分神经元仅在训练期间使用,测试时没有神经元被丢掉。 正…

【圆上的连线——卡特兰数】

题目 思路 因为不相交,所以每个点最多连出一条线,所以参与连线的点一定是偶数个 我们按照选出点的数量 2,4 …… 2x 将答案划分,答案可以表示为 (假设我们选出2x个点连线,假设方法数为 :2x个点参…

Pytest-Bdd-Playwright 系列教程(11):场景快捷方式

Pytest-Bdd-Playwright 系列教程(11):场景快捷方式 前言1. 手动绑定场景的传统方法2. 场景快捷方式的自动绑定方法2.1 绑定所有场景2.2 绑定多个路径2.3 自动与手动绑定的结合 3. 示例:结合 Playwright 的实际应用3.1 项目目录结构…

day-17 反转字符串中的单词

利用split()函数和substring函数 code: class Solution {public String reverseWords(String s) {int m0;while(s.charAt(m) ){m;}ss.substring(m);String arr[]s.split("[\\s]");int narr.length;String ss"";for(int in-1;i>1;i--){ssssarr[i]"…

Ubuntu20.04从零安装IsaacSim/IsaacLab

Ubuntu20.04从零安装IsaacSim/IsaacLab 电脑硬件配置:安装Isaac sim方案一:pip安装方案二:预构建二进制文件安装1、安装ominiverse2、在ominiverse中安装isaac sim,下载最新的4.2版本 安装Isaac Lab1、IsaacLab环境克隆2、创建con…

力扣hot100-->二分查找

二分查找 1. 33. 搜索旋转排序数组 中等 整数数组 nums 按升序排列&#xff0c;数组中的值 互不相同 。 在传递给函数之前&#xff0c;nums 在预先未知的某个下标 k&#xff08;0 < k < nums.length&#xff09;上进行了 旋转&#xff0c;使数组变为 [nums[k], nums[…

Javaweb梳理17——HTMLCSS简介

Javaweb梳理17——HTML&CSS简介 17 HTML&CSS简介17.1 HTML介绍17.2 快速入门17.3 基础标签17.3 .1 标题标签17.3.2 hr标签17.3.3 字体标签17.3.4 换行17.3.8 案例17.3.9 图片、音频、视频标签17.3.10 超链接标签17.3.11 列表标签 17 HTML&CSS简介 今日目标&#x…

倍福PLC数据 转 IEC61850项目案例

目录 1 案例说明 2 VFBOX网关工作原理 3 准备工作 4 设置倍福PLC 5 配置网关参数采集倍福PLC数据 6 用IEC61850协议转发数据 7 网关使用多个逻辑设备和逻辑节点的方法 8 案例总结 1 案例说明 设置倍福PLC&#xff0c;开通ADS通信设置网关采集倍福PLC数据把采集的数据转…

代码辅助工具 GPT / Cursor

代码辅助工具 GPT / Cursor 文章说明GPT辅助效果第一次提问效果第二次提问效果第三第四次提问效果手动微调布局和宽高的效果第五次要求添加主题切换效果第六次提问--继续让它优化主题切换的效果第七次提问--修改主题切换的按钮位置并添加动画提问词第一次提问词第二次提问词第三…

FPGA 常用 I/O 电平标准有哪些?

在 FPGA 的神奇世界里&#xff0c;I/O 电平标准就像魔法咒语&#xff0c;掌控着芯片与外界交流的方式。对于初涉 FPGA 领域的小白来说&#xff0c;这些标准可能有点神秘莫测&#xff0c;但别担心&#xff0c;今天我就用最通俗易懂的方式为你揭开它们的面纱。 一、电平标准的魔…

网络协议(4)拥塞控制

之前已经说过了tcp也是会考虑网络的情况的&#xff0c;也就是当网络出现问题的时候tcp不会再对报文进行重传。当所有的用户在网络不好的时候都不会对丢失的报文进行重传。这样就会防止网络瘫痪。 这样的机制也就是tcp会进行拥塞控制。 拥塞控制 所谓的慢启动看下面这张图就能…

#define定义宏(2)

大家好&#xff0c;今天给大家分享两个技巧。 首先我们应该先了解一下c语言中字符串具有自动连接的特点。注意只有将字符串作为宏参数的时候才可以把字符串放在字符串中。 下面我们来讲讲这两个技巧 1.使用#&#xff0c;把一个宏参数变成对应的字符串。 2.##的作用 可以把位…