深度学习可视化工具——GradCAM

深度学习可视化——GradCAM

    • GradCAM
    • 调用函数
    • 使用函数

GradCAM

  • CAM package
# -*- coding: utf-8 -*-
"""
Created on Fri Sep  2 15:25:33 2022@author: Lenovo
"""
import cv2
import numpy as npclass ActivationsAndGradients:""" Class for extracting activations andregistering gradients from targeted intermediate layers """def __init__(self, model, target_layers, reshape_transform):self.model = modelself.gradients = []self.activations = []self.reshape_transform = reshape_transformself.handles = []for target_layer in target_layers:self.handles.append(target_layer.register_forward_hook(self.save_activation))# Backward compatibility with older pytorch versions:if hasattr(target_layer, 'register_full_backward_hook'):self.handles.append(target_layer.register_full_backward_hook(self.save_gradient))else:self.handles.append(target_layer.register_backward_hook(self.save_gradient))def save_activation(self, module, input, output):activation = outputif self.reshape_transform is not None:activation = self.reshape_transform(activation)self.activations.append(activation.cpu().detach())def save_gradient(self, module, grad_input, grad_output):# Gradients are computed in reverse ordergrad = grad_output[0]if self.reshape_transform is not None:grad = self.reshape_transform(grad)self.gradients = [grad.cpu().detach()] + self.gradientsdef __call__(self, x):self.gradients = []self.activations = []return self.model(x)def release(self):for handle in self.handles:handle.remove()class GradCAM:def __init__(self,model,target_layers,reshape_transform=None,use_cuda=False):self.model = model.eval()self.target_layers = target_layersself.reshape_transform = reshape_transformself.cuda = use_cudaif self.cuda:self.model = model.cuda()self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform)""" Get a vector of weights for every channel in the target layer.Methods that return weights channels,will typically need to only implement this function. """@staticmethoddef get_cam_weights(grads):return np.mean(grads, axis=(2, 3), keepdims=True)@staticmethoddef get_loss(output, target_category):loss = 0for i in range(len(target_category)):loss = loss + output[i, target_category[i]]return lossdef get_cam_image(self, activations, grads):weights = self.get_cam_weights(grads)weighted_activations = weights * activationscam = weighted_activations.sum(axis=1)return cam@staticmethoddef get_target_width_height(input_tensor):width, height = input_tensor.size(-1), input_tensor.size(-2)return width, heightdef compute_cam_per_layer(self, input_tensor):activations_list = [a.cpu().data.numpy()for a in self.activations_and_grads.activations]grads_list = [g.cpu().data.numpy()for g in self.activations_and_grads.gradients]target_size = self.get_target_width_height(input_tensor)cam_per_target_layer = []# Loop over the saliency image from every layerfor layer_activations, layer_grads in zip(activations_list, grads_list):cam = self.get_cam_image(layer_activations, layer_grads)cam[cam < 0] = 0  # works like mute the min-max scale in the function of scale_cam_imagescaled = self.scale_cam_image(cam, target_size)cam_per_target_layer.append(scaled[:, None, :])return cam_per_target_layerdef aggregate_multi_layers(self, cam_per_target_layer):cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)cam_per_target_layer = np.maximum(cam_per_target_layer, 0)result = np.mean(cam_per_target_layer, axis=1)return self.scale_cam_image(result)@staticmethoddef scale_cam_image(cam, target_size=None):result = []for img in cam:img = img - np.min(img)img = img / (1e-7 + np.max(img))if target_size is not None:img = cv2.resize(img, target_size)result.append(img)result = np.float32(result)return resultdef __call__(self, input_tensor, target_category=None):if self.cuda:input_tensor = input_tensor.cuda()# 正向传播得到网络输出logits(未经过softmax)output = self.activations_and_grads(input_tensor)if isinstance(target_category, int):target_category = [target_category] * input_tensor.size(0)if target_category is None:target_category = np.argmax(output.cpu().data.numpy(), axis=-1)print(f"category id: {target_category}")else:assert (len(target_category) == input_tensor.size(0))self.model.zero_grad()loss = self.get_loss(output, target_category)loss.backward(retain_graph=True)# In most of the saliency attribution papers, the saliency is# computed with a single target layer.# Commonly it is the last convolutional layer.# Here we support passing a list with multiple target layers.# It will compute the saliency image for every image,# and then aggregate them (with a default mean aggregation).# This gives you more flexibility in case you just want to# use all conv layers for example, all Batchnorm layers,# or something else.cam_per_layer = self.compute_cam_per_layer(input_tensor)return self.aggregate_multi_layers(cam_per_layer)def __del__(self):self.activations_and_grads.release()def __enter__(self):return selfdef __exit__(self, exc_type, exc_value, exc_tb):self.activations_and_grads.release()if isinstance(exc_value, IndexError):# Handle IndexError here...print(f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")return True

