基于yolov5的ignore classes训练

本文提到的忽略类别和检测中的忽略类别不一样,前者是在训练中加入忽略类,后者是在检测中仅检测想要的类。

ignore class的定义

我们在标注数据集的时候都是标注的正样本,训练过程中也是这样训练,让网络对正样本计算loss。但我们也遇到过这样的目标,这个目标即不属于正样本,也不属于负样本,比如正样本是person,那么人形雕塑或者人的影子,这类物体他并不是正样本,但如果直接归为负样本也是不严谨的,因此就可以将这类物体标注为“忽略类”,这类物体在标注的时候是有标签和box信息的,比如给这类物体的标签是"-2"。又或者说在训练中,希望对某些像素大小目标设置为ignore,比如训练中忽略20x20的目标

ignore训练

ignore训练并不是说把该类目标丢弃,如果你是把该类目标从label中删除,那么不就相当于把这个目标作为负样本进行训练了嘛,这是不对的。

但如果将ignore作为一个正常的类进行训练,参与loss,那么相当于将该类作为正样本进行训练了。因此这样也是不对的。我们需要时刻记得ignore class既不是正样本,也不是负样本,是一种夹在两者之间的样本

那么问题就是如何正确对待ignore class

对于ignore class的训练,其实是希望不对其进行反向传播,仅前向传播即可。这是因为不进行反向传播就不用当正样本进行学习,而且也不会因为直接丢弃而变成负样本。

想法很简单,但做起来还是有难度的。首先第一个问题就是如果做ignore class的样本匹配。我最初的想法是在做样本匹配的时候将忽略类样本与正样本分开,比如下面的代码,通过对class_id进行类别的筛选。【在yolov5的utils/loss.py中的build_targets函数中修改】

这里的t和t_ign shape含义是一样的,【image_id,class_id,x,y,w,h,anchor_id】

t = targets * gain  # 将Box缩放到对应的特征层上
t_ign = t[t[..., 1] < 0].view(t.shape[0], -1, t.shape[2])  # 忽略类
t = t[t[..., 1] >= 0].view(t.shape[0], -1, t.shape[2])  # 正样本
image_id = torch.as_tensor(t_ign[..., 0])  # [3, num_classes] 记录每个anchor对应的image id

然后按原yolov5的方法对这两个部分分别进行anchor的宽高匹配,然后仅把正样本的结果送入后面的三个loss计算。但这是有问题的,因为在yolov5中anchor是通过宽高匹配,虽然可以分别对两个部分进行匹配,但如何判断忽略类的匹配到的anchor也会被正样本匹配到呢?因为后面需要进行一个去重工作,就是将忽略类与正样本匹配到的相同anchor去除掉

第二个问题就是,针对ignore class的训练,loss部分怎么处理?我最初想这个问题的时候,就仅仅是觉得不需要对三个loss进行处理,这显然不对,如果不对loss进行处理,那前面只做样本匹配其实是没有意义的。

ignore class样本匹配

针对第一个问题

为了可以获得更多的anchor框与ignore class的匹配,以减小对于正样本的影响,我们需要做的是在当前特征图上生成网格,并将特征图所有网格中的anchor box与ignore class进行匹配,你可以用iou,但我这里用的是ioa匹配。

在特征层上生成网格并生成anchor box代码如下:

                    # 生成网格,anchor中心grid_y, grid_x = torch.meshgrid([torch.arange(p[i].shape[3],device=targets.device), torch.arange(p[i].shape[2], device=targets.device)])grid = torch.stack((grid_x, grid_y), 2).view((1, 1, p[i].shape[3], p[i].shape[2], 2)).float()# plot_anchor(grid_x, grid_y, gain)# 在网格上生成anchoranchors_boxes = torch.zeros(3, p[i].shape[3], p[i].shape[2], 4)  # [3,80,80,4]anchors_boxes[..., :2] = grid[..., :2]  # 给所有anchor分配中心点anchors_boxes[0, :, :, 2:] = anchors[0, :]anchors_boxes[1, :, :, 2:] = anchors[1, :]anchors_boxes[2, :, :, 2:] = anchors[2, :]  # 将所有anchor的w,h传入# 此刻的anchors_boxes为所有batch下三种anchor对应的x,y,w,h shape[batch_size,3,feat_w,feat_h,4]# xywh->x1y1x2y2an_box = torch.zeros_like(anchors_boxes)  # [3,80,80,4]an_box[..., :2] = anchors_boxes[..., :2] - anchors_boxes[..., 2:] / 2an_box[..., 2:] = anchors_boxes[..., :2] + anchors_boxes[..., 2:] / 2

