【DETR】

img
https://tianfeng.space/

前言

论文 代码

DETR(Data-efficient Image Transformer)是一种用于目标检测任务的深度学习模型。它与传统的目标检测方法不同,采用了Transformer架构,将目标检测问题转化为一个序列到序列的问题。以下是DETR模型的一些关键特点:

  1. Transformer架构: DETR采用了Transformer架构,这是一种用于自然语言处理的架构,但在DETR中被用于图像处理。这种架构允许模型同时处理整个图像,而不是传统的滑动窗口或区域提议方法。

  2. 序列到序列:DETR将目标检测问题建模为一个序列到序列的问题,其中输入序列是图像的嵌入表示,输出序列是目标的嵌入表示。这种方法允许模型根据图像上的所有信息来预测目标。

  3. 位置嵌入: DETR引入了位置嵌入,用于指示目标在图像中的位置。这些位置嵌入与目标的嵌入结合起来,帮助模型预测目标的位置。

  4. 多头注意力: 模型使用多头自注意力机制,允许它关注不同位置的图像信息以预测目标的位置和类别。

  5. 无需锚框:与传统的目标检测方法不同,DETR不需要使用锚框(anchor boxes)或区域提议网络(Region Proposal Network)。它直接从输入图像中x生成目标框,这使得模型更简洁和易于训练。

框架解读

img

基本思想

使用ResNet作为backbone提取图片特征,同时会使用一个1*1的卷积进行降维。因为transformer的编码器模块只处理序列输入,所以后续还需要把CNN特征展开为一个序列。

先来个CNN得到各Patch作为输入,再套transformer做编码和解码编码路子跟VIT基本一样,重在在解码,直接预测100个坐标框。CNN 的特征提取部分没有什么可以说的,目标检测的图一般比较大,那么直接上 Transformer 计算上吃不消,所以先用 CNN 进行特征提取并缩减尺寸,再使用 Transformer 是常规操作。

DETR使用的典型值是C = 2048和H,W = H0 / 32,W0 / 32;C=2048 是每个 token 的维度,还是比较大,所以先经过一个 1 × 1 的卷积进行降维,然后再输入 Transformer Encoder 。此时自注意力机制在特征图上进行全局分析,因为最后一个特征图对于大物体比较友好,那么在上面进行 Self-Attention 会便于网络更好的提取不同位置不同大物体之间的相互关系的联系,然后位置编码是被每一个 Multi-Head Self-Attention 前都加入了的。

将ResNet提取的特征图转成特征序列后,图像就失去了像素的空间分布信息,所以Transformer就引入位置编码。把特征序列和位置编码序列拼接起来,作为编码起的输入。

img

整体网络架构

DETR 分为四个部分,首先是一个 CNN 的 backbone,Transformer 的 Encoder,Transformer 的 Decoder,最后的预测层 FFN。

DETR使用传统的CNN主干网络来学习输入图像的2D表示。该模型对其进行扁平序列化(大的卷积核和步长使其变成一个个patch,并行展开输入Encoder),并在将其传递到转换器编码器之前用位置编码对其进行补充。然后,转换器解码器将少量固定数量的学习位置嵌入作为输入,我们称之为对象查询,并额外处理编码器输出。我们将解码器的每个输出嵌入传递到共享前馈网络(FFN),该网络预测检测(类和边界框)或“无对象”类。(论文预测100框)

object queries是核心,让它学会怎么从原始特征找到是物体的位置

img

Encoder完成的任务

得到各个目标的注意力结果,准备好特征,等解码器来选秀

img

Decoder

输出层就是100个object queries预测编码器,解码器首先随机初始化object queries(0+位置编码,),先自己self attention学习一下;然后用解码器学到的q去查询编码的KV,通过多层让其学习如何利用输入特征。

输出的匹配

GT只有两个,但是预测的恒为100个,怎么匹配呢?匈牙利匹配完成,按照LOSS最小的组合,剩下98个都是背景。集合到集合的预测看起来非常直接,但是在训练的过程就会遇到一个问题,就是如何把预测出来的100个框与ground truth做匹配,然后得到损失。DETR就非常暴力,直接利用pd(predicttion)与gt(ground truth)按照最小权重做一对一匹配,剩余的框全部当做背景处理。

此权重的构成:

