【论文复现】Vision Transformer(ViT)

1. Transformer结构

在这里插入图片描述

1.1 编码器和解码器

翻译这个过程需要中间体。也就是说,编码,解码之间需要一个中介,英文先编码成一个意思,再解码成中文。

那么查字典这个过程就是编码和解码的体现。首先我们的大脑会把它编码,编码这个句子的意思,然后通过字典映射解码。但是这样的过程太过于繁琐,如果让机器做,超长文本就对应着超长的数据量,也不利于机器学习的上下文理解。那么就有了Attention注意力机制。

1.2 Attention:注意力机制

Attention机制的核心思想是,要想翻译一个句子并不需要完全编码。像我们人类一样,仅凭借几个词就可以猜出整句话的大概意义,即使我们不懂日语,也可以根据一些汉字推出来大概的意思,这是准确度低的情况;而“中译中”这种情况,准确度当然就更高一些。

Attention注意力机制:

Attention示意图本质上是加权平均。如果我非常注意某个地方,我想要多看,那就分配更高的权重。

计算权重是使用相似度计算,

Attention机制的优点

Attention的优点是能够实现并行计算和全局视野。

并行计算的可能性是因为它不像RNN一样,依赖时序数据。它只是加权计算,但并不需要像时序数据那样,依赖像是队列一样的进出顺序。

全局视野是因为在加权计算的时候,这个计算就是涉及了整体的,它一看就能看到全部。

1.3 Self Attention

对于Self Attention来说,它的输入是一个序列,序列的获得依靠的是vector。我们把一个词转换为序列模块,需要用到vector向量去指向。而vector的指向,是有空间性的,比如说两个意思很相近或者同义的词汇,它们在空间中的距离就会比较小。相反,意思差的多,距离当然就远。

这样也可以理解为vector是和词语的意义有关系的。

注意力分配的多少取决于公式:

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V

其中,Q代表Queries(输入的信息),K代表Keys(内容信息),V代表Values(信息本身,只表达了信息的特征)。

在这里插入图片描述

Q,K,V的获得

本质上,是input的线性变换。计算使用的是矩阵乘法,实现方法是nn.Linear。

用点积表示相似度的方法是因为cos角投影长度可以很直观地理解两个量的相似度。

在1.3的式子里除以 d k \sqrt{dk} dk 这一步看起来很多余,但是它是为了避免较大的数值。较大的数值会导致softmax之后值更极端,softmax的极端值会导致梯度消失。这一步相当于控制了数值范围,让它在可观测范围内。

为什么是 d k \sqrt{dk} dk ?假设q,k是均值为0,方差为1的标准正态分布的独立随机变量,那么它们的点积的均值和方差分别为0以及dk。

之后做逐元素相乘(enterwise)。

我们需要将单词意思转换为句中意思,这就涉及一词多义的问题,在逐元素相乘得到的sum之后的z就涉及这个问题。那么我们需要根据句子中其他词来推理当前词的意思。

就比如说Mine,它有两种意思,一种是“我的”,一种是“矿石”。这当然是截然不同的词性和词义。假设我们完全不知道这个多义词,但我们可以通过观察它们在句子中的位置和与上下文的联系来推理这是什么意思。

1.4 Multihead Attention

这一步的核心是复读机()

这一步就是有多个W_q,W_k,W_v,那么上述操作重复多次,将结果用concat串在一起。

这样的复读机机制就是给注意力提供多种可能性。

应用了multihead的Conditional DETR就发现不同的head会将注意力放到物体的不同边上。
在这里插入图片描述

1.5 输入端适配

直接把图片切分成patches,flatten操作拉平patches,然后过一个linear projection使patches维度变小,然后编号123456789…输入网络即可。

就是切蛋糕喂给encoder和decoder。

在这里插入图片描述

这块儿有个patch 0的原因,有一种说法是从NLP来的:为了保持整体结构,变换尽可能的少。而NLP需要一些token负责输出,需要“终止输出”的功能模块。另一种CV里的说法是整合信息,设置在1-9之外就保持了1-9本身无干扰。Patch 0本质上是dynamic pooling layer。

1.6 位置编码

图像切分重排后失去了位置信息,而transformer的内部运算与空间信息无关。这样一来,就需要把位置信息编码重新传进网络。ViT使用了一个可学习的vector来编码,维度和patch维度一样,所以编码vector和patch vector直接相加组成输入。本质是相加,而相加是concat的一种特例。

