多分类中混淆矩阵的TP,TN,FN,FP计算

关于混淆矩阵,各位可以在这里了解:混淆矩阵细致理解_夏天是冰红茶的博客-CSDN博客

上一篇中我们了解了混淆矩阵,并且进行了类定义,那么在这一节中我们将要对其进行扩展,在多分类中,如何去计算TP,TN,FN,FP。

原理推导

这里以三分类为例,这里来看看TP,TN,FN,FP是怎么分布的。

类别1的标签:

类别2的标签:

类别3的标签:

这样我们就能知道了混淆矩阵的对角线就是TP

TP = torch.diag(h)

 假正例(FP)是模型错误地将负类别样本分类为正类别的数量

FP = torch.sum(h, dim=1) - TP

假负例(FN)是模型错误地将正类别样本分类为负类别的数量

FN = torch.sum(h, dim=0) - TP

最后用总数减去除了 TP 的其他三个元素之和得到 TN

TN = torch.sum(h) - (torch.sum(h, dim=0) + torch.sum(h, dim=1) - TP)

逻辑验证

这里借用上一篇的例子,假如我们这个混淆矩阵是这样的:

tensor([[2, 0, 0],
            [0, 1, 1],
            [0, 2, 0]])

为了方便讲解,这里我们对其进行一个简单的编号,即0—8:

012
345
678

torch.sum(h, dim=1) 可得 tensor([2., 2., 2.]) , torch.sum(h, dim=0) 可得 tensor([2., 3., 1.]) 。

  •  TP:   tensor([2., 1., 0.]) 
  •  FN:   tensor([0., 1., 2.]) 
  •  TN:   tensor([4., 2., 3.]) 
  •  FP:   tensor([0., 2., 1.])

我们先来看看TP的构成,对应着矩阵的对角线2,1,0;FP在类别1中占3,6号位,在类别2中占1,7号位,在类别3中占2,5号位,加起来即为0,1,2;TN在类别1中占4,5,7,8号位,在类别2中占边角位,在类别3中占0,1,3,4号位,加起来即为4,2,3;FN在类别1中占1,2号位,在类别2中占3,5号位,在类别3中占6,7号位,加起来即为0,2,1。

补充类定义

import torch
import numpy as npclass ConfusionMatrix(object):def __init__(self, num_classes):self.num_classes = num_classesself.mat = Nonedef update(self, t, p):n = self.num_classesif self.mat is None:# 创建混淆矩阵self.mat = torch.zeros((n, n), dtype=torch.int64, device=t.device)with torch.no_grad():# 寻找GT中为目标的像素索引k = (t >= 0) & (t < n)# 统计像素真实类别t[k]被预测成类别p[k]的个数inds = n * t[k].to(torch.int64) + p[k]self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)def reset(self):if self.mat is not None:self.mat.zero_()@propertydef ravel(self):"""计算混淆矩阵的TN, FP, FN, TP"""h = self.mat.float()n = self.num_classesif n == 2:TP, FN, FP, TN = h.flatten()return TP, FN, FP, TNif n > 2:TP = h.diag()FN = h.sum(dim=1) - TPFP = h.sum(dim=0) - TPTN = torch.sum(h) - (torch.sum(h, dim=0) + torch.sum(h, dim=1) - TP)return TP, FN, FP, TNdef compute(self):"""主要在eval的时候使用,你可以调用ravel获得TN, FP, FN, TP, 进行其他指标的计算计算全局预测准确率(混淆矩阵的对角线为预测正确的个数)计算每个类别的准确率计算每个类别预测与真实目标的iou,IoU = TP / (TP + FP + FN)"""h = self.mat.float()acc_global = torch.diag(h).sum() / h.sum()acc = torch.diag(h) / h.sum(1)iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))return acc_global, acc, iudef __str__(self):acc_global, acc, iu = self.compute()return ('global correct: {:.1f}\n''average row correct: {}\n''IoU: {}\n''mean IoU: {:.1f}').format(acc_global.item() * 100,['{:.1f}'.format(i) for i in (acc * 100).tolist()],['{:.1f}'.format(i) for i in (iu * 100).tolist()],iu.mean().item() * 100)

我在代码中添加了属性修饰器,以便我们可以直接的进行调用,并且也考虑到了二分类与多分类不同的情况。

性能指标

关于这些指标在网上有很多介绍,这里就不细讲了