获得t_ign的所有box,用于后面的anchor匹配,获得box代码:

                    # 获得t_ign的box框center_xy = t_ign[..., 2:4]  # 取gt box的中心点gt_wh = t_ign[..., 4:6]  # 取gt的w和ht_ign_box = torch.zeros(3, t_ign.shape[1], 4)  # 用于存储框shape[3,num_obj,4]t_ign_box[..., :2] = center_xy - gt_wh / 2  # 左上角t_ign_box[..., 2:] = center_xy + gt_wh / 2  # 右下角

 然后是遍历当前head中所有cell中的anchor box与忽略类的box进行ioa匹配。image_id是之前定义的用于记录anchor对应的目标image id[或者说是batch id]。因为在前面我们已经给t_ign中的每个目标均分配了三种anchor,而且给每个目标记录了所在的image id【t_ign[...,0]就是image id,最后一个维度是anchor id了】。anchor_id用于记录匹配到的anchor id【每个head上设置了三种anchor】,ign_gi和ign_gj是记录满足ioa阈值的anchor的中心点坐标。

                    anchor_id = []ign_gi = []ign_gj = []image_id_list = []#plot_ign_cls_box(t_ign_box, gain)for an_idx in range(an_box.shape[0]):  # 遍历3种anchorfor i in range(an_box.shape[1]):  # 遍历每个网格获得每个anchor的box 行遍历for j in range(an_box.shape[2]):  # 列遍历# 将所有目标box与grid中的所有anchor计算ioa,输入shape[num_obj],得到每个目标ioaioa = compute_ioa(box1=t_ign_box[an_idx], box2=an_box[an_idx, i, j])mask = ioa > ioa_threif mask.shape[0] and torch.max(mask):  # 所有batch中表示匹配到image_id_list.append(image_id[an_idx][mask])anchor_id.append(an_idx)ign_gi.append(j)  # x坐标ign_gj.append(i)  # y坐标

最后将上面几个列表存储在ign_match_anchor列表中。shape为【batch,anchor,x,y】。

ign_match_anchor.append((torch.IntTensor([val[0] for val in image_id_list]),torch.IntTensor(anchor_id),torch.IntTensor(ign_gj),torch.IntTensor(ign_gi)))

通过设置的ioa阈值的不同,匹配到的anchor数量也不同。下图分别是阈值为0.5和0.2时在80x80特征图上匹配示意图,绿色框为ignore class的gt,红色框为anchor box。阈值为>0.2的有更多的ign_class被anchor匹配到。

 

ioa阈值>0.5
ioa阈值>0.2

 

loss设计

代码中的indices记录了正样本的[image_id,anchor_id,x,y],ign_anchors则记录了与忽略类匹配到的【image_id,anchor_id,anchor_x,anchor_y】.

        if self.ignore_class:tcls, tbox, indices, anchors, ign_anchors = self.build_targets(p, targets)  # targetsfor i, pi in enumerate(p):  # layer index, layer predictions  pi shape [batch_size,3,feat_w,feat_h,5+classes]b, a, gj, gi = indices[i]  # image, anchor, gridy, gridxif self.ignore_class:b_ign, a_ign, gj_ign, gi_ign = ign_anchors[i]b_ign = b_ign.type_as(b)a_ign = a_ign.type_as(a)gj_ign = gj_ign.type_as(gj)gi_ign = gi_ign.type_as(gi)n_ign = b_ign.shape[0]tobj_ign = torch.zeros_like(pi[..., 0], device=device)tobj = torch.zeros_like(pi[..., 0], device=device)  # target obj

 在yolo中的三个损失函数中,box_loss,cls_loss其实对于忽略类可以不用,而obj_loss怎么算呢?因为通过前面的操作其实可以知道和忽略类匹配到的anchor框,然后我们需要在正样本以及预测pred中去重,也就是看正样本obj和pred哪些地方和忽略类anchor是重叠的,去除点,然后将这个像素点作为负样本【也就是将重复的坐标置0即可,这样该点就不反向传播了】。

                if self.ignore_class:tobj_ign[b_ign, a_ign, gj_ign, gi_ign] = 1# 去重mask = (tobj != 0) & (tobj_ign != 0)  # 相同位置不为0的地方tobj[mask] = 0  # 正样本与ign相同位置不为0的地方置0【正样本去重成功】pi[..., 4][mask] = 0  # 预测中正样本与ign相同位置不为0的地方置0【预测类正样本去重成功】obji = self.BCEobj(pi[..., 4], tobj)lobj += obji * self.balance[i]  # obj loss    