1.7 ViT结构的数据流

输入图像是256x256像素大小,然后切开,切成N(8x8=64)个小块,每一块则是256/8=32单位长(宽)度。也就是说,现在每一小块儿是32x32。把切开的每一小块都拉平,RGB值为3,每一块儿的维度就是3x32x32=3072维。但是3072维太高了,所以过一个linear projection把维度变成1024。但是此时每个小块儿的空间位置丢失了。所以需要加上position embedding这个可学习的向量,维度一样也是1024,让他们相加。position embedding放在patch0这里,一起进入transformer。

进入了transformer encoder之后,首先因为多了一个patch 0,Patches的表示向量里数量取N+1,即(b,65,1024)。在这个norm层里patches会被归一化,一直检验维度,保证维度是一样的。

最后到MLP Head手里,就只输入负责整合信息的patch 0,此时它表示为(b,1,1024)。这样就可以做分类任务了。
在这里插入图片描述

1.8 训练方法

Transformer 非常吃数据量,需要大量的样本,大规模使用Pre-Train。它先在大数据集(ImageNet)上预训练,然后到小数据集上做 Fine Tuning。

迁移过去之后,需要把原本的MLP Head换掉,换成对应类别数的FC层。处理不同尺寸输入的时候需要对Postional Encoding的结果进行插值。

插值方法:图片切好了之后,编号,但不同的input size和patch size会切出不同数量的patch,position embedding也会变。所以编号的方法需要缩放。

1.9 实验结果

Transformer的性能需要庞大数据量的保证,很吃资源。否则,无法充分的发挥出它的性能。它和ResNet的性能不相上下

Attention的距离可以等价为Conv的感受野大小。越深的层数,Attention跨越的距离越远。在最底层,也有head能覆盖到很远的距离。这说明它确实在捕捉信息,做信息整合。

模型注意力集中的地方,都和分类的语义高度相关。

2. 代码复现

VIT仓库链接:https://github.com/lucidrains/vit-pytorch

Usage:

import torch
from vit_pytorch import ViT # 抽象出了一个VIT类v = ViT(image_size = 256,# 图片像素大小patch_size = 32, # patch的大小num_classes = 1000, # 分类数量dim = 1024,# 维度depth = 6, # transformer的block数量heads = 16, # 线性变换后输出张量的最后维数,多头注意力层中的头数mlp_dim = 2048, # MLP前馈层维度dropout = 0.1, # 每个训练步骤中被关闭神经元的比例,可以调成0emb_dropout = 0.1 # 嵌入丢失率
)img = torch.randn(1, 3, 256, 256)preds = v(img) # (1, 1000)

2.1 切图重排

输入端适配涉及到切图和reshape。

以下是ViT部分.py代码:

import torch
from torch import nnfrom einops import rearrange, repeat
from einops.layers.torch import Rearrange# helpersdef pair(t):return t if isinstance(t, tuple) else (t, t)# classesclass FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout = 0.):super().__init__()self.net = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)class Attention(nn.Module):def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):super().__init__()inner_dim = dim_head *  headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.norm = nn.LayerNorm(dim)self.attend = nn.Softmax(dim = -1)self.dropout = nn.Dropout(dropout)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):x = self.norm(x)qkv = self.to_qkv(x).chunk(3, dim = -1)q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)attn = self.dropout(attn)out = torch.matmul(attn, v)out = rearrange(out, 'b h n d -> b n (h d)')return self.to_out(out)class Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):super().__init__()self.norm = nn.LayerNorm(dim)self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),FeedForward(dim, mlp_dim, dropout = dropout)]))def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + xreturn self.norm(x)class ViT(nn.Module):def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):super().__init__()image_height, image_width = pair(image_size)patch_height, patch_width = pair(patch_size)assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'num_patches = (image_height // patch_height) * (image_width // patch_width)patch_dim = channels * patch_height * patch_widthassert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'self.to_patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),nn.LayerNorm(patch_dim),nn.Linear(patch_dim, dim),nn.LayerNorm(dim),)self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))self.cls_token = nn.Parameter(torch.randn(1, 1, dim))self.dropout = nn.Dropout(emb_dropout)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)self.pool = poolself.to_latent = nn.Identity()self.mlp_head = nn.Linear(dim, num_classes)
# 解析代码段def forward(self, img):x = self.to_patch_embedding(img)b, n, _ = x.shapecls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)x = torch.cat((cls_tokens, x), dim=1)x += self.pos_embedding[:, :(n + 1)]x = self.dropout(x)x = self.transformer(x)x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]x = self.to_latent(x)return self.mlp_head(x)

