【深度学习】解析Vision Transformer (ViT): 从基础到实现与训练

之前介绍:

https://qq742971636.blog.csdn.net/article/details/132061304

文章目录

  • 背景
      • 实现代码示例
      • 解释
  • 训练
      • 数据准备
      • 模型定义
      • 训练和评估
      • 总结

在这里插入图片描述

Vision Transformer(ViT)是一种基于transformer架构的视觉模型,它最初是由谷歌研究团队在论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》中提出的。ViT将图像分割成固定大小的patches(例如16x16),并将每个patch视为一个词(类似于NLP中的单词)进行处理。以下是ViT的详细讲解:

背景

在计算机视觉领域,传统的卷积神经网络(CNNs)一直是处理图像的主流方法。然而,CNNs存在一些局限性,如在处理长距离依赖关系时表现不佳。ViT引入了transformer架构,通过全局注意力机制,有效地处理图像中的长距离依赖关系。

实现代码示例

ViT代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeatclass PatchEmbedding(nn.Module):def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):super().__init__()self.img_size = img_sizeself.patch_size = patch_sizeself.grid_size = img_size // patch_sizeself.num_patches = self.grid_size ** 2self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)def forward(self, x):x = self.proj(x)  # [B, embed_dim, H, W]x = x.flatten(2)  # [B, embed_dim, num_patches]x = x.transpose(1, 2)  # [B, num_patches, embed_dim]return xclass Attention(nn.Module):def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):super().__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)def forward(self, x):B, N, C = x.shapeqkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]attn = (q @ k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)x = self.proj_drop(x)return xclass MLP(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = nn.GELU()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass Block(nn.Module):def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.):super().__init__()self.norm1 = nn.LayerNorm(dim)self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)self.drop_path = nn.Identity() if drop_path == 0 else nn.Dropout(drop_path)self.norm2 = nn.LayerNorm(dim)self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio), drop=drop)def forward(self, x):x = x + self.drop_path(self.attn(self.norm1(x)))x = x + self.drop_path(self.mlp(self.norm2(x)))return xclass VisionTransformer(nn.Module):def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):super().__init__()self.num_classes = num_classesself.num_features = self.embed_dim = embed_dimself.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)num_patches = self.patch_embed.num_patchesself.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))self.pos_drop = nn.Dropout(p=drop_rate)dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]self.blocks = nn.ModuleList([Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i])for i in range(depth)])self.norm = nn.LayerNorm(embed_dim)self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()nn.init.trunc_normal_(self.pos_embed, std=0.02)nn.init.trunc_normal_(self.cls_token, std=0.02)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):nn.init.trunc_normal_(m.weight, std=0.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, nn.LayerNorm):nn.init.ones_(m.weight)nn.init.zeros_(m.bias)def forward(self, x):B = x.shape[0]x = self.patch_embed(x)cls_tokens = self.cls_token.expand(B, -1, -1)x = torch.cat((cls_tokens, x), dim=1)x = x + self.pos_embedx = self.pos_drop(x)for blk in self.blocks:x = blk(x)x = self.norm(x)cls_token_final = x[:, 0]x = self.head(cls_token_final)return x# 示例输入
img = torch.randn(1, 3, 224, 224)
model = VisionTransformer()
output = model(img)
print(output.shape)  # 输出大小为 [1, 1000]

解释

  1. PatchEmbedding:将输入图像分割为不重叠的patches,并通过卷积操作将其转换为embedding。
  2. Attention:实现自注意力机制。
  3. MLP:实现多层感知器(MLP),包括GELU激活函数和Dropout。
  4. Block:包含一个注意力层和一个MLP层,每层都有残差连接和层归一化。
  5. VisionTransformer:组合上述模块,形成完整的ViT模型。包含位置嵌入和分类头。

训练

为了在GPU上训练ViT模型,你可以使用PyTorch中的DataLoader来处理数据,并确保模型和数据都在GPU上。以下是一个详细的代码示例,包括数据准备、模型定义、训练和评估。

数据准备

假设你的数据结构如下:

dataset/class1/img1.jpgimg2.jpg...class2/img1.jpgimg2.jpg......

你可以使用 torchvision.datasets.ImageFolder 来加载数据。

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm# 数据转换和增强
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加载数据
data_dir = 'dataset'
train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=transform)
val_dataset = datasets.ImageFolder(os.path.join(data_dir, 'val'), transform=transform)train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)# 获取类别数
num_classes = len(train_dataset.classes)

模型定义

定义ViT模型并将其移动到GPU上。

# VisionTransformer定义(使用上面的定义)
model = VisionTransformer(num_classes=num_classes).cuda()# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)# 如果有多个GPU,使用DataParallel
if torch.cuda.device_count() > 1:model = nn.DataParallel(model)

训练和评估

定义训练和评估函数,并进行训练。

def train_one_epoch(model, criterion, optimizer, data_loader, device):model.train()running_loss = 0.0running_corrects = 0for inputs, labels in tqdm(data_loader):inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(data_loader.dataset)epoch_acc = running_corrects.double() / len(data_loader.dataset)return epoch_loss, epoch_accdef evaluate(model, criterion, data_loader, device):model.eval()running_loss = 0.0running_corrects = 0with torch.no_grad():for inputs, labels in data_loader:inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(data_loader.dataset)epoch_acc = running_corrects.double() / len(data_loader.dataset)return epoch_loss, epoch_acc# 训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_epochs = 25for epoch in range(num_epochs):train_loss, train_acc = train_one_epoch(model, criterion, optimizer, train_loader, device)val_loss, val_acc = evaluate(model, criterion, val_loader, device)print(f'Epoch {epoch}/{num_epochs - 1}')print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}')print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')# 保存模型
torch.save(model.state_dict(), 'vit_model.pth')

总结

这段代码展示了如何使用PyTorch在GPU上训练Vision Transformer模型。包括数据加载、模型定义、训练和评估步骤。请根据你的实际需求调整批量大小、学习率和训练轮数等参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1453050.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

决策树算法:揭示数据背后的决策逻辑

目录 一 决策树算法原理 特征选择 信息增益 信息增益比 基尼指数 树的构建 树的剪枝 预剪枝 后剪枝 二 决策树算法实现 一 使用决策树进行分类 数据预处理 构建决策树模型 二 使用决策树进行回归 数据预处理 构建决策树回归模型 三 决策树算法的优缺点 优点 …

政务云参考技术架构

行业优势 总体架构 政务云平台技术框架图,由机房环境、基础设施层、支撑软件层及业务应用层组成,在运维、安全和运营体系的保障下,为政务云使用单位提供统一服务支撑。 功能架构 标准双区隔离 参照国家电子政务规范,打造符合标准的…

Python编程环境搭建

简介: Python环境安装比较简单,无需安装其它依赖环境,主要步骤为: 1. 下载并安装Python对应版本解释器 2. 下载并安装一个ide编码工具 一、下载并安装Python解释器 1.1 下载 官网地址:Welcome to Python.org 选择…

2023年13个最适合销售电子书的WordPress主题

欢迎来到我们用于销售电子书和其他数字/可下载产品(软件、应用程序、图标集、主题等)的最佳WordPress主题的完整集合。 这些主题有内置的支付网关,可以通过 PayPal、信用卡等处理安全支付。(易于配置!) 最…

生产中的 RAG:使你的生成式 AI 项目投入运营

作者:来自 Elastic Tim Brophy 检索增强生成 (RAG) 为组织提供了一个采用大型语言模型 (LLM) 的机会,即通过将生成式人工智能 (GenAI) 功能应用于其自己的专有数据。使用 RAG 可以降低固有风险,因为我们依赖受控数据集作为模型答案的基础&…

四. TensorRT模型部署优化-quantization(quantization granularity)

目录 前言0. 简述1. 量化粒度2. Per-tensor & Per-channel量化3. 量化粒度选择的推荐方法总结 前言 自动驾驶之心推出的 《CUDA与TensorRT部署实战课程》,链接。记录下个人学习笔记,仅供自己参考 本次课程我们来学习课程第四章—TensorRT 模型部署优…

[面试题]Spring Boot

[面试题]Java【基础】[面试题]Java【虚拟机】[面试题]Java【并发】[面试题]Java【集合】[面试题]MySQL[面试题]Maven[面试题]Spring Boot[面试题]Spring Cloud[面试题]Spring MVC Spring Boot 涉及到的知识点很多,在内容上,我们会分成两大块&#xff1a…

day38-39| 509. 斐波那契数 70. 爬楼梯 746. 使用最小花费爬楼梯 62.不同路径 343. 整数拆分 96.不同的二叉搜索树

文章目录 前言动态规划理论基础509. 斐波那契数思路方法一 完整动态规划方法二 dp简化版方法三 使用递归 70. 爬楼梯思路方法一 动态规划方法一2 教程里面的简化方法方法二 拓展 746. 使用最小花费爬楼梯思路方法一方法二 拓展 62.不同路径思路 动态规划方法一方法二 递归 63. …

POC EXP | woodpecker插件编写

woodpecker插件编写 目录 woodpecker介绍woodpecker使用插件编写 安装环境 woodpecker-sdkwoodpecker-request 创建Maven项目 Confluence OGNL表达式注入漏洞插件编写 创建Package包和Class类编写POC 漏洞POC代码编写导出jar包将jar包放入woodpecker的plugin目录运行woodpeck…

我主编的电子技术实验手册(07)——串联电路

本专栏是笔者主编教材(图0所示)的电子版,依托简易的元器件和仪表安排了30多个实验,主要面向经费不太充足的中高职院校。每个实验都安排了必不可少的【预习知识】,精心设计的【实验步骤】,全面丰富的【思考习…

宿舍用电管理模块一进三出的升级改造

宿舍用电管理模块一进三出石家庄光大远通电气有限公司产品在高校日常管理工作中,宿舍管理是一项重要工作。宿舍管理内容复杂,而且涉及学生的日常生活,意义重大。其中,学生宿舍内漏电,超负荷用电,违规用电等现象一直是困扰后勤管理的普遍问题。随着学生日常生活方式以及生活用品…

【Pandas驯化-04】Pandas中drop_duplicates、describe、翻转操作

【Pandas驯化-04】Pandas中drop_duplicates、describe、翻转操作 本次修炼方法请往下查看 🌈 欢迎莅临我的个人主页 👈这里是我工作、学习、实践 IT领域、真诚分享 踩坑集合,智慧小天地! 🎇 相关内容文档获取 微信公…

GPTZero:引领AI内容检测

随着人工智能技术的飞速发展,AI生成内容(AIGC)正在迅速改变我们获取和消费信息的方式。然而,AIGC的激增也带来了一系列挑战,尤其是在内容真实性和版权方面。正是在这样的背景下,一家由00后团队创立的公司——GPTZero,以其独特的AI检测工具,迅速崛起为行业的领军者。 一…

【LeetCode215】数组中的第K个最大元素

题目地址 1. 基本思路 用一个基准数e将集合S分解为不包含e在内的两个小集合 S 1 S_{1} S1​和 S 2 S_{2} S2​,其中 S 1 S_{1} S1​的任何元素均大于等于e, S 2 S_{2} S2​的任何元素均小于e,记 ∣ S ∣ |S| ∣S∣代表集合S元素的个数&…

JDBC常见的几种连接池使用(C3P0、Druid、HikariCP 、DBCP)

前言 数据库连接池负责分配、管理和释放数据库连接,它允许应用程序重复使用一个现有的数据库连接,而不是重新建立一个。连接池技术尽可能多地重用了消耗内存的资源,大大节省了内存。通过使用连接池,将大大提高程序运行效率。常用的…

Java——构造器(构造方法)和 this

一、什么是构造器 构造器(Constructor)是Java类的一种特殊方法,用于初始化对象的状态。构造器在创建对象时被调用,可以对对象的成员变量进行初始化。 我之前的文章《Java——类和对象-CSDN博客》中也提到了构造器。 二、构造器…

【面试题】MySQL常见面试题总结

备战实习,会定期给大家整理常考的面试题,大家一起加油! 🎯 系列文章目录 【面试题】面试题分享之JVM篇【面试题】面试题分享之Java并发篇【面试题】面试题分享之Java集合篇(三) 注意:文章若有错…

AI办公自动化:批量在多个Word文档中插入对应图片

工作任务:文件夹中有多个word文档和word文档名称一致的图片,要把这些图片都插入到word文档中 在chatpgt中输入提示词: 你是一个Python编程专家,写一个Python脚本,具体步骤如下: 打开文件夹:F:…

蓝队-溯源技巧

溯源技巧 大致思想 通常情况下,接到溯源任务时,获得的信息如下 攻击时间 攻击 IP 预警平台 攻击类型 恶意文件 受攻击域名/IP其中攻击 IP、攻击类型、恶意文件、攻击详情是溯源入手的点。 通过攻击类型分析攻击详情的请求包,看有没有攻击者…

零基础入门学用Arduino 第三部分(二)

重要的内容写在前面: 该系列是以up主太极创客的零基础入门学用Arduino教程为基础制作的学习笔记。个人把这个教程学完之后,整体感觉是很好的,如果有条件的可以先学习一些相关课程,学起来会更加轻松,相关课程有数字电路…