通过以上的样本匹配和loss设计就完成了忽略训练的设计。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/143757.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

五、C#—字符串

&#x1f33b;&#x1f33b; 目录 一、字符串1.1 字符类型1.2 转义字符1.3 字符串的声明及赋值1.3.1 c# 中的字符串1.3.2 声明字符串1.3.3 使用字符串1.3.4 字符串的初始化1.3.4.1 引用字符串常量之初始化1.3.4.2 利用字符数组初始化1.3.4.3 提取数组中的一部分进行初始化 1.3.…

R的一些奇奇怪怪的功能

1. 欧氏距离计算 df <- data.frame(x 1:10, y 1:10, row.names paste0("s", 1:10)) euro_dist <- as.matrix(dist(df))2. 集合运算 union(x, y) # 并集 intersect(x, y) # 交集 setdiff(x, y) # 只在x中存在&#xff0c;y中不存在的元素 setequal(x, y)…

利用Redis实现全局唯一ID

利用Redis实现全局唯一ID 背景 场景分析&#xff1a;如果我们的id具有太明显的规则&#xff0c;用户或者说商业对手很容易猜测出来我们的一些敏感信息&#xff0c;比如商城在一天时间内&#xff0c;卖出了多少单&#xff0c;这明显不合适。 场景分析二&#xff1a;随着我们商…

慢性疼痛治疗服务公司Kindly MD申请700万美元纳斯达克IPO上市

来源&#xff1a;猛兽财经 作者&#xff1a;猛兽财经 猛兽财经获悉,慢性疼痛治疗服务公司Kindly MD近期已向美国证券交易委员会&#xff08;SEC&#xff09;提交招股书&#xff0c;申请在纳斯达克IPO上市&#xff0c;股票代码为&#xff08;KDLY&#xff09;,Kindly MD计划通过…

PTA程序辅助实验平台——2023年软件设计综合实践_3(分支与循环)

第一题&#xff1a;7-1 印第安男孩 - C/C 分支与循环 朵拉编程的时候也想顺便练习英语。她编程从键盘读入一个整数n&#xff0c;如果n值为0或者1&#xff0c;向屏幕输出“0 indian boy.”或“1 indian boy.”&#xff1b;如果n大于1&#xff0c;比如9&#xff0c;则输出“9 in…

计算机图像处理:图像轮廓

图像轮廓 图像阈值分割主要是针对图片的背景和前景进行分离&#xff0c;而图像轮廓也是图像中非常重要的一个特征信息&#xff0c;通过对图像轮廓的操作&#xff0c;就能获取目标图像的大小、位置、方向等信息。画出图像轮廓的基本思路是&#xff1a;先用阈值分割划分为两类图…

性能测试 —— 性能测试常见的测试指标 !

一、什么是性能测试 先看下百度百科对它的定义&#xff0c;性能测试是通过自动化的测试工具模拟多种正常、峰值以及异常负载条件来对系统的各项性能指标进行测试。 我们可以认为性能测试是&#xff1a;通过在测试环境下对系统或构件的性能进行探测&#xff0c;用以验证在生产环…

mysql面试题3:谈谈你知道的MySQL索引?MySQL中一个表可以创建多少个列索引?MySQL索引有哪几种?他们的优缺点是什么?

该文章专注于面试,面试只要回答关键点即可,不需要对框架有非常深入的回答,如果你想应付面试,是足够了,抓住关键点 面试官:谈谈你知道的MySQL索引? MySQL索引是一种特殊的数据结构,用于加速数据库的查询操作。它通过存储列值和对应记录的指针,可以快速定位到满足查询…

如果只是用php纯做api的话,给移动端做数据接口,是否需要用php框架?

API接口对接是现代软件开发中不可或缺的一部分&#xff0c;它允许不同的应用程序之间进行数据交换和服务调用。在PHP中&#xff0c;可以使用多种方式实现API接口的对接&#xff0c;包括基于HTTP协议的传统方法以及现代的API客户端库客户端库客户端库等。 一、实现API接口的对接…

【React】组件实例三大属性state、props、refs

state React 把组件看成是一个状态机&#xff08;State Machines&#xff09;。通过与用户的交互&#xff0c;实现不同状态&#xff0c;然后渲染 UI&#xff0c;让用户界面和数据保持一致。 React 里&#xff0c;只需更新组件的 state&#xff0c;然后根据新的 state 重新渲染用…

运行在浏览器中的Domino Designer开发客户机

大家好&#xff0c;才是真的好。 首先讨论一个非常有意思的事情&#xff0c;就是有人问&#xff0c;如果我用很老的Lotus软件&#xff0c;它是免费的吗&#xff1f; 这估计代表了很多盆友的心声。但不太友好的是&#xff0c;即使你用很老的Lotus软件&#xff08;例如Notes R4…

百度搜索逐步恢复优质网站权限

我是卢松松&#xff0c;点点上面的头像&#xff0c;欢迎关注我哦&#xff01; 从9月25日开始&#xff0c;有越来越多的站长和卢松松反馈&#xff0c;说他们的站可以正常入驻百度搜索资源平台了。我也试了试卢松松博客&#xff0c;果然&#xff0c;可以正常提交了。还是以前的…

Redis 线程模式

Redis 是单线程吗&#xff1f; Redis 单线程指的是 [接收客户端请求 -> 解析请求 -> 进行数据读写操作 -> 发送数据给客户端] 这个过程是由一个线程 (主线程) 来完成的&#xff0c;这也是常说的 Redis 是单线程的原因。 但是 &#xff0c;Redis 程序不是单线程的&am…

已实现:关于富文本组件库vue2-editor的使用方法以及一些必要的注意事项,特别是设置完富文本以后的回显问题。以及在光标位置插入字符串等

前言 目前常见的基于vue的富文本编辑器有两个&#xff1a;“vue2-editor” 和 “vue-quill-editor” 都是用于Vue.js的富文本编辑器组件&#xff0c;它们具有一些共同的特点&#xff0c;但也有一些区别。 共同点&#xff1a; 1、富文本编辑功能&#xff1a; 两者都提供了富文…

Ubuntu安装Oracle JDK

文章目录 下载JDK安装Oracle JDK验证安装 下载JDK Oracle JDK需要从Oracle的官方网站下载&#xff0c;访问Oracle的官方网站并下载所需版本的JDK。 https://www.oracle.com/java/technologies/downloads/#java17 安装Oracle JDK 2.1. 下载.tar.gz文件后&#xff0c;移动到适…

el-tooltip内容换行显示

效果图&#xff1a; html: <div class"rules-tooltip flex-center"><el-tooltip class"item" effect"dark" placement"bottom-start"><div slot"content" v-html"tipsContent"></div>&l…

二维平面扭曲的python实现及思路

二维平面扭曲的python实现及思路 缘起原理实现代码 缘起 工作需要&#xff0c;需要一个尝试改变设备布点的方法&#xff0c;在csdn闲逛时&#xff0c;偶然间发现这样的一篇文章 二维扭曲&#xff0c;参考这位博主的文章&#xff0c;我对其内容进行复现和进一步挖掘。若有侵权或…

多维时序 | MATLAB实现WOA-CNN-LSTM-Attention多变量时间序列预测(SE注意力机制)

多维时序 | MATLAB实现WOA-CNN-LSTM-Attention多变量时间序列预测&#xff08;SE注意力机制&#xff09; 目录 多维时序 | MATLAB实现WOA-CNN-LSTM-Attention多变量时间序列预测&#xff08;SE注意力机制&#xff09;预测效果基本描述模型描述程序设计参考资料 预测效果 基本描…

肖sir__mysql之综合题练习__013

数据库题&#xff08;10*5&#xff09; 下面是一个学生与课程的数据库&#xff0c;三个关系表为&#xff1a; 学生表S&#xff08;Sid&#xff0c;SNAME,AGE,SEX&#xff09; 成绩表SC&#xff08;Sid&#xff0c;Cid&#xff0c;GRADE&#xff09; 课程表C&#xff08;Cid&…