transpose函数的作用:重排张量(tensor)或者数组的维度。有一个形状为(batch_size, channels, height, width)的四维张量,代表一批图像数据。那就可以将将channels维度移动到最前面,即形状变为(channels, batch_size, height, width)。这时,你就可以使用transpose操作来实现这一转换。(类似于矩阵行变换)

img = torch.randn(1,3,256,256)b=1
c=3
h=256 = h*p1,h = 8
w =256
self.to_patch.embedding = nn.Sequential(Rearrange(‘b c (h p1)(w p2)-> b(h w)(p1 p2 c),p1 = patch_height, p2 = patch_width),# 图片切分重排nn.Linear(patch_dim, dim) # Linear Projection of Flattened Patches)

2.2 构造Patch 0

这一步:

cls_tokens = repeat(self.cls_token, '() n d -> b n d',b = b)x = torch.cat((cls_tokens, x), dim = 1) # concat方法,维度为1,在n的维度上

2.3 positional embedding

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
# 位置编码:它是一个可学习的参数,初始化为随机值。
x += self.pos_embedding[:,:(n+1)]
# 将位置编码加到输入序列上。

2.4 代码示例

首先准备数据集。

from __future__ import print_functionimport glob
from itertools import chain
import os
import random
import zipfileimport matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, DataSet
from torchvision import datasets, transforms
from tqdm.notebook import tqdmfrom vit_pytorch import ViT
batch_size = 64
epochs = 20
lr = 3e-5
gamma = 0.7
seed = 42
def seed_everything(seed):random.seed(seed)os.environ['PYTHONHASHSEED'] = str(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Trueseed_everything(seed)device = 'cuda'
os.makedirs('data',exist_ok = True)train_dir = 'data/train'
test_dir = 'data/test'with zipfile.ZipFile('data/train.zip') as train_zip:train_zip.extractall('data')
with zipfile.ZipFile('data/test.zip') as test_zip:test_zip.extractall('data')train_list = glob.glob(os.path.join(train_dir,'*.jpg')) # 查找匹配的jpg文件
test_list = glob.glob(os.path.join(test_dir,'*.jpg'))print(f'Train Data:{len(train_list)}')
print(f'Test Data:{len(test_list)}')labels = [path.split('/')[-1].split('\\')[-1].split[0] for path in train_list]
print(train_list[0]
print(labels[0]))
random_idx = np.random.randint(1,len(train_list),size = 9)
fig, axes = plt.subplots(3,3,figsize = (16,12))for idx,ax in enumerate(axes.ravel()):img = Imag.open(train_list[idx])ax.set_title(labels[idx])ax.imshow(img)
train_list, valid_list = train_test_split(train_list, test_size = 0.2, stratify = labels, random_state = seed)print(f'Train Data:{len(train_list)}')
print(f'Validation Data:{len(valid_list)}')
print(f'Test Data:{len(test_list)}')
train_tranforms = tranforms.Compose([transforms.Resize((224,224)),transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),]
)val_tranforms = tranforms.Compose([transforms.Resize((224,224)),transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),]
)test_tranforms = tranforms.Compose([transforms.Resize((224,224)),transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),]
)
class CatsDogsDataset(Dataset):def __init__(self, file_list, transform = None):self.file_list = file_listself.transform = transformdef __len__(self):self.filelength = len(self.file_list)return self.filelengthdef __getitem__(self,idx):img_path = self.file_list[idx]img = Image.open(img_path)img_transformed = self.transform(img)label = img_path.split('/')[-1].split("\\")[-1].split("、")[0]label = 1 if label == "dog" else 0return img_transformed, label
train_data = CatsDogsDataset(train_list, transform = train_transforms)
valid_data = CatsDogsDataset(valid_list, transform = val_transforms)
test_data = CatsDogsDataset(test_list, transform = test_transforms)
train_loader = DataLoader(dataset = train_data,batch_size = batch_size,shuffle = True)
valid_loader = DataLoader(dataset = valid_data,batch_size = batch_size,shuffle = True)
test_loader = DataLoader(dataset = test_data,batch_size = batch_size,shuffle = True)
print(len(train_data), len(train_loader))
print(len(valid_data), len(valid_loader))