分类损失:这里分类损失是由直接softmax的值取出来的。举个例子:预测100个目标框,每个目标框有92个候选类别,经softmax输出后有out,shape=(100,92)。根据groundtruth的target标签假设(有20个),根据这些类别值直接作为索引值筛选出每个预测目标框的类别以及概率,最后剩下了=(100,20)的softmax的值。也就是说只把图片内存在的类别作为交叉熵损失的选择,然后用softmax来作为损失,由于1是常数,直接进行了一个省略。目标框的损失是将预测的目标框,与gt中每个目标框做L1损失,假设gt有20个目标框,就会产生200*20个损失值。同上,求IOU并取负做损失,损失加权求和作为总损失。

然后利用匈牙利匹配出目标框,将预测框的索引值和对应位置的gt目标狂的索引配对输出。其余的就直接抛弃。

该算法实现预测值与真值之间最优的匹配,并且是一一对应,不会多个预测值匹配到同一个ground truth上,这样就无需NMS后处理了。假设预测结果是N个,那么标注信息也要是N个,假设N=6,但真实标签2个,剩下的4个(标注如果小于N就用无物体信息去填充)标注信息都是用无类别来填充。

注意力起到的作用

这个注意力挺有意思,能不被遮挡,照样可以学出来(注意颜色)

img

细节

decoder中的位置肯定最重要了,这个得学习才行;每层都预测(Auxiliary);100个预测框之间可以相互通信,训练用了多个卡,

100个框各自要干啥

论文中可视化了其中20个,绿色是小物体,红蓝是大物体基本描述了各个位置都需要关注,而且它们还是各不相同的

img

额外证明

transformer不仅在检测领域好使,分割里照样行(感觉就像是让一群人去做分割,每个人做其中一块,最后合并一起)

img

img

简单使用

环境配置

下载代码

git clone https://github.com/facebookresearch/detr.git

下载pytorch和torchvision必须的

conda install -c pytorch pytorch torchvision

安装scipy和pycocotools

conda install cython scipy
pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'

数据集下载

https://cocodataset.org/#download

也提供网盘链接:https://pan.baidu.com/s/1RM_9Eip_-94eJtL23fEM5Q
提取码:icnt

分别为标注文件,训练集和测试集

path/to/coco/annotations/  # annotation json filestrain2017/    # train imagesval2017/      # val images

模型训练

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --coco_path /path/to/coco

模型评估

python main.py --batch_size 2 --no_aux_loss --eval --resume https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth --coco_path /path/to/coco

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/143062.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Java之IO流概述

1.1 什么是IO 生活中,你肯定经历过这样的场景。当你编辑一个文本文件,忘记了ctrls ,可能文件就白白编辑了。当你电脑上插入一个U盘,可以把一个视频,拷贝到你的电脑硬盘里。那么数据都是在哪些设备上的呢?键…

【数据库——MySQL】(10)视图和索引

目录 1. 视图1.1 创建视图1.2 查询视图 2. 索引2.1 索引的分类2.2 索引的建立 参考书籍 1. 视图 1.1 创建视图 基础语法: CREATE [OR REPLACE] VIEW 视图名[(列名表)]ASSELECT语句[WITH CHECK OPTION]说明: 在默认情况下,将在当前数据库创…

Linux 用户 用户组管理

用户 Linux系统是一个多用户多任务的分时操作系统,任何要使用系统资源的用户,都必须首先向系统管理员申请一个账号,然后以这个账号的身份进入系统。每个用户账号都拥有一个唯一的用户名和各自的口令。用户在登录时键入正确的用户名和口令后&a…

华为ICT——第二章-数字图像处理私人笔记

目录 1:计算机视觉:​编辑 2:计算机视觉应用:​编辑 3:计算机视界核心问题:​编辑 4:相关学科: 5:计算机视觉与人工智能: 最成熟的技术方向是图像识别 6…

【面试算法——动态规划 20】最长公共子序列 不相交的线

1143. 最长公共子序列 链接: 1143. 最长公共子序列 给定两个字符串 text1 和 text2,返回这两个字符串的最长 公共子序列 的长度。如果不存在 公共子序列 ,返回 0 。 一个字符串的 子序列 是指这样一个新的字符串:它是由原字符串在不改变字…

Spring面试题8:面试官:说一说Spring的BeanFactory

该文章专注于面试,面试只要回答关键点即可,不需要对框架有非常深入的回答,如果你想应付面试,是足够了,抓住关键点 面试官:说一说Spring的BeanFactory Spring的BeanFactory是Spring框架的核心容器,负责管理和创建Bean对象。它是一个工厂类,用于实例化、配置和管理Bean的…

SpringBoot 如何使用 Druid 进行数据库连接池管理