调用函数

import sys
from tqdm import tqdm
import torch
from torchvision import transforms
import numpy as np
import os
import cv2
from PIL import Image
from CAM import GradCAM
def show_cam_on_image(img: np.ndarray,mask: np.ndarray,use_rgb: bool = False,colormap: int = cv2.COLORMAP_JET) -> np.ndarray:""" This function overlays the cam mask on the image as an heatmap.By default the heatmap is in BGR format.:param img: The base image in RGB or BGR format.:param mask: The cam mask.:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.:param colormap: The OpenCV colormap to be used.:returns: The default image with the cam overlay."""heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)if use_rgb:heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)heatmap = np.float32(heatmap) / 255if np.max(img) > 1:raise Exception("The input image should np.float32 in the range [0, 1]")cam = heatmap + imgcam = cam / np.max(cam)return np.uint8(255 * cam)def save_cam_mask(cam_path,mask: np.ndarray,w: int = 224,h: int = 224,use_rgb: bool = True,colormap: int = cv2.COLORMAP_JET):'''只保存生成的CAM,其中的参数mask表示生成的CAM:cam_path cam保存的地址:mask 生成的CAM,此时是224*224的灰度图,需要转换成RGB:h 图像的高:w 图像的宽:ues_rgb 使用RGB格式:colormap ...'''heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)if use_rgb:heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)heatmap = transforms.Resize([w, h])(Image.fromarray(heatmap))heatmap.save(cam_path)def get_cam(model, img_path, target_layers, data_transform):'''根据model,target_layer对图像img_path进行cam可视化:db  使用的model:img_path  需要可是化的图像:target_layer  目标可视化那一层的特征图:data_transform:cam_path  生成的cam图片:camAorig_path  生成cam叠加原图的图片'''assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path).convert('RGB')h, w = img.sizeorimg = img.resize((224, 224))orimg = np.array(orimg, dtype=np.uint8)img_tensor = data_transform(img)input_tensor = torch.unsqueeze(img_tensor, dim=0)cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)target_category = 1  # tabby, tabby cat# 生成cam灰度图grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category)grayscale_cam = grayscale_cam[0, :]# 将CAM的灰度图转换成三通道并保存#    save_cam_mask(cam_path= cam_path,mask = grayscale_cam,w=w,h=h)# 将原图和cam进行叠加显示visualization = show_cam_on_image(orimg.astype(dtype=np.float32) / 255.,grayscale_cam,use_rgb=True)visualization = transforms.Resize([w, h])(Image.fromarray(visualization))#    visualization.save(camAorig_path)return visualization

使用函数

# Grad-CAM
import os
import cv2
import torch
import utils
from torchvision import transforms
from model import convnext_tiny as create_model
from PIL import Image
num_classes = 2
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose([transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize([0.456, 0.485, 0.406],[0.224, 0.229, 0.225])])    
img_dir='../test_data/img/' 
cam_dir='../test_data/cam/'
color_list=os.listdir(img_dir)
#%%
weights_path = '..t/weights/patch.pth'
model = create_model(num_classes=num_classes).to(device)
model.load_state_dict(torch.load(weights_path, map_location=device))
target_layers = model.stages[3]
for picture in color_list:img_path = os.path.join(img_dir,picture)cam_path = os.path.join(cam_dir,picture)cam = utils.get_cam(model, img_path, target_layers, data_transform)
#    print(type(cam))cam.save(cam_path)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/35619.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

1.1 Beginner Level学习之“使用 rosed 在 ROS 中编辑文件”(第九节)

学习大纲&#xff1a; 1. 使用 rosed rosed 是 ROS 自带的 Rosbash Suite 的一部分&#xff0c;它的目的是让你通过 ROS 包的名称快速编辑文件&#xff0c;而不用手动输入完整的路径&#xff0c;节省开发时间。 基本用法&#xff1a;$ rosed [package_name] [filename] 示例…

MySQL语句学习第三篇_数据库

MySQL语句学习第三篇_数据库 专栏记录MySQL的学习&#xff0c;感谢大家观看。 本章的专栏&#x1f4da;➡️MySQL语法学习 本博客前一章节指向➡️MySQL语句学习第二篇 本人的博客➡️:如烟花般绚烂却又稍纵即逝的主页 文章目录 MySQL的基础操作&#xff08;改与查&#xff0…

HCIA-openGauss_2_2连接与认证

设置客户端认证策略 设置配置文件参数 gssql客户端连接-确定连接信息 客户端工具通过数据库主节点连接数据库&#xff0c;因此连接前&#xff0c;需要获取数据库主节点的在服务器的IP地址及数据库主节点的端口号信息。 步骤1&#xff1a;以操作系统用户omm登录数据库主节点。…

什么?RayLink远程控制软件支持企业IT应用!

在当今企业IT管理中&#xff0c;远程控制工具扮演着不可或缺的角色。设想一下&#xff0c;你的团队成员分散在全球各地&#xff0c;或者员工正在远程工作&#xff0c;这时电脑突然出现问题。如果IT支持团队能够利用远程控制软件&#xff0c;比如RayLink&#xff0c;迅速远程接入…

【C++】——精细化哈希表架构:理论与实践的综合分析

先找出你的能力在哪里&#xff0c;然后再决定你是谁。 —— 塔拉韦斯特弗 《你当像鸟飞往你的山》 目录 1. C 与哈希表&#xff1a;核心概念与引入 2. 哈希表的底层机制&#xff1a;原理与挑战 2.1 核心功能解析&#xff1a;效率与灵活性的平衡 2.2 哈希冲突的本质&#x…

12月第1周AI资讯

阅读时间:3-4min 更新时间:2024.12.2-2024.12.6 目录 OpenAI CEO Sam Altman 预告“12天OpenAI”系列活动 腾讯HunyuanVideo:130亿参数的开源视频生成模型 李飞飞的World Labs发布空间智能技术预览版 中科院联手腾讯打造“AI带货王”AnchorCrafter OpenAI CEO Sam Alt…

10_C语言 -数组(常规)

数组 引例 如果我们要在程序中表示一个学生的成绩&#xff0c;我们会使用一个int来表示&#xff0c;如&#xff1a;int score。假如我们要在程序中表示一组成绩&#xff0c;此时我们所学的常规数据类型就无法再表示&#xff0c;这个 时候我们就需要使用到一种新的表现形式&am…

红蓝对抗之Windows内网渗透

前言 无论是渗透测试&#xff0c;还是红蓝对抗&#xff0c;目的都是暴露风险&#xff0c;促进提升安全水平。企业往往在外网布置重兵把守&#xff0c;而内网防护相对来说千疮百孔&#xff0c;所以渗透高手往往通过攻击员工电脑、外网服务、职场WiFi等方式进入内网&#xff0c;…

Google推出 PaliGemma 2

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

Spring IoC的基本概念

引言 在 Java 中&#xff0c;出现了大量轻量级容器&#xff0c;这些容器有助于将来自不同项目的组件组装成一个有凝聚力的应用程序。这些容器的底层是它们如何执行布线的常见模式&#xff0c;它们将这一概念称为“控制反转”。 &#x1f3e2; 本章内容 &#x1f3ed; IoC服务…

图神经网络GNN入门

参考教程&#xff1a;A Gentle Introduction to Graph Neural Networks 图神经网络&#xff08;Graph Neural Networks&#xff0c;GNNs&#xff09;是一类专门用于处理图结构数据的神经网络&#xff0c;旨在通过节点、边和图的结构信息来学习图中节点和图的表示。GNN通过消息传…

卧式螺旋混合机搅拌机:饲料加工设备

卧式螺旋混合机搅拌机是一种用于饲料混合的设备&#xff0c;其结构特点为卧式&#xff0c;即搅拌桶体水平放置。这种设计使得物料在搅拌过程中能够充分混合&#xff0c;且搅拌效率高、混合均匀度好。卧式饲料混合机广泛应用于畜牧业、养殖业以及饲料加工行业&#xff0c;是饲料…

【北京迅为】iTOP-4412全能版使用手册-第四十二章 驱动注册

iTOP-4412全能版采用四核Cortex-A9&#xff0c;主频为1.4GHz-1.6GHz&#xff0c;配备S5M8767 电源管理&#xff0c;集成USB HUB,选用高品质板对板连接器稳定可靠&#xff0c;大厂生产&#xff0c;做工精良。接口一应俱全&#xff0c;开发更简单,搭载全网通4G、支持WIFI、蓝牙、…

交易系统:线上交易系统流程详解

大家好&#xff0c;我是汤师爷~ 今天聊聊线上交易系统流程详解。 线上交易系统为新零售连锁商家提供一站式线上交易解决方案。其核心目标是&#xff0c;通过数字化手段扩大商家的服务范围&#xff0c;突破传统门店的地理限制。系统支持电商、O2O等多种业务形态&#xff0c;为…

Postman接口测试详解

&#x1f345; 点击文末小卡片&#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 pre-request script 介绍 在过往的工作中&#xff0c;遇到很多测试小伙伴使用 postman 的时候都是直接通过 api 文档的描述请求&#xff0c;检查返回的数据是否正…

【单链表】(更新中...)

一、 题单 206.反转链表203.移除链表元素 876.链表的中间结点BM8 链表中倒数最后k个结点21.合并两个有序链表 二、题目简介及思路 206.反转链表 给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 思路简单&#xff0c;但是除了要两个指针进…

深入理解 SQL 注入:原理、攻击流程与防御措施

深入理解 SQL 注入&#xff1a;原理、攻击流程与防御措施 在当今数字化的时代&#xff0c;数据安全已成为每个企业和开发者必须面对的重要课题。SQL 注入&#xff08;SQL Injection&#xff09;作为一种常见的网络攻击方式&#xff0c;给无数企业带来了巨大的损失。本文将深入…

市场上显卡型号需求分析

两个平台统计&#xff1a;&#xff08;关键词统计&#xff0c;仅做参考&#xff09; GPU型号&#xff5c;平台 github(提交量/万) huggingface&#xff08;模型量/个&#xff09; H100 6.6 210 A100 17.2 483 V100 14.4 484 4090 27.3 31 3090 11.1 92 在git…

C# WPF抽奖程序

C# WPF抽奖程序 using Microsoft.Win32; using System; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; using System.Windows; using System.…

Master EDI 项目需求分析

Master Electronics 通过其全球分销网络&#xff0c;支持多种采购需求&#xff0c;确保能够为客户提供可靠的元件供应链解决方案&#xff0c;同时为快速高效的与全球伙伴建立合作&#xff0c;Master 选择通过EDI来实现与交易伙伴间的数据传输。 EDI为交易伙伴之间建立了一个安…