class ModelIndex():def __init__(self,TP, FN, FP, TN, e=1e-5):self.TN = TNself.FP = FPself.FN = FNself.TP = TPself.e = edef Precision(self):"""精确度衡量了正类别预测的准确性"""return self.TP / (self.TP + self.FP + self.e)def Recall(self):"""召回率衡量了模型对正类别样本的识别能力"""return self.TP / (self.TP + self.FN + self.e)def IOU(self):"""表示模型预测的区域与真实区域之间的重叠程度"""return self.TP / (self.TP + self.FP + self.FN + self.e)def F1Score(self):"""F1分数是精确度和召回率的调和平均数"""p = self.Precision()r = self.Recall()return (2*p*r) / (p + r + self.e)def Specificity(self):"""特异性是指模型在负类别样本中的识别能力"""return self.TN / (self.TN + self.FP + self.e)def Accuracy(self):"""准确度是模型正确分类的样本数量与总样本数量之比"""return (self.TP + self.TN) / (self.TP + self.TN + self.FP + self.FN + self.e)def FP_rate(self):"""False Positive Rate,假阳率是模型将负类别样本错误分类为正类别的比例"""return self.FP / (self.FP + self.TN + self.e)def FN_rate(self):"""False Negative Rate,假阴率是模型将正类别样本错误分类为负类别的比例"""return self.FN / (self.FN + self.TP + self.e)def Qualityfactor(self):"""品质因子综合考虑了召回率和特异性"""r = self.Recall()s = self.Specificity()return r+s-1

参考文章:多分类中TP/TN/FP/FN的计算_Hello_Chan的博客-CSDN博客 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/139820.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

IDEA中创建Java Web项目方法1

以下过程使用IntelliJ IDEA 2021.3 一、File-> New -> Project... 1. 项目类型中选择 Java Enterprise 项目 2. Name&#xff1a;填写自己的项目名称 3. Project template&#xff1a;选择项目的模板&#xff0c;Web application。支持JSP和Servlet的项目 4. Applica…

Nginx location 精准匹配URL = /

Location是什么&#xff1f; Location是Nginx中的块级指令(block directive)&#xff0c;通过配置Location指令块&#xff0c;可以决定客户端发过来的请求URI如何处理&#xff08;是映射到本地文件还是转发出去&#xff09;及被哪个location处理。 匹配模式 分为两种模式&…

位段 联合体 枚举

Hello好久不见&#xff0c;今天分享的是接上次结构体没有分享完的内容&#xff0c;这次我们讲讲位段 枚举和联合体的概念以及他们的用法。 2.1 什么是位段 位段的声明和结构是类似的&#xff0c;有两个不同&#xff1a; 1.位段的成员必须是 int、unsigned int 或signed int 。 …

微信开放平台第三方开发,实现代小程序备案申请

大家好&#xff0c;我是小悟 微信小程序备案整体流程总共分为五个环节&#xff1a;备案信息填写、平台初审、工信部短信核验、通管局审核和备案成功。 服务商可以代小程序发起备案申请。在申请小程序备案之前&#xff0c;需要确保小程序基本信息已填写完成、小程序至少存在一个…

com.google.gson.internal.LinkedTreeMap cannot be cast to XXX

起因是在对google商品做本地缓存时&#xff0c;上线后发现的bug 刚开始非常自信&#xff0c;debug没问题线上有问题&#xff0c;大概率就是混淆文件没有添加keep&#xff0c;于是本地添加对SDK中类的keep&#xff0c;本地打包release验证&#xff0c;不出意外还是崩溃 仔细看…

C语言指针变量的引用距离

本段代码&#xff0c;测试&#xff0c;C的函数传参中&#xff0c;形参是基础类型参数和地址参数&#xff0c;对于实参的值影响。 #include <stdio.h> add(int a,int b){a;b;printf("add副本a%d\n",a);printf("add副本b%d\n",b);printf("副本ca…

Interceptor的使用场景:拦截请求中的租户信息,注入到租户上下文中

业务场景 在SaaS环境中&#xff0c;租户是最重要的隔离业务数据的属性了&#xff0c;在自己的项目体系环境中&#xff0c;租户id能保证有值。但有个特殊场景&#xff0c;某些特殊权限的账号需要修改指定租户的内容&#xff0c;也即前端会携带租户信息过来&#xff0c;并且内部涉…

在github上设置不同分支,方便回滚

在github上设置不同分支&#xff0c;方便回滚 步骤可能出现的问题couldnt find remote ref gpuVersion1. 确保您处于正确的分支2. 添加并提交更改&#xff08;如果还未进行&#xff09;3. 推送本地分支到远程仓库4. 验证操作 步骤 之前在github上上传了一个项目代码&#xff0c…

【马蹄集】—— 数论专题:筛法

数论专题 目录 MT2213 质数率MT2214 元素共鸣MT2215 小码哥的喜欢数MT2216 数的自我MT2217 数字游戏 MT2213 质数率 难度&#xff1a;黄金    时间限制&#xff1a;1秒    占用内存&#xff1a;256M 题目描述 请求出 [ 1 , n ] \left[1,n\right] [1,n] 范围内质数占比率。…

Unity的AB包相关

1、打包 在这个界面左边右键&#xff0c;CreateNewBundle 将要打包的模型制作成预设体 在下面勾选 选好平台路径&#xff0c;点击Build 2、加载AB包 public class ABTest : MonoBehaviour {// Start is called before the first frame updatevoid Start(){//加载AB包AssetB…

识别准确率达 95%,华能东方电厂财务机器人实践探索

摘 要&#xff1a;基于华能集团公司大数据与人工智能构想理念&#xff0c;结合东方电厂实际工作需要&#xff0c;财务工作要向数字化、智能化纵深推进&#xff0c;随着财务数字化转型和升级加速&#xff0c;信息化水平不断提升&#xff0c;以及内部信息互联互通不断加深&#x…

深入探析NCV7356D1R2G 单线CAN收发器各项参数

NCV7356D1R2G深力科是一款用于单线数据链路的物理层器件&#xff0c;能够使用多种具碰撞分解的载波感测多重存取 (CSMA/CR) 协议运行&#xff0c;如博世控制器区域网络 (CAN) 2.0 版。此串行数据链路网络适用于不需要高速数据的应用&#xff0c;低速数据可在物理介质部件和微处…

【用unity实现100个游戏之12】unity制作一个俯视角2DRPG《类星露谷物语》资源收集游戏demo

文章目录 前言加快编辑器运行速度素材(1)场景人物(2)工具 一、人物移动和动画切换二、走路灰尘粒子效果探究实现 三、树木排序设计方法一方法二 四、绘制拿工具的角色动画五、砍树实现六、存储拾取物品引入Unity 的可序列化字典类 七、实现靠近收获物品自动吸附八、树木被砍掉的…

视频编解码器H.264和H265有什么区别?

对于大型视频文件来说&#xff0c;视频编解码器至关重要&#xff0c;它可以将文件压缩为较小的尺寸&#xff0c;从而可以更轻松地存储和加快传输速度。而两种最常用的编解码器是H.264和H.265&#xff0c;那么它们两者之间有什么区别&#xff0c;哪一个更好呢&#xff1f; 1. 什…

手摸手图解 CodeWhisperer 的安装使用

CodeWhisperer 是亚⻢逊出品的一款基于机器学习的通用代码生成器&#xff0c;可实时提供代码建议。 亚马逊云科技开发者社区为开发者们提供全球的开发技术资源。这里有技术文档、开发案例、技术专栏、培训视频、活动与竞赛等。帮助中国开发者对接世界最前沿技术&#xff0c;观点…

JCEF中js与java交互、js与java相互调用

jcef中js与java相互调用&#xff0c;java与js相互调用&#xff0c;chrome与java相互调用&#xff0c;java与chrome相互调用、jcef与java相互调用 前提&#xff1a;https://blog.csdn.net/weixin_44480167/article/details/133170970&#xff08;java内嵌浏览器CEF-JAVA、jcef、…

车辆检测:An Efficient Wide-Range Pseudo-3D Vehicle Detection Using A Single Camera

论文作者&#xff1a;Zhupeng Ye,Yinqi Li,Zejian Yuan 作者单位&#xff1a;Xian Jiaotong University 论文链接&#xff1a;http://arxiv.org/abs/2309.08369v1 项目链接&#xff1a;https://www.youtube.com/watch?v1gk1PmsQ5Q8 内容简介&#xff1a; 1&#xff09;方…

【数据结构】二叉树之堆的实现

&#x1f525;博客主页&#xff1a;小王又困了 &#x1f4da;系列专栏&#xff1a;数据结构 &#x1f31f;人之为学&#xff0c;不日近则日退 ❤️感谢大家点赞&#x1f44d;收藏⭐评论✍️ 目录 一、二叉树的顺序结构 &#x1f4d2;1.1顺序存储 &#x1f4d2;1.2堆的性质…

【LeetCode75】第六十二题 多米诺和托米诺平铺

目录 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 代码&#xff1a; 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 题目给我一个数字n&#xff0c;表示我们有2*n大小的地板需要铺。 我们拥有两种瓷砖&#xff0c;一种的长度为2的多米诺&#xff0c;另一…

CFCA证书 申请 流程(一)

跳过科普&#xff0c;可直接进入申请&#x1f449;https://blog.csdn.net/Ximerr/article/details/133169391 CFCA证书 CFCA证书是指由中国金融认证中心颁发的证书&#xff0c;包括普通数字证书、服务器数字证书和预植证书等&#xff0c;目前&#xff0c;各大银行和金融机构都…