使用 Druid 进行数据库连接池管理的 Spring Boot 应用 数据库连接池是任何Web应用程序的重要组成部分,它们有助于管理数据库连接的复用,提高性能和资源利用率。Druid是一个强大的数据库连接池,它具有监控、防SQL注入、快速、可扩展等特点。在…

谈谈最近招人的感受!

最近折腾新的项目,面试了很多实习生小伙伴,我说说我的一些「面试」感受, 虽然是一个老生常谈的话题,但是依然提一下。 准时很重要:提前一点时间,踩个点,别迟到,面试的过程中由于每个…

低功耗引擎Cliptrix为什么可以成为IOT的高效能工具

在万物互联的时代,现代人已普遍接受电视、音箱等电器设备具备智能化能力,也是在这个趋势下,我们身边越来越多的iOT设备联网和交互成为刚需。 但iot设备也面临到一些非常显著的痛点,例如iot设备的内存、处理器等核心元件无法与手机…

爬虫 — 多线程

目录 一、多任务概念二、实现多任务方式1、多进程 (Multiprocessing)2、多线程(Multithreading)3、协程(Coroutine) 三、多线程执行顺序四、多线程的方法1、join()2、setDaemon()3、threading.enumerate() …

python运算函数

简 python输入输出函数input() :用户用于读取键盘输入的函数,返回值为“string”类型 运算函数abs(x) :x的绝对值int(x) :将x转换成整型(截掉小数部分)float(x):浮点数divmod(x,y):返回(x//y,x%y)complex(re,im):返回一…

linux部署页面内容

/bin:该目录包含了常用的二进制可执行文件,如ls、cp、mv、rm等等。 /boot:该目录包含了启动Linux系统所需的文件,如内核文件和引导加载程序。 /dev:该目录包含了所有设备文件,如硬盘、光驱、鼠标、键盘等等…

Scoket网络编程

1.首先来的个简单示例: 客户端: using System; using System.Net.Sockets; using System.Net; using System.Text;namespace Client {internal class Program{static void Main(string[] args){Console.WriteLine("Client");// 创建一个Socket并连接到服…

windows11 cmd使用python没有反应, windows11使用python跳应用商店

1. 修改系统变量位置,右击我的电脑,选择属性: 点击环境变量,找到path: 将python 的path移到windowsapp 上侧 保存退出。重新打开cmd,输入命令python -v

网络通信(套接字通信)(C/C++)

1.网络编程必知概念 1.广域网和局域网 广域网:又称外网、公网。是连接不同地区局域网或城域网进行计算机通信的远程公共网络。 局域网:在一定的通信范围内,有很个多计算机组成的私有网络就叫局域网。(这些计算机相互之间是可以通信的,但是不能直接访问外网(可以通过网线…

虹科方案 | LIN/CAN总线汽车零部件测试方案

文章目录 摘要一、汽车零部件测试的重要性?二、虹科的测试仿真工具如何在汽车零部件测试展露头角?三、应用场景**应用场景1:方向盘开关的功能测试****应用场景2:各类型电机的控制测试****应用场景3:RGB氛围灯的功能测试…

CISSP,你值得拥有(我的学习之路)

(只分享三点:怎么学、怎么练、怎么考。) 我为啥去考CISSP 我是个在信安行业摸爬滚打将近20年的老油条,知道CISSP这个认证是很早前的事情了,但一直以来都觉得它有点难,加上人又懒得要命,也就始…

安装elasticsearch

1.部署单点es 1.1.创建网络 因为我们还需要部署kibana容器,因此需要让es和kibana容器互联。这里先创建一个网络: docker network create es-net 1.2.加载镜像 这里我们采用elasticsearch的7.12.1版本的镜像,这个镜像体积非常大,接近1G。不建议大家自己pull。 课前资料提…

用selenium和xpath定位元素并获取属性值以及str字符型转json型

页面html如图所示: 要使用xpath定位这个div元素,并且获取其属性data-config的内容值。 from selenium import webdriver from selenium.webdriver.common.by import By from selenium.webdriver.chrome.options import Optionshost127.0.0.1 port10808 …

Serlet API详解

目录 一、HttpServlet 1.1 处理doGet请求 1.2 处理doPost请求 二、HttpServletRequest 2.1 核心方法 三、HttpServletRespons 3.1 核心方法 一、HttpServlet 在编写Servlet代码的时候,首先第一步要做的就是继承HttpServlet类,并重写其中的某些方法 核心…