模型建立:

model = ViT(image_size = 224,patch_size = 16,num_classes = 2,dim = 768,depth = 12,heads = 12,mlp_dim = 3072,dropout = 0.1,emb_dropout = 0.1
).to(device) # 将Transformer模型移动到指定设备上,比如GPUmodel.load_state_dict(torch.load('vit_base_patch16_224_r.pth'),strict = False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1487936.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

数仓架构解析(第45天)

系列文章目录 经典数仓架构传统离线大数据架构 文章目录 系列文章目录烂橙子-终生成长社群群主,前言1. 经典数仓架构2. 传统离线大数据架构 烂橙子-终生成长社群群主, 采取邀约模式,不支持付费进入。 前言 经典数仓架构 传统离线大数据架…

2000-2023年上市公司融资约束指数FC指数(含原始数据+计算结果)

2000-2023年上市公司融资约束指数FC指数(含原始数据计算结果) 1、时间:2000-2023年 2、来源:上市公司年报 3、指标:证券代码、证券简称、统计截止日期、是否剔除ST或*ST或PT股、是否剔除上市不满一年、已经退市或被…

Linus: vim编辑器的使用,快捷键及配置等周边知识详解

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 vim的安装创建新用户 adduser 用户名Linus是个多用户的操作系统是否有创建用户的权限查看当前用户身份:whoami** 怎么创建设置密码passwdsudo提权(sudo输入的是用户…

前端网页打开PC端本地的应用程序实现方案

最近开发有一个需求,网页端有个入口需要跳转三维大屏,而这个大屏是一个exe应用程序。产品需要点击这个入口,并打开这个应用程序。这个就类似于百度网盘网页跳转到PC端应用程序中。 这里我们采用添加自定义协议的方式打开该应用程序。一开始可…

前端:Vue学习 - 购物车项目

前端:Vue学习 - 购物车项目 1. json-server,生成后端接口2. 购物车项目 - 实现效果3. 参考代码 - Vuex 1. json-server,生成后端接口 全局安装json-server,json-server官网为:json-server npm install json-server -…

vue3前端开发-小兔鲜项目-登录和非登录状态下的模板适配

vue3前端开发-小兔鲜项目-登录和非登录状态下的模板适配&#xff01;有了上次的内容铺垫&#xff0c;我们可以根据用户的token来判定&#xff0c;到底是显示什么内容了。 1&#xff1a;我们在对应的导航组件内修改完善一下内容即可。 <script setup> import { useUserSt…

深入理解TensorFlow底层架构

目录 深入理解TensorFlow底层架构 一、概述 二、TensorFlow核心概念 计算图 张量 三、TensorFlow架构组件 前端 后端 四、分布式计算 集群管理 并行计算 五、性能优化 内存管理 XLA编译 六、总结与展望 深入理解TensorFlow底层架构 一、概述 TensorFlow是一个开…

嵌入式C++、InfluxDB、Spark、MQTT协议、和Dash:树莓派集群物联网数据中心设计与实现(代码示例)

1. 项目概述 随着物联网技术的快速发展,如何高效地收集、存储和分析海量IoT设备数据成为一个重要课题。本文介绍了一个基于树莓派集群搭建的小型物联网数据中心,实现了从数据采集到分析可视化的完整流程。 该系统采用轻量级组件,适合资源受限的边缘计算环境。主要功能包括: …

分类常用的评价指标-二分类/多分类

二分类常用的性能度量指标 精确率、召回率、F1、TPR、FPR、AUC、PR曲线、ROC曲线、混淆矩阵 「精确率」查准率 PrecisionTP/(TPFP) 「召回率」查全率RecallTP/(TPFN) 「真正例率」即为正例被判断为正例的概率TPRTP/(TPFN) 「假正例率」即为反例被判断为正例的概率FPRFP/(TNFP)…

Java代码基础算法练习-数值求和-2024.07.25

任务描述&#xff1a; 现有一串字符(长度不超过255个字符)&#xff0c;需对其中的数值字符求和&#xff08;需转换成整型进行计算&#xff09;。 解决思路&#xff1a; 输入字符串&#xff0c;循环对每个字符否为数字&#xff0c;转换整型并求和 转换整型有以下的方式 1. su…

当 Nginx 出现请求的重复提交,如何处理?

&#x1f345;关注博主&#x1f397;️ 带你畅游技术世界&#xff0c;不错过每一次成长机会&#xff01; 文章目录 当 Nginx 出现请求的重复提交&#xff0c;如何处理&#xff1f;一、理解请求重复提交的来龙去脉二、请求重复提交可能带来的麻烦三、解决方案之“一夫当关”——…

文件包涵条件竞争(ctfshow82)

Web82 利用 session.upload_progress 包含文件漏洞 <!DOCTYPE html> <html> <body> <form action"https://09558c1b-9569-4abd-bf78-86c4a6cb6608.challenge.ctf.show//" method"POST" enctype"multipart/form-data"> …

【YashanDB知识库】yasdb jdbc驱动集成BeetISQL中间件,业务(java)报autoAssignKey failure异常

问题现象 BeetISQL中间件版本&#xff1a;2.13.8.RELEASE 客户在调用BeetISQL提供的api向yashandb的表中执行batch insert并将返回sequence设置到传入的java bean时&#xff0c;报如下异常&#xff1a; 问题的风险及影响 影响业务流程正常执行&#xff0c;无法获得batch ins…

【BUG】已解决:IndexError: positional indexers are out-of-bounds

IndexError: positional indexers are out-of-bounds 目录 IndexError: positional indexers are out-of-bounds 【常见模块错误】 【解决方案】 原因分析 解决方法 示例代码 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页&#xff0c;我是博…

HarmonyOS入门-状态管理

View(UI)&#xff1a;UI渲染&#xff0c;指将build方法内的UI描述和Builder装饰的方法内的UI描述映射到界面。 State&#xff1a;状态&#xff0c;指驱动UI更新的数据。用户通过触发组件的事件方法&#xff0c;改变状态数据。状态数据的改变&#xff0c;引起UI的重新渲染。 装…

2024权益商城系统网站源码

2024权益商城系统源码&#xff0c;支持多种支付方式&#xff0c;后台商品管理&#xff0c;订单管理&#xff0c;串货管理&#xff0c;分站管理&#xff0c;会员列表&#xff0c;分销日志&#xff0c;应用配置。 上传到服务器&#xff0c;修改数据库信息&#xff0c;导入数据库…

四、GD32 MCU 常见外设介绍 (7) 7.I2C 模块介绍

7.1.I2C 基础知识 I2C(Inter-Integrated Circuit)总线是一种由Philips公司开发的两线式串行总线&#xff0c;用于内部IC控制的具有多端控制能力的双线双向串行数据总线系统&#xff0c;能够用于替代标准的并行总线&#xff0c;连接各种集成 电路和功能模块。I2C器件能够减少电…

deepin深度操作系统安装教程(完整安装步骤·详细图文教程)

官方下载教程 一、概述 如果您首次使用deepin ISO镜像文件来安装deepin系统&#xff0c;无论您之前是否有安装过Windows电脑系统或者Debian、Ubuntu等其他Linux发行版桌面操作系统&#xff0c;我们都建议您先阅读本文档再安装。安装时&#xff0c;您可以选择只安装deepin系统…

Angular由一个bug说起之八:实践中遇到的一个数据颗粒度的问题

互联网产品离不开数据处理&#xff0c;数据处理有一些基本的原则包括&#xff1a;准确性、‌完整性、‌一致性、‌保密性、‌及时性。‌ 准确性&#xff1a;是数据处理的首要目标&#xff0c;‌确保数据的真实性和可靠性。‌准确的数据是进行分析和决策的基础&#xff0c;‌因此…

思维(交互题),CF 1990E2 - Catch the Mole(Hard Version)

一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 E2 - Catch the Mole(Hard Version) 二、解题报告 1、思路分析 考虑每次误判都会让鼹鼠上升一层&#xff0c;相应的&#xff0c;最外层的一层结点都没用了 由于数据范围为5000&#xff0c;我们随便找个叶子…