【PyTorch】图像目标检测

图像目标检测是什么

Object Detection
判断图像中目标位置

目标检测两要素

  1. 分类:分类向量 [p0, …, pn]
  2. 回归:回归边界框 [x1, y1, x2, y2]

模型如何完成目标检测

将3D张量映射到两个张量

  1. 分类张量:shape为 [N, c+1]
  2. 边界框张量:shape为 [N, 4]

Recent Advances in Deep Learning for Object Detection
在这里插入图片描述

边界框数量N如何确定?

传统方法——滑动窗策略

缺点:

  1. 重复计算量大
  2. 窗口大小难确定

利用卷积减少重复计算
重要概念:特征图一个像素对应原图一块区域

目标检测模型简介

按流程分为:one-stage和two-stage
在这里插入图片描述
在这里插入图片描述

Faster RCNN

经典two stage检测网络

Faster RCNN 数据流

  1. Feature map:[256, h_f, w_f]
  2. 2 Softmax:[num_anchors, h_f, w_f]
  3. Regressors:[num_anchors*4, h_f, w_f]
  4. NMS OUT:[n_proposals=2000, 4]
  5. ROI Layer:[512, 256, 7, 7]
  6. FC1 FC2:[512, 1024]
  7. c+1 Softmax: [512, c+1]
  8. Regressors:[512, (c+1)*4]

Faster RCNN 主要组件

  1. backbone
  2. rpn
  3. filter_proposals(NMS)
  4. roi_heads
    在这里插入图片描述

Faster RCNN 行人检测

数据: PennFudanPed数据集, 70张行人照片共345行人标签
官方地址: http://www.cis.upenn.edu/~jshi/ped_html/
模型: fasterrcnn_resnet50_fpn 进行finetune
目标检测推荐github:https://github.com/amusi/awesome-object-detection

代码如下:

import os
import time
import torch.nn as nn
import torch
import random
import numpy as np
import torchvision.transforms as transforms
import torchvision
from PIL import Image
import torch.nn.functional as F
from my_dataset import PennFudanDataset
from common_tools import set_seed
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.transforms import functional as F
import enviromentsset_seed(1)  # 设置随机种子BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# classes_coco
COCO_INSTANCE_CATEGORY_NAMES = ['__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus','train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign','parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow','elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A','handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball','kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket','bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl','banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza','donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table','N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone','microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book','clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]def vis_bbox(img, output, classes, max_vis=40, prob_thres=0.4):fig, ax = plt.subplots(figsize=(12, 12))ax.imshow(img, aspect='equal')out_boxes = output_dict["boxes"].cpu()out_scores = output_dict["scores"].cpu()out_labels = output_dict["labels"].cpu()num_boxes = out_boxes.shape[0]for idx in range(0, min(num_boxes, max_vis)):score = out_scores[idx].numpy()bbox = out_boxes[idx].numpy()class_name = classes[out_labels[idx]]if score < prob_thres:continueax.add_patch(plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], fill=False,edgecolor='red', linewidth=3.5))ax.text(bbox[0], bbox[1] - 2, '{:s} {:.3f}'.format(class_name, score), bbox=dict(facecolor='blue', alpha=0.5),fontsize=14, color='white')plt.show()plt.close()class Compose(object):def __init__(self, transforms):self.transforms = transformsdef __call__(self, image, target):for t in self.transforms:image, target = t(image, target)return image, targetclass RandomHorizontalFlip(object):def __init__(self, prob):self.prob = probdef __call__(self, image, target):if random.random() < self.prob:height, width = image.shape[-2:]image = image.flip(-1)bbox = target["boxes"]bbox[:, [0, 2]] = width - bbox[:, [2, 0]]target["boxes"] = bboxreturn image, targetclass ToTensor(object):def __call__(self, image, target):image = F.to_tensor(image)return image, targetif __name__ == "__main__":# configLR = 0.001num_classes = 2batch_size = 1start_epoch, max_epoch = 0, 5train_dir = enviroments.pennFudanPed_data_dirtrain_transform = Compose([ToTensor(), RandomHorizontalFlip(0.5)])# step 1: datatrain_set = PennFudanDataset(data_dir=train_dir, transforms=train_transform)# 收集batch data的函数def collate_fn(batch):return tuple(zip(*batch))train_loader = DataLoader(train_set, batch_size=batch_size, collate_fn=collate_fn)# step 2: modelmodel = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)in_features = model.roi_heads.box_predictor.cls_score.in_featuresmodel.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) # replace the pre-trained head with a new onemodel.to(device)# step 3: loss# in lib/python3.6/site-packages/torchvision/models/detection/roi_heads.py# def fastrcnn_loss(class_logits, box_regression, labels, regression_targets)# step 4: optimizer schedulerparams = [p for p in model.parameters() if p.requires_grad]optimizer = torch.optim.SGD(params, lr=LR, momentum=0.9, weight_decay=0.0005)lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)# step 5: Iterationfor epoch in range(start_epoch, max_epoch):model.train()for iter, (images, targets) in enumerate(train_loader):images = list(image.to(device) for image in images)targets = [{k: v.to(device) for k, v in t.items()} for t in targets]# if torch.cuda.is_available():#     images, targets = images.to(device), targets.to(device)loss_dict = model(images, targets)  # images is list; targets is [ dict["boxes":**, "labels":**], dict[] ]losses = sum(loss for loss in loss_dict.values())print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} ".format(epoch, max_epoch, iter + 1, len(train_loader), losses.item()))optimizer.zero_grad()losses.backward()optimizer.step()lr_scheduler.step()# testmodel.eval()# configvis_num = 5vis_dir = os.path.join(BASE_DIR, "..", "..", "data", "PennFudanPed", "PNGImages")img_names = list(filter(lambda x: x.endswith(".png"), os.listdir(vis_dir)))random.shuffle(img_names)preprocess = transforms.Compose([transforms.ToTensor(), ])for i in range(0, vis_num):path_img = os.path.join(vis_dir, img_names[i])# preprocessinput_image = Image.open(path_img).convert("RGB")img_chw = preprocess(input_image)# to deviceif torch.cuda.is_available():img_chw = img_chw.to('cuda')model.to('cuda')# forwardinput_list = [img_chw]with torch.no_grad():tic = time.time()print("input img tensor shape:{}".format(input_list[0].shape))output_list = model(input_list)output_dict = output_list[0]print("pass: {:.3f}s".format(time.time() - tic))# visualizationvis_bbox(input_image, output_dict, COCO_INSTANCE_CATEGORY_NAMES, max_vis=20, prob_thres=0.5)  # for 2 epoch for nms

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1551877.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

索尼MDR-M1:超宽频的音频盛宴,打造沉浸式音乐体验

在音乐的世界里&#xff0c;每一次技术的突破都意味着全新的听觉体验。 索尼&#xff0c;作为音频技术的先锋&#xff0c;再次以其最新力作——MDR-M1封闭式监听耳机&#xff0c;引领了音乐界的新潮流。 这款耳机以其超宽频播放和卓越的隔音性能&#xff0c;为音乐爱好者和专…

深蕾半导体参加2024年度上海设计100+全球竞赛展览WDCC

展览介绍 WDCC2024 上海于2010年加入联合国教科文组织“创意城市网络”&#xff0c;定名为“设计之都”。“上海设计100”全球竞赛&#xff0c;遴选推广优秀设计案例&#xff0c;将“设计之都”的规划和愿景具体呈现。 ——展出时间、地点见文末—— 深蕾参展 深圳前海深蕾…

初识Linux · 进程等待

目录 前言&#xff1a; 进程等待是什么 为什么需要进程等待 进程等待都在做什么 前言&#xff1a; 通过上文的学习&#xff0c;我们了解了进程终止&#xff0c;知道终止是在干什么&#xff0c;终止的三种情况&#xff0c;以及有了退出码&#xff0c;错误码的概念&#xff…

Python | Leetcode Python题解之第448题找到所有数组中消失的数字

题目&#xff1a; 题解&#xff1a; class Solution:def findDisappearedNumbers(self, nums: List[int]) -> List[int]:n len(nums)for num in nums:x (num - 1) % nnums[x] nret [i 1 for i, num in enumerate(nums) if num < n]return ret

【RocketMQ】秒杀设计与实现

&#x1f3af; 导读&#xff1a;本文档详细探讨了高并发场景下的秒杀系统设计与优化策略&#xff0c;特别是如何在短时间内处理大量请求。文档分析了系统性能指标如QPS&#xff08;每秒查询率&#xff09;和TPS&#xff08;每秒事务数&#xff09;&#xff0c;并通过实例讲解了…

鸿蒙开发(NEXT/API 12)【申请接入Wear Engine服务】 穿戴服务

申请Wear Engine服务前&#xff08;开发者需实名认证为个人开发者或者企业开发者&#xff0c;认证前&#xff0c;请先了解二者的[权益区别] &#xff09;&#xff0c;确认开发环境并完成创建项目、创建HarmonyOS应用等基本准备工作&#xff0c;再继续进行以下开发活动。 进入华…

JVM(HotSpot):字符串常量池(StringTable)

文章目录 一、内存结构图二、案例讲解三、总结 一、内存结构图 JDK1.6 JDK1.8 我们发现&#xff0c;StringTable移入了Heap里面。所以&#xff0c;应该想到&#xff0c;StringTable将受到GC管理。 其实&#xff0c;1.6中&#xff0c;在方法区中的时候&#xff0c;也是受GC管…

Android Studio 新版本 Logcat 的使用详解

点击进入官方Logcat介绍 一个好的Android程序员要会使用AndroidStudio自带的Logcat查看日志&#xff0c;会Log定位也是查找程序bug的第一关键。同时Logcat是一个查看和处理日志消息的工具&#xff0c;它可以更快的帮助开发者调试应用程序。 步入正题&#xff0c;看图说话。 点…

Linux 之 IO模型

IO的本质是基于操作系统接口来控制底层的硬件之间数据传输&#xff0c;并且在操作系统中实现了多种不同的IO方式&#xff08;模型&#xff09;&#xff0c;比较常见的有下列三种 阻塞型IO模型 非阻塞型IO模型 多路复用IO模型 一、阻塞与非阻塞IO 一般默认的 IO 操作都是阻塞…

在Linux中进行OpenSSH升级(编译安装在openssh目录)

由于OpenSSH有严重漏洞&#xff0c;因此需要升级OpenSSH到最新版本。 注意&#xff1a;在OpenSSH升级过程中千万不要断开服务器连接&#xff0c;不然的话&#xff0c;会出现断开后连接不了服务器的情况。 第一步&#xff0c;查看当前的OpenSSH服务版本。 命令&#xff1a;ss…

DataEase v2 开源代码 Windows 从0到1环境搭建

一、环境准备 功能名称 描述 其它 操作系统 Windows 数据库 Mysql8.0 开发环境 JDK17以上 本项基于的21版本开发 Maven 3.9版本 开发工具 idea2024.2版本 前端 VSCode TIPS&#xff1a;如果你本地有jdk8版本&#xff0c;需要切换21版本&#xff0c;请看…

C语言 | Leetcode C语言题解之第448题找到所有数组中消失的数字

题目&#xff1a; 题解&#xff1a; int* findDisappearedNumbers(int* nums, int numsSize, int* returnSize) {for (int i 0; i < numsSize; i) {int x (nums[i] - 1) % numsSize;nums[x] numsSize;}int* ret malloc(sizeof(int) * numsSize);*returnSize 0;for (in…

遥感图像文本检索

遥感图像文本检索是一种通过自然语言描述&#xff0c;从大量遥感图像中搜索与之相关的图像的技术。它用于遥感解释任务中&#xff0c;帮助用户根据文字描述快速找到符合条件的遥感图像&#xff0c;这在城市规划、环境监测、灾害管理等领域具有重要应用意义。 实现这一技术的核…

线路交换与分组交换的深度解析

1. 线路交换 原理 线路交换是一种在通信双方之间建立固定通信路径的方式。当用户发起通信时&#xff0c;网络为其分配一条专用的物理通道&#xff0c;这条通道在整个通话过程中保持不变。这意味着在通话期间&#xff0c;其他用户无法使用这条线路。 优点 稳定性&#xff1a…

记录一次出现循环依赖问题

具体的结构设计&#xff1a; 在上面的图片中&#xff1a; UnboundBlackVerifyChain类中继承了UnboundChain类。但是UnboundChain类中注入了下面三个类。 Scope(“prototype”) UnboundLinkFlowCheck类 Scope(“prototype”) UnboundUserNameCheck类 Scope(“prototype”) Un…

【刷题6】一维前缀和、二维前缀和

目录 一、一维前缀和二、二维前缀和 一、一维前缀和 题目&#xff1a; 思路&#xff1a; 一、前缀和&#xff0c;时间复杂度O&#xff08;1&#xff09;&#xff0c;快速得到区间的值 二、预处理&#xff0c;公式——dp[i] dp[i-1] arr[i] 三、使用前缀和&#xff0c;根据…

VUE a-table 动态拖动修改列宽+固定列

实现效果 实现思路 自定义表头&#xff0c;在标题后面加两个标签&#xff0c;分别用来显示拖拽图标&#xff08;cursor: col-resize&#xff09;&#xff0c;和蓝色标记线&#xff08;有的时候鼠标移动过程中不一定会在表内&#xff0c;这个时候不显示图标&#xff0c;只显示蓝…

综合练习 学习案例

//验证码 前四位是字母 最后一位是数字 public class test1 {public static void main(String[] args){char [] charsnew char[52];for (int i 0; i <chars.length ; i) {if(i<25){chars[i](char)(i97);}else{chars[i](char)(i65-26);}}Random rnew Random();String cod…

828华为云征文|华为云Flexus云服务器X实例部署 即时通讯IM聊天交友软件——高性能服务器实现120W并发连接

营运版的即时通讯IM聊天交友系统&#xff1a;特点可发红包&#xff0c;可添加多条链接到用户网站和应用&#xff0c;安卓苹果APPPC端H5四合一 后端开发语言&#xff1a;PHP&#xff0c; 前端开发语言&#xff1a;uniapp混合开发。 集安卓苹果APPPC端H5四合一APP源码&#xff0…

语音转文字免费利器:助力高效办公与学习

语音转文字免费的软件如同一股清流&#xff0c;让我们能够更轻松地将语音信息转化为可编辑的文字内容。今天我们一起来分析它们的功能、特点以及如何为我们的生活和工作带来便利。 1.365在线转文字 链接直达&#xff1a;https://www.pdf365.cn/ 这是一个功能强大的在线工具…