AOT源码解析4.4 -decoder生成预测mask并计算loss

3、生成ref_imgs的预测mask和loss

这一步在训练阶段调用

3.1 数据处理

在这里插入图片描述

图1,如图1所示,将enc_embs的最后一个比例的特征图和有ref_imgs相关的特征图得到的LSTT特征图相拼接作为输入

        curr_enc_embs = self.curr_enc_embscurr_lstt_embs = self.curr_lstt_output[0]pred_id_logits = self.AOT.decode_id_logits(curr_lstt_embs,curr_enc_embs)

3.2 Decoder结构

在这里插入图片描述

图2, decoder的操作步骤如图,该解码器将enc_embs各个比例的特征图结合到一起

  • Decoder结构
class FPNSegmentationHead(nn.Module):def __init__(self,in_dim,out_dim,decode_intermediate_input=True,hidden_dim=256,shortcut_dims=[24, 32, 96, 1280],align_corners=True):super().__init__()self.align_corners = align_cornersself.decode_intermediate_input = decode_intermediate_inputself.conv_in = ConvGN(in_dim, hidden_dim, 1)self.conv_16x = ConvGN(hidden_dim, hidden_dim, 3)self.conv_8x = ConvGN(hidden_dim, hidden_dim // 2, 3)self.conv_4x = ConvGN(hidden_dim // 2, hidden_dim // 2, 3)self.adapter_16x = nn.Conv2d(shortcut_dims[-2], hidden_dim, 1)self.adapter_8x = nn.Conv2d(shortcut_dims[-3], hidden_dim, 1)self.adapter_4x = nn.Conv2d(shortcut_dims[-4], hidden_dim // 2, 1)self.conv_out = nn.Conv2d(hidden_dim // 2, out_dim, 1)self._init_weight()def forward(self, inputs, shortcuts):if self.decode_intermediate_input:x = torch.cat(inputs, dim=1)else:x = inputs[-1]x = F.relu_(self.conv_in(x))s1 = self.adapter_16x(shortcuts[-2])x = F.relu_(self.conv_16x(self.adapter_16x(shortcuts[-2]) + x))x = F.interpolate(x,size=shortcuts[-3].size()[-2:],mode="bilinear",align_corners=self.align_corners)x = F.relu_(self.conv_8x(self.adapter_8x(shortcuts[-3]) + x))x = F.interpolate(x,size=shortcuts[-4].size()[-2:],mode="bilinear",align_corners=self.align_corners)x = F.relu_(self.conv_4x(self.adapter_4x(shortcuts[-4]) + x))x = self.conv_out(x)return x

3.3 计算loss

在这里插入图片描述

  • 对Decoder输出的结果按照对象数量进行分隔
        pred_id_logits = self.pred_id_logitspred_id_logits = F.interpolate(pred_id_logits,size=gt_mask.size()[-2:],mode="bilinear",align_corners=self.align_corners)label_list = []logit_list = []for batch_idx, obj_num in enumerate(self.obj_nums):now_label = gt_mask[batch_idx].long()now_logit = pred_id_logits[batch_idx, :(obj_num + 1)].unsqueeze(0)label_list.append(now_label.long())logit_list.append(now_logit)
  • 计算loss

在深度学习中,尤其是在图像相关的任务(如图像分割)中,我们通常有大量的像素需要预测。在这种情况下,可能并不是所有的像素对最终的任务都同样重要。
例如,模型可能已经能够很好地预测图像的大部分区域,但是对于一些难以区分的区域(如物体边缘或小物体)预测得不够好。这些难以预测的区域可能正是模型需要关注的重点。

为了使模型更加关注这些难以预测的区域,可以采用一种称为“硬例挖掘”(hard example mining)的技术。这种方法的基本思想是,不是对所有的像素平均地计算损失,而是只关注那些损失最大的像素。

通过这种方式,模型的训练可以更加集中在那些难以正确预测的像素上,从而提高模型的整体性能。具体来说,“top k percent pixels” 指的是按照损失值从高到低排序后,选取前 k 百分比的像素。例如,如果 k 设置为 50%,那么在损失计算中,只会考虑损失最大的前 50% 的像素。

在代码中,这通常是通过以下步骤实现的:

  • 计算所有像素的损失。
  • 根据损失值对像素进行排序。
  • 选择损失值最高的前 k 百分比的像素。
  • 只计算这些选定像素的损失,并将它们加起来作为最终的损失。
class CrossEntropyLoss(nn.Module):def __init__(self,top_k_percent_pixels=None,hard_example_mining_step=100000):super(CrossEntropyLoss, self).__init__()self.top_k_percent_pixels = top_k_percent_pixelsif top_k_percent_pixels is not None:assert (top_k_percent_pixels > 0 and top_k_percent_pixels < 1)self.hard_example_mining_step = hard_example_mining_step + 1e-5if self.top_k_percent_pixels is None:self.celoss = nn.CrossEntropyLoss(ignore_index=255,reduction='mean')else:self.celoss = nn.CrossEntropyLoss(ignore_index=255,reduction='none')def forward(self, dic_tmp, y, step):total_loss = []for i in range(len(dic_tmp)):pred_logits = dic_tmp[i]gts = y[i]if self.top_k_percent_pixels is None:final_loss = self.celoss(pred_logits, gts)else:# Only compute the loss for top k percent pixels.# First, compute the loss for all pixels. Note we do not put the loss# to loss_collection and set reduction = None to keep the shape.num_pixels = float(pred_logits.size(2) * pred_logits.size(3))pred_logits = pred_logits.view(-1, pred_logits.size(1),pred_logits.size(2) * pred_logits.size(3))gts = gts.view(-1, gts.size(1) * gts.size(2))pixel_losses = self.celoss(pred_logits, gts)if self.hard_example_mining_step == 0:top_k_pixels = int(self.top_k_percent_pixels * num_pixels)else:ratio = min(1.0,step / float(self.hard_example_mining_step))top_k_pixels = int((ratio * self.top_k_percent_pixels +(1.0 - ratio)) * num_pixels)top_k_loss, top_k_indices = torch.topk(pixel_losses,k=top_k_pixels,dim=1)final_loss = torch.mean(top_k_loss)final_loss = final_loss.unsqueeze(0)total_loss.append(final_loss)total_loss = torch.cat(total_loss, dim=0)return total_loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1547450.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

了解针对基座大语言模型(类似 ChatGPT 的架构,Decoder-only)的重头预训练和微调训练

&#x1f349; CSDN 叶庭云&#xff1a;https://yetingyun.blog.csdn.net/ 随着自然语言处理&#xff08;NLP&#xff09;技术的飞速进步&#xff0c;基于 Transformer 架构的大语言模型在众多任务中取得了显著成就。特别是 Decoder-only 架构&#xff0c;如 GPT 系列模型&…

8.7基于数学形态学的边缘检测

基本概念 数学形态学&#xff08;Mathematical Morphology&#xff09;是一套用于图像处理的技术&#xff0c;它包括膨胀&#xff08;Dilation&#xff09;、腐蚀&#xff08;Erosion&#xff09;、开运算&#xff08;Opening&#xff09;和闭运算&#xff08;Closing&#xf…

使用电子模拟器 Wokwi 运行 ESP32 示例(Arduino IDE、VSCode、ESP32C3)

文章目录 Wokwi 简介安装客户端&#xff08;Mac/Linux&#xff09;创建 Token Arduino IDEVSCode 配置安装 wokwi 插件打开编译后目录 ESP32C3 示例Arduino IDE创建模拟器运行模拟器 Wokwi 简介 Wokwi 是一款在线电子模拟器。您可以使用它来模拟 Arduino、ESP32、STM32 以及许…

HTML·第3章 表格布局与表单交互

3.1 表格概述 3.1.1 表格的结构 表格是由行和列组成的二维表&#xff0c;而每行又由一个或多个单元格组成&#xff0c;用于放置数据或其他内容。表格中的单元格是行与列的交叉部分&#xff0c;是组成表格的最基本单元。单元格的内容是数据&#xff0c;也称数据单元格。数据单元…

线上环境排故思路与方法GC优化策略

前言 这是针对于我之前[博客]的一次整理&#xff0c;因为公司需要一些技术文档的定期整理与分享&#xff0c;我就整理了一下。(https://blog.csdn.net/TT_4419/article/details/141997617?spm1001.2014.3001.5501) 其实&#xff0c;nginx配置 服务故障转移与自动恢复也是可以…

人工智能开发实战照片智能搜索功能实现

内容提要 项目分析预备知识项目实战 一、项目分析 1、提出问题 随着人民生活水平的提高和手机照相功能的日趋完美&#xff0c;我们不经意中拍摄了很多值得回忆的时刻&#xff0c;一场说走就走的旅行途中也记录下许多令人心动的瞬间&#xff0c;不知不觉之中&#xff0c;我们…

【CSS in Depth 2 精译_040】6.3 CSS 定位技术之:相对定位(下)—— 用纯 CSS 绘制一个三角形

当前内容所在位置&#xff08;可进入专栏查看其他译好的章节内容&#xff09; 第一章 层叠、优先级与继承&#xff08;已完结&#xff09;第二章 相对单位&#xff08;已完结&#xff09;第三章 文档流与盒模型&#xff08;已完结&#xff09;第四章 Flexbox 布局&#xff08;已…

RabbitMQ应用

RabbitMQ 共提供了7种⼯作模式, 进⾏消息传递 一、七种模式的概述 1、Simple(简单模式) P&#xff1a;生产者&#xff0c;就是发送消息的程序 C&#xff1a;消费者&#xff0c;就是接收消息的程序 Queue&#xff1a;消息队列&#xff0c;类似⼀个邮箱, 可以缓存消息; ⽣产者…

【微服务即时通讯系统】——brpc远程过程调用、百度开源的RPC框架、brpc的介绍、brpc的安装、brpc使用和功能测试

文章目录 brpc1. brpc的介绍1.1 rpc的介绍1.2 rpc的原理1.3 grpc和brpc 2. brpc的安装3. brpc使用3.1 brpc接口介绍 4. brpc使用测试4.1 brpc同步和异步调用 brpc 1. brpc的介绍 1.1 rpc的介绍 RPC&#xff08;Remote Procedure Call&#xff09;远程过程调用&#xff0c;是一…

使用Postman搞定各种接口token实战

现在许多项目都使用jwt来实现用户登录和数据权限&#xff0c;校验过用户的用户名和密码后&#xff0c;会向用户响应一段经过加密的token&#xff0c;在这段token中可能储存了数据权限等&#xff0c;在后期的访问中&#xff0c;需要携带这段token&#xff0c;后台解析这段token才…

Java Stream流编程入门

流式编程 stream流式编程分为 首先转化为stream中间函数的链接最后的终结函数 怎么转化为stream 单列集合 List<String> list new ArrayList<String>(); Collections.addAll(list,"1","2","3","4","5","…

【MySQL】MVCC及其实现原理

目录 1. 概念介绍 什么是MVCC 什么是当前读和快照读 MVCC的好处 2. MVCC实现原理 隐藏字段 Read View undo-log 数据可见性算法 3. RC和RR隔离级别下MVCC的差异 4. MVCC&#xff0b;Next-key-Lock 防止幻读 1. 概念介绍 什么是MVCC Multi-Version Concurrency Cont…

FGPA实验——触摸按键

本文系列都基于正点原子新起点开发板 FPGA系列 1&#xff0c;verlog基本语法&#xff08;随时更新&#xff09; 2&#xff0c;流水灯&#xff08;待定&#xff09; 3&#xff0c;FGPA实验——触摸按键 一、触摸操作原理实现 分类&#xff1a;电阻式&#xff08;不耐用&…

LeetCode - 850 矩形面积 II

题目来源 850. 矩形面积 II - 力扣&#xff08;LeetCode&#xff09; 题目描述 给你一个轴对齐的二维数组 rectangles 。 对于 rectangle[i] [x1, y1, x2, y2]&#xff0c;其中&#xff08;x1&#xff0c;y1&#xff09;是矩形 i 左下角的坐标&#xff0c; (xi1, yi1) 是该…

通信工程学习:什么是VIM虚拟化基础设施管理器

VIM:虚拟化基础设施管理器 VIM(Virtualized Infrastructure Manager)虚拟化基础设施管理器,是一种负责管理和控制虚拟化环境中所有虚拟资源的工具和系统。以下是关于VIM虚拟化基础设施管理器的详细解释: 一、定义与功能 VIM是网络功能虚拟化(NFV)架构中…

李宏毅机器学习2023-HW10-Adversarial Attack

文章目录 TaskBaselineFGSM (Fast Gradient Sign Method (FGSM)I-FGSM(Iterative Fast Gradient Sign Method)MI-FGSM(Momentum Iterative Fast Gradient Sign Method)M-DI2-FGSM(Diverse Input Momentum Iterative Fast Gradient Sign Method) Reportfgsm attackJepg Compress…

探索5 大 Node.js 功能

目录 单线程 Node.js 工作线程【Worker Threads】 Node.js 进程 进程缺点 工作线程 注意 集群进程模块【Cluster Process Module】 内部发生了什么&#xff1f; 为什么要使用集群 注意&#xff1a; 应用场景&#xff1a; 内置 HTTP/2 支持 这个 HTTP/2 是什么&…

Windows安装Vim,并在PowerShell中直接使用vim

大家好啊&#xff0c;我是豆小匠。 这期介绍下怎么在windows的PowerShell上使用vim&#xff0c;方便在命令行里修改配置文件等。 先上效果图&#xff1a; 1、下载Vim GitHub传送门&#xff1a;https://github.com/vim/vim-win32-installer/releases 选择win-64的版本下载即可&…

VS Code使用Git Bash终端

Git Bash可以运行linux命令&#xff0c;在VS Code的终端界面&#xff0c;找到号旁边的箭头&#xff0c;就能直接切换了 当然&#xff0c;前提是安装了Git Bash&#xff0c;并且在资源管理器里&#xff0c;能鼠标右键出"Git Bash Here"

node.js从入门到快速开发一个简易的web服务器

浏览器中JavaScript学习路径: JavaScript基础语法浏览器内置API(DOMBOM)第三方库(jQuery,art-template等) Node.js的学习路径 JavaScript基础语法Node.js内置API模块(fs、path、http等)第三方API模块(express、mysql等) Node.js安装 通过Node.js 来运行Javascript 代码&am…