视觉常用Backbone大全:VisionTransformer(ViT)

       视觉常用Backbone大全

       今天介绍的主干网络模型叫VisionTransformer,是一种将 Transformer 架构应用于计算机视觉任务的模型,通过将图像进行切块,将图片转变为self-attention认识的token输入到Transformer模块中,实现了Transformer架构在视觉领域的应用;

一、模型介绍

       Transformer 最初是由 Vaswani 等人在 2017 年的论文《Attention is All You Need》中提出,主要用于自然语言处理任务。2020 年,Google Research 在论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》中首次将 Transformer 应用于图像识别任务,提出了 ViT 模型。 

       上图左边就是ViT模型的整体架构,模型整体由三部分构成,分别是Linear Projection、Transformer Encoder、MLP Head;接下来就对这几个模块做进一步的分析;

二、模块分析

2.1 Linear Projection

在这个模块里,模型对数据进行了切分和编码操作,具体操作如下:

       先说切分,如上图左下角所示,将一张图片切分成9块那样,假设原图像的尺寸为224*224*3,我们想要使得切出来的小块的尺寸为16*16,那么我们就需要将原图分成(224 / 16 )^2 = 196块,每一块的尺寸为16*16*3,这个操作叫做patch,在实际代码中,patch的裁剪是用一个patch_size大小的卷积同时以patch_size的步长进行卷积实现的;

       在NLP中,一段文本在输入到transformer模块之前需要将文本送入一个可训练的encode模型进行编码,模型会将文本进行切分,再将切分后的文本转换成token张量,这样一个token张量就是一个一维的向量;

       在图像中也是一样,我们将图像进行切分,但是切分后的小图像块还是三维数据,所以我们还需要将这样的每一个小图像块进行维度转换,将其变为一维的向量将其视为一个token,这样转换后的小图像块尺寸就变为了(1,768),我们一共有196个这样的图像块,所以转换维度之后的图像数据维度为(196,768);

       在转换为token张量之后,我们还需要对其添加位置信息,位置信息张量的维度与现有token维度相同,通过add方式(对应元素相加)进行融合;

       最后,输入tensor还需要有一个class token(分类层),数据格式和其他token一样为(1,768),与位置编码的融合方式不一样,这里做的是Concat(维度上的拼接),这样做是因为分类信息是在后面需要取出来单独做预测的,所以不能以Add方式融合,shape也就从(196, 768)变为(197, 768);

       至此,Linear Projection部分就全部完成,token将被送入Transformer Encoder模块;

2.2 Transformer Encoder

       在整个Transformer Encoder中,实际上就是如上图的几个encoder block模块的堆叠,我们可以看到在encoder block模块中主要包含Layer Normalization、Multi-Head Attention、DropOut/DropPath和MLP四部分;

2.2.1 Layer Normalization

       Layer Normalization的作用和BatchNorm是同样的作用,都是将数据进行正则化处理,使得模型可以加速收敛,区别在于两者对数据处理的维度不同,BN针对的是批量数据中的某个维度进行操作,而LN则是针对某个样本的所有维度进行操作;

2.2.2 Multi-Head Attention

       这里的多头自注意力机制应用的就是Transformer的Multi-Head Attention,这里就简单说一下self-attention以及Multi-Head Attention;

       对于self-attention机制,它接收的是token张量,每一个token张量进入self-attention后,都会与可训练权重w_{q}w_{k}w_{v}进行矩阵运算,生成对应的Q、K、V张量,然后不同token的Q和K进行运算(有公式)生成紧密权重\alpha张量,最后再将\alpha与V进行运算,生成新的token张量,这就是自注意力机制的计算流程;

       而对于Multi-Head Attention,它与self-attention的区别在于同一个token可以生成多个Q、K、V张量对,中间的运算流程都是相同的,最后会输出多个新的token张量,再将其加权合并为一个token进行输出;

2.2.3 MLP

        这个就是一个简单的前馈神经网络,通过全连接层、激活层、dropout层的串联对token进一步做特征提取和特征融合的操作;

2.2.4 Transformer Encoder代码

# transformer编码class Block(nn.Module):def __init__(self,dim,num_heads,mlp_ratio=4.,qkv_bias=False,qk_scale=None,drop_ratio=0.,attn_drop_ratio=0.,drop_path_ratio=0.,act_layer=nn.GELU,norm_layer=nn.LayerNorm):super(Block, self).__init__()self.norm1 = norm_layer(dim)self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)# NOTE: drop path for stochastic depth, we shall see if this is better than dropout hereself.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)def forward(self, x):x = x + self.drop_path(self.attn(self.norm1(x)))x = x + self.drop_path(self.mlp(self.norm2(x)))return x

三、拓扑结构

       这里以最常使用的ViT-B/16为例,看一下它的拓扑结构以及代码实现;

       如上图所示就是ViT-B/16模型的拓扑结构,从结构上可以看出ViT-B/16模型输入尺寸为224*224,patch_siae=16*16,进过编码后输入到由12个多头自注意力机制的transformer encoder堆叠起来的网络中,最后通过一个MLP head进行分类;

四、代码实现

       下面是ViT-B/16模型基于pytorch的实现:

# ViT-B/16import torch
import torch.nn as nn
import torch.nn.functional as Fclass PatchEmbedding(nn.Module):def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):super(PatchEmbedding, self).__init__()self.img_size = img_sizeself.patch_size = patch_sizeself.num_patches = (img_size // patch_size) ** 2self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)def forward(self, x):x = self.proj(x).flatten(2).transpose(1, 2)return xclass PositionalEncoding(nn.Module):def __init__(self, embed_dim, dropout=0.1, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)pe = torch.zeros(max_len, embed_dim)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0).transpose(0, 1)self.register_buffer('pe', pe)def forward(self, x):x = x + self.pe[:x.size(0), :]return self.dropout(x)class TransformerEncoderLayer(nn.Module):def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):super(TransformerEncoderLayer, self).__init__()self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)self.linear1 = nn.Linear(embed_dim, mlp_dim)self.dropout = nn.Dropout(dropout)self.linear2 = nn.Linear(mlp_dim, embed_dim)self.norm1 = nn.LayerNorm(embed_dim)self.norm2 = nn.LayerNorm(embed_dim)self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)self.dropout3 = nn.Dropout(dropout)def forward(self, src):src2, _ = self.self_attn(src, src, src)src = src + self.dropout1(src2)src = self.norm1(src)src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))src = src + self.dropout2(src2)src = self.norm2(src)return srcclass VisionTransformer(nn.Module):def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_dim=3072, num_classes=1000):super(VisionTransformer, self).__init__()self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)self.pos_embed = PositionalEncoding(embed_dim, max_len=self.patch_embed.num_patches + 1)self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))self.transformer = nn.ModuleList([TransformerEncoderLayer(embed_dim, num_heads, mlp_dim) for _ in range(depth)])self.norm = nn.LayerNorm(embed_dim)self.head = nn.Linear(embed_dim, num_classes)def forward(self, x):B = x.shape[0]x = self.patch_embed(x)cls_tokens = self.cls_token.expand(B, -1, -1)x = torch.cat((cls_tokens, x), dim=1)x = self.pos_embed(x)for layer in self.transformer:x = layer(x)x = self.norm(x)x = self.head(x[:, 0])return x# 示例用法
if __name__ == "__main__":model = VisionTransformer()input_tensor = torch.randn(1, 3, 224, 224)  # 假设输入图像大小为224x224output = model(input_tensor)print(output.shape)  # 输出形状应为 (1, 1000)

五、模型优缺点

优点:

  1. 全局依赖关系:通过自注意力机制,ViT 能够捕获图像中的全局依赖关系,这对于理解复杂的视觉场景非常有用。
  2. 灵活的输入表示:ViT 的 patch-based 输入表示方式使得模型可以灵活地处理不同分辨率的图像。
  3. 强大的特征提取能力:Transformer 的强大建模能力使得 ViT 在大规模数据集上表现出色。
  4. 端到端训练:ViT 可以从头开始训练,不需要复杂的预处理步骤。

缺点:

  1. 计算成本高:ViT 的自注意力机制计算复杂度较高,特别是对于高分辨率图像,计算量和内存消耗都非常大。
  2. 数据需求大:ViT 需要在大规模数据集上进行训练才能取得良好的性能,对于小规模数据集的效果可能不如传统的卷积神经网络。
  3. 过拟合风险:由于模型参数量较大,ViT 在小规模数据集上容易发生过拟合。
  4. 训练不稳定:ViT 的训练过程可能不够稳定,需要仔细调整超参数和优化策略。

       ViT模型相比于CNN架构模型优点在于它可以借助transformer全局信息互通、信息融合的特性来从全局的角度进行特征信息的提取,可以有效提高对复杂图像的理解能力,但图像信息却又不像文本信息那样有很强的上下文关联性,甚至图像缺少部分像素也不影响对图像的识别任务,所以这种全局的强关联性又是比较冗余的信息,同时还加大了运算量;即我们既希望模型可以多关注一点全局信息的特征,但又不希望过多的去关注全局的特征,这个是ViT模型所存在的问题。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/20641.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

星海智算:Stable Diffusion3.5镜像教程

Stable Diffusion3.5 模型介绍 Stable Diffusion 3.5是由Stability AI推出的最新图像生成模型,它是Stable Diffusion系列中的一个重大升级。这个模型家族包括三个版本,分别是Stable Diffusion 3.5 Large、Stable Diffusion 3.5 Large Turbo和Stable Dif…

[JavaWeb] 尚硅谷JavaWeb课程笔记

1 Tomcat服务器 Tomcat目录结构 bin:该目录下存放的是二进制可执行文件,如果是安装版,那么这个目录下会有两个exe文件:tomcat10.exe、tomcat10w.exe,前者是在控制台下启动Tomcat,后者是弹出GUI窗口启动To…

【Unity基础】认识Unity中的包

Unity中的包是一个核心概念,像Unity本身的功能的扩展,或者项目中资源的管理,都是通过包的形式来实现的。 一、什么是包? 一个包包含满足您项目各种需求的功能。这可以包括编辑器安装过程中附带的任何核心Unity功能,也…

elment-ui的折叠tree表单实现纯前端搜索,展开收起功能

好久没更新博客了~ 记录一下本次做的一个很扯的需求 纯前端去实现这个查询的功能,后台返回的是个数组对象,前端要给他包装成树结构先展示 之后参考代码路径src\views\goods\category\index.vue 需求描述: 搜索输入任何一个关键字,都会展开他的父级,两个栏目都包含了,那么两个父…

linux先创建文件夹后指定创建文件夹用户

1、创建文件夹,然后创建用户并指定用户目录,然后修改目录所有权给该目录 # 创建 /home/test 目录 mkdir /home/test # 设置权限(确保有适当的读写权限) chown root:root /home/test chmod 700 /home/test # 创建 xl 用户并指定家…

大模型(LLM)全参数微调有哪些技巧,常用的轻量级微调有哪些,微调策略应该如何选择?

大家好,我是微学AI,今天给大家介绍一下大模型(LLM)全参数微调有哪些技巧,常用的轻量级微调有哪些,微调策略应该如何选择?本文将从大模型(LLM)全参数微调技巧,常用的轻量级微调方法,微调策略应该…

蓝牙电话-如何自动切换手机SIM卡(预研方向)

蓝牙电话-如何自动切换手机SIM卡(预研方向) 一、前言 最近突然有客户问说,蓝牙电话的app既然已经能统计手机里面插了多少张卡,那能不能做双卡的SIM卡自动切换?即:设置一个呼叫策略和频率,当打…

【蓝桥杯C/C++】C语言和C++的常量概念与区别分析

博客主页: [小ᶻZ࿆] 本文专栏: 蓝桥杯C/C 文章目录 💯前言💯常量的概念和作用💯C语言中 const 的应用与限制#define 和 enum 的使用方法 💯C 中 const 的计算方法和处理💯代码实例和应用区别&#x1f…

全面解析亚马逊云服务器(AWS):功能、优势与使用指南

亚马逊云服务器(AWS)概述 亚马逊云服务器(Amazon Web Services,简称AWS)是全球领先的云计算平台,提供一系列强大且灵活的云服务,帮助企业和开发者通过云基础设施实现数据存储、计算、分析和机器…

“小浣熊家族AI办公助手”产品体验 — “人人都是数据分析师”

一、引言: 大家平时应该在工作中常常使用到Excel来做数据统计,比如临近过年时,公司一般会开各种复盘、年终、检讨、明年规划大会,势必需要准备一大堆的量化数据报表,用于会议上的数据汇报、分析工作,试想一…

C盘扩容(C盘右键无法扩展卷解决)超详细步骤!!!

目录 1、问题及需求2、解决办法方法2 1、问题及需求 今天一看C盘爆红了,但是D盘还剩很多空间,想要从D盘再分出来50G给C盘。 但是压缩了D盘,在C盘扩展卷,实现不了,因为不仅挨着。看下边的解决办法 2、解决办法 桌面上…

机器学习笔记 // 天气预报、股票价格以及历史轨迹(如摩尔定律)// 时间序列的常见属性

时间序列随处可见。你可能已经在天气预报、股票价格以及历史轨迹[如摩尔定律,见下图​]等事物中见过它们。摩尔定律预测微芯片上面的晶体管个数大约每两年会翻倍。几乎50年以来,它已经被证明对未来的计算能源和成本来说是一个准确的预测器。 许多时间序列…

mysql日志写满出现The table ‘xxxx_amazon_order’ is full

数仓发现写数据出现 SQL 错误 [1114] [HY000]: The table ‘xxxx_amazon_order’ is full 1.第一时间查看系统磁盘, 发现空间写满了 df -h因为mysql是使用docker部署的, Docker 的默认存储位置在 /var/lib/docker /var 目录默认是在根分区 (/dev/mapper/centos-root) 下的 …

(一)Ubuntu22.04服务器端部署Stable-Diffusion-webui AI绘画环境

一、说明 cup型号: Intel(R) Celeron(R) CPU G1610 2.60GHz 内存大小: 8G 显卡型号:NVIDIA P104-100 注意:系统睡眠问题 sudo systemctl mask sleep.target suspend.target hibernate.target hybrid-sleep.target 网卡设置 …

springboot:少量配置信息情形

发现无论怎么改都还是指向8001 所以换一种方法 通过 结果 代码 import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.web.server.ConfigurableWebServerFactory; import org.springframework.boot.web.server.WebServerFactoryCusto…

SpringBoot的快速入门

Maven Maven可以方便管理依赖的 Jar 包 IDEA 自带Maven,也可以选择自己安装 安装Maven:https://blog.csdn.net/qq_59636442/article/details/142314019 创建项目 通过Spring Initializr 快速创建项目:https://start.springboot.io/ 我的项目名叫blog&a…

2024中国高校计算机大赛 — 大数据挑战赛-赛后复盘

一、赛题描述 基于气象大数据的自动站实况联合预测 风光清洁能源的管理与气象关系密不可分,因为风能和太阳能的发电效率直接依赖于气象条件。风力发电需要精确的风速和风向预测,而太阳能发电则依赖于日照时间和云层覆盖情况的准确预报。优质的气象预测…

J.U.C - 深入解析ReentrantLock原理源码

文章目录 概述synchronized的缺陷1)synchronized不能控制阻塞,不能灵活控制锁的释放。2)在读多写少的场景中,效率低下。 独占锁ReentrantLock原理ReentrantLock概述AQS同步队列1. AQS实现原理2. 线程被唤醒时,AQS队列的…

基于Java+Springboot+Jpa+Mysql实现的在线网盘文件分享系统功能设计与实现二

一、前言介绍: 免费学习:猿来入此 1.1 项目摘要 在线网盘文件分享系统的课题背景主要源于现代社会对数字化信息存储和共享需求的日益增长。随着互联网的普及和技术的快速发展,人们越来越依赖电子设备来存储和传输各种类型的数据文件。然而…

DBSCAN聚类——基于密度的聚类算法(常用的聚类算法)

DBSCAN(Density-Based Spatial Clustering of Applications with Noise)简称密度聚类或密度基础聚类,是一种基于密度的聚类算法,也是一种常用的无监督学习算法,特别适用于形状不规则的聚类和含有噪声的数据集。主要用于…