结果:(前提是对应类别的yolov8模型已经训练好)
具体实现:
在ultralytics\utils\plotting.py里面
Ctrl+F搜索box_label
再次照片的最后一行输入:
# 左上角cv2.putText(self.im, f"({p1[0]}, {p1[1]})", (p1[0], p1[1]+20), 0, self.lw / 3, (0, 0, 225), thickness=2, lineType=cv2.LINE_AA)# 右上角cv2.putText(self.im, f"({p2[0]}, {p1[1]})", (p2[0], p1[1]+20), 0, self.lw / 3, (0, 0, 225), thickness=2, lineType=cv2.LINE_AA)# 左下角cv2.putText(self.im, f"({p1[0]}, {p2[1]})", (p1[0]-20, p2[1]+20), 0, self.lw / 3, (0, 0, 225), thickness=2, lineType=cv2.LINE_AA)# 右下角cv2.putText(self.im, f"({p2[0]}, {p2[1]})", (p2[0]-20, p2[1]+20), 0, self.lw / 3, (0, 0, 225), thickness=2, lineType=cv2.LINE_AA)
复制到相应位置测试文件及命令即可达到想要的效果
plotting.py完整代码如下:
# Ultralytics YOLO 🚀, AGPL-3.0 licenseimport contextlib
import math
import warnings
from pathlib import Path
from typing import Callable, Dict, List, Optional, Unionimport cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from PIL import __version__ as pil_versionfrom ultralytics.utils import LOGGER, TryExcept, ops, plt_settings, threaded
from ultralytics.utils.checks import check_font, check_version, is_ascii
from ultralytics.utils.files import increment_pathclass Colors:"""Ultralytics default color palette https://ultralytics.com/.This class provides methods to work with the Ultralytics color palette, including converting hex color codes toRGB values.Attributes:palette (list of tuple): List of RGB color values.n (int): The number of colors in the palette.pose_palette (np.ndarray): A specific color palette array with dtype np.uint8."""def __init__(self):"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""hexs = ("042AFF","0BDBEB","F3F3F3","00DFB7","111F68","FF6FDD","FF444F","CCED00","00F344","BD00FF","00B4FF","DD00BA","00FFFF","26C000","01FFB3","7D24FF","7B0068","FF1B6C","FC6D2F","A2FF0B",)self.palette = [self.hex2rgb(f"#{c}") for c in hexs]self.n = len(self.palette)self.pose_palette = np.array([[255, 128, 0],[255, 153, 51],[255, 178, 102],[230, 230, 0],[255, 153, 255],[153, 204, 255],[255, 102, 255],[255, 51, 255],[102, 178, 255],[51, 153, 255],[255, 153, 153],[255, 102, 102],[255, 51, 51],[153, 255, 153],[102, 255, 102],[51, 255, 51],[0, 255, 0],[0, 0, 255],[255, 0, 0],[255, 255, 255],],dtype=np.uint8,)def __call__(self, i, bgr=False):"""Converts hex color codes to RGB values."""c = self.palette[int(i) % self.n]return (c[2], c[1], c[0]) if bgr else c@staticmethoddef hex2rgb(h):"""Converts hex color codes to RGB values (i.e. default PIL order)."""return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))colors = Colors() # create instance for 'from utils.plots import colors'class Annotator:"""Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.Attributes:im (Image.Image or numpy array): The image to annotate.pil (bool): Whether to use PIL or cv2 for drawing annotations.font (ImageFont.truetype or ImageFont.load_default): Font used for text annotations.lw (float): Line width for drawing.skeleton (List[List[int]]): Skeleton structure for keypoints.limb_color (List[int]): Color palette for limbs.kpt_color (List[int]): Color palette for keypoints."""def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"):"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillicinput_is_pil = isinstance(im, Image.Image)self.pil = pil or non_ascii or input_is_pilself.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2)if self.pil: # use PILself.im = im if input_is_pil else Image.fromarray(im)self.draw = ImageDraw.Draw(self.im)try:font = check_font("Arial.Unicode.ttf" if non_ascii else font)size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)self.font = ImageFont.truetype(str(font), size)except Exception:self.font = ImageFont.load_default()# Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string)if check_version(pil_version, "9.2.0"):self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, heightelse: # use cv2assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images."self.im = im if im.flags.writeable else im.copy()self.tf = max(self.lw - 1, 1) # font thicknessself.sf = self.lw / 3 # font scale# Poseself.skeleton = [[16, 14],[14, 12],[17, 15],[15, 13],[12, 13],[6, 12],[7, 13],[6, 7],[6, 8],[7, 9],[8, 10],[9, 11],[2, 3],[1, 2],[1, 3],[2, 4],[3, 5],[4, 6],[5, 7],]self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]self.dark_colors = {(235, 219, 11),(243, 243, 243),(183, 223, 0),(221, 111, 255),(0, 237, 204),(68, 243, 0),(255, 255, 0),(179, 255, 1),(11, 255, 162),}self.light_colors = {(255, 42, 4),(79, 68, 255),(255, 0, 189),(255, 180, 0),(186, 0, 221),(0, 192, 38),(255, 36, 125),(104, 0, 123),(108, 27, 255),(47, 109, 252),(104, 31, 17),}def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):"""Assign text color based on background color."""if color in self.dark_colors:return 104, 31, 17elif color in self.light_colors:return 255, 255, 255else:return txt_colordef circle_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=2):"""Draws a label with a background circle centered within a given bounding box.Args:box (tuple): The bounding box coordinates (x1, y1, x2, y2).label (str): The text label to be displayed.color (tuple, optional): The background color of the rectangle (B, G, R).txt_color (tuple, optional): The color of the text (R, G, B).margin (int, optional): The margin between the text and the rectangle border."""# If label have more than 3 characters, skip other characters, due to circle sizeif len(label) > 3:print(f"Length of label is {len(label)}, initial 3 label characters will be considered for circle annotation!")label = label[:3]# Calculate the center of the boxx_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)# Get the text sizetext_size = cv2.getTextSize(str(label), cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0]# Calculate the required radius to fit the text with the marginrequired_radius = int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin# Draw the circle with the required radiuscv2.circle(self.im, (x_center, y_center), required_radius, color, -1)# Calculate the position for the texttext_x = x_center - text_size[0] // 2text_y = y_center + text_size[1] // 2# Draw the textcv2.putText(self.im,str(label),(text_x, text_y),cv2.FONT_HERSHEY_SIMPLEX,self.sf - 0.15,self.get_txt_color(color, txt_color),self.tf,lineType=cv2.LINE_AA,)def text_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=5):"""Draws a label with a background rectangle centered within a given bounding box.Args:box (tuple): The bounding box coordinates (x1, y1, x2, y2).label (str): The text label to be displayed.color (tuple, optional): The background color of the rectangle (B, G, R).txt_color (tuple, optional): The color of the text (R, G, B).margin (int, optional): The margin between the text and the rectangle border."""# Calculate the center of the bounding boxx_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)# Get the size of the texttext_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.1, self.tf)[0]# Calculate the top-left corner of the text (to center it)text_x = x_center - text_size[0] // 2text_y = y_center + text_size[1] // 2# Calculate the coordinates of the background rectanglerect_x1 = text_x - marginrect_y1 = text_y - text_size[1] - marginrect_x2 = text_x + text_size[0] + marginrect_y2 = text_y + margin# Draw the background rectanglecv2.rectangle(self.im, (rect_x1, rect_y1), (rect_x2, rect_y2), color, -1)# Draw the text on top of the rectanglecv2.putText(self.im,label,(text_x, text_y),cv2.FONT_HERSHEY_SIMPLEX,self.sf - 0.1,self.get_txt_color(color, txt_color),self.tf,lineType=cv2.LINE_AA,)def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):"""Draws a bounding box to image with label.Args:box (tuple): The bounding box coordinates (x1, y1, x2, y2).label (str): The text label to be displayed.color (tuple, optional): The background color of the rectangle (B, G, R).txt_color (tuple, optional): The color of the text (R, G, B).rotated (bool, optional): Variable used to check if task is OBB"""txt_color = self.get_txt_color(color, txt_color)if isinstance(box, torch.Tensor):box = box.tolist()if self.pil or not is_ascii(label):if rotated:p1 = box[0]self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) # PIL requires tuple boxelse:p1 = (box[0], box[1])self.draw.rectangle(box, width=self.lw, outline=color) # boxif label:w, h = self.font.getsize(label) # text width, heightoutside = p1[1] >= h # label fits outside boxif p1[0] > self.im.size[0] - w: # size is (w, h), check if label extend beyond right side of imagep1 = self.im.size[0] - w, p1[1]self.draw.rectangle((p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1),fill=color,)# self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font)else: # cv2if rotated:p1 = [int(b) for b in box[0]]cv2.polylines(self.im, [np.asarray(box, dtype=int)], True, color, self.lw) # cv2 requires nparray boxelse:p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)# 在预测图中绘制一个中心坐标红点# center_x = (p1[0] + p2[0]) // 2# center_y = (p1[1] + p2[1]) // 2# cv2.circle(self.im, (center_x, center_y), self.lw, (0, 0, 225), self.lw)# 创建中心点坐标变量# Center = (center_x, center_y)# 用于在图像上添加文本# cv2.putText(self.im, str(Center), (center_x, center_y), 0, self.lw / 3, (0, 0, 225), thickness=4, lineType=cv2.LINE_AA)# 添加左上角、右上角、左下角和右下角的坐标# 左上角cv2.putText(self.im, f"({p1[0]}, {p1[1]})", (p1[0], p1[1]+20), 0, self.lw / 3, (0, 0, 225), thickness=2, lineType=cv2.LINE_AA)# 右上角cv2.putText(self.im, f"({p2[0]}, {p1[1]})", (p2[0], p1[1]+20), 0, self.lw / 3, (0, 0, 225), thickness=2, lineType=cv2.LINE_AA)# 左下角cv2.putText(self.im, f"({p1[0]}, {p2[1]})", (p1[0]-20, p2[1]+20), 0, self.lw / 3, (0, 0, 225), thickness=2, lineType=cv2.LINE_AA)# 右下角cv2.putText(self.im, f"({p2[0]}, {p2[1]})", (p2[0]-20, p2[1]+20), 0, self.lw / 3, (0, 0, 225), thickness=2, lineType=cv2.LINE_AA)if label:w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, heighth += 3 # add pixels to pad textoutside = p1[1] >= h # label fits outside boxif p1[0] > self.im.shape[1] - w: # shape is (h, w), check if label extend beyond right side of imagep1 = self.im.shape[1] - w, p1[1]p2 = p1[0] + w, p1[1] - h if outside else p1[1] + hcv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filledcv2.putText(self.im,label,(p1[0], p1[1] - 2 if outside else p1[1] + h - 1),0,self.sf,txt_color,thickness=self.tf,lineType=cv2.LINE_AA,)def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):"""Plot masks on image.Args:masks (tensor): Predicted masks on cuda, shape: [n, h, w]colors (List[List[Int]]): Colors for predicted masks, [[r, g, b] * n]im_gpu (tensor): Image is in cuda, shape: [3, h, w], range: [0, 1]alpha (float): Mask transparency: 0.0 fully transparent, 1.0 opaqueretina_masks (bool): Whether to use high resolution masks or not. Defaults to False."""if self.pil:# Convert to numpy firstself.im = np.asarray(self.im).copy()if len(masks) == 0:self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255if im_gpu.device != masks.device:im_gpu = im_gpu.to(masks.device)colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)colors = colors[:, None, None] # shape(n,1,1,3)masks = masks.unsqueeze(3) # shape(n,h,w,1)masks_color = masks * (colors * alpha) # shape(n,h,w,3)inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)mcs = masks_color.max(dim=0).values # shape(n,h,w,3)im_gpu = im_gpu.flip(dims=[0]) # flip channelim_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)im_gpu = im_gpu * inv_alpha_masks[-1] + mcsim_mask = im_gpu * 255im_mask_np = im_mask.byte().cpu().numpy()self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)if self.pil:# Convert im back to PIL and update drawself.fromarray(self.im)def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True, conf_thres=0.25, kpt_color=None):"""Plot keypoints on the image.Args:kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).shape (tuple, optional): Image shape (h, w). Defaults to (640, 640).radius (int, optional): Keypoint radius. Defaults to 5.kpt_line (bool, optional): Draw lines between keypoints. Defaults to True.conf_thres (float, optional): Confidence threshold. Defaults to 0.25.kpt_color (tuple, optional): Keypoint color (B, G, R). Defaults to None.Note:- `kpt_line=True` currently only supports human pose plotting.- Modifies self.im in-place.- If self.pil is True, converts image to numpy array and back to PIL."""if self.pil:# Convert to numpy firstself.im = np.asarray(self.im).copy()nkpt, ndim = kpts.shapeis_pose = nkpt == 17 and ndim in {2, 3}kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plottingfor i, k in enumerate(kpts):color_k = kpt_color or (self.kpt_color[i].tolist() if is_pose else colors(i))x_coord, y_coord = k[0], k[1]if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:if len(k) == 3:conf = k[2]if conf < conf_thres:continuecv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA)if kpt_line:ndim = kpts.shape[-1]for i, sk in enumerate(self.skeleton):pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1]))pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1]))if ndim == 3:conf1 = kpts[(sk[0] - 1), 2]conf2 = kpts[(sk[1] - 1), 2]if conf1 < conf_thres or conf2 < conf_thres:continueif pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0:continueif pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0:continuecv2.line(self.im,pos1,pos2,kpt_color or self.limb_color[i].tolist(),thickness=2,lineType=cv2.LINE_AA,)if self.pil:# Convert im back to PIL and update drawself.fromarray(self.im)def rectangle(self, xy, fill=None, outline=None, width=1):"""Add rectangle to image (PIL-only)."""self.draw.rectangle(xy, fill, outline, width)def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False):"""Adds text to an image using PIL or cv2."""if anchor == "bottom": # start y from font bottomw, h = self.font.getsize(text) # text width, heightxy[1] += 1 - hif self.pil:if box_style:w, h = self.font.getsize(text)self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)# Using `txt_color` for background and draw fg with white colortxt_color = (255, 255, 255)if "\n" in text:lines = text.split("\n")_, h = self.font.getsize(text)for line in lines:self.draw.text(xy, line, fill=txt_color, font=self.font)xy[1] += helse:self.draw.text(xy, text, fill=txt_color, font=self.font)else:if box_style:w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, heighth += 3 # add pixels to pad textoutside = xy[1] >= h # label fits outside boxp2 = xy[0] + w, xy[1] - h if outside else xy[1] + hcv2.rectangle(self.im, xy, p2, txt_color, -1, cv2.LINE_AA) # filled# Using `txt_color` for background and draw fg with white colortxt_color = (255, 255, 255)cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)def fromarray(self, im):"""Update self.im from a numpy array."""self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)self.draw = ImageDraw.Draw(self.im)def result(self):"""Return annotated image as array."""return np.asarray(self.im)def show(self, title=None):"""Show the annotated image."""Image.fromarray(np.asarray(self.im)[..., ::-1]).show(title)def save(self, filename="image.jpg"):"""Save the annotated image to 'filename'."""cv2.imwrite(filename, np.asarray(self.im))def get_bbox_dimension(self, bbox=None):"""Calculate the area of a bounding box.Args:bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).Returns:angle (degree): Degree value of angle between three points"""x_min, y_min, x_max, y_max = bboxwidth = x_max - x_minheight = y_max - y_minreturn width, height, width * heightdef draw_region(self, reg_pts=None, color=(0, 255, 0), thickness=5):"""Draw region line.Args:reg_pts (list): Region Points (for line 2 points, for region 4 points)color (tuple): Region Color valuethickness (int): Region area thickness value"""cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2):"""Draw centroid point and track trails.Args:track (list): object tracking points for trails displaycolor (tuple): tracks line colortrack_thickness (int): track line thickness value"""points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness)cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)def queue_counts_display(self, label, points=None, region_color=(255, 255, 255), txt_color=(0, 0, 0)):"""Displays queue counts on an image centered at the points with customizable font size and colors.Args:label (str): queue counts labelpoints (tuple): region points for center point calculation to display textregion_color (RGB): queue region colortxt_color (RGB): text display color"""x_values = [point[0] for point in points]y_values = [point[1] for point in points]center_x = sum(x_values) // len(points)center_y = sum(y_values) // len(points)text_size = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0]text_width = text_size[0]text_height = text_size[1]rect_width = text_width + 20rect_height = text_height + 20rect_top_left = (center_x - rect_width // 2, center_y - rect_height // 2)rect_bottom_right = (center_x + rect_width // 2, center_y + rect_height // 2)cv2.rectangle(self.im, rect_top_left, rect_bottom_right, region_color, -1)text_x = center_x - text_width // 2text_y = center_y + text_height // 2# Draw textcv2.putText(self.im,label,(text_x, text_y),0,fontScale=self.sf,color=txt_color,thickness=self.tf,lineType=cv2.LINE_AA,)def display_objects_labels(self, im0, text, txt_color, bg_color, x_center, y_center, margin):"""Display the bounding boxes labels in parking management app.Args:im0 (ndarray): inference imagetext (str): object/class nametxt_color (bgr color): display color for text foregroundbg_color (bgr color): display color for text backgroundx_center (float): x position center point for bounding boxy_center (float): y position center point for bounding boxmargin (int): gap between text and rectangle for better display"""text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]text_x = x_center - text_size[0] // 2text_y = y_center + text_size[1] // 2rect_x1 = text_x - marginrect_y1 = text_y - text_size[1] - marginrect_x2 = text_x + text_size[0] + marginrect_y2 = text_y + margincv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)cv2.putText(im0, text, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)def display_analytics(self, im0, text, txt_color, bg_color, margin):"""Display the overall statistics for parking lots.Args:im0 (ndarray): inference imagetext (dict): labels dictionarytxt_color (bgr color): display color for text foregroundbg_color (bgr color): display color for text backgroundmargin (int): gap between text and rectangle for better display"""horizontal_gap = int(im0.shape[1] * 0.02)vertical_gap = int(im0.shape[0] * 0.01)text_y_offset = 0for label, value in text.items():txt = f"{label}: {value}"text_size = cv2.getTextSize(txt, 0, self.sf, self.tf)[0]if text_size[0] < 5 or text_size[1] < 5:text_size = (5, 5)text_x = im0.shape[1] - text_size[0] - margin * 2 - horizontal_gaptext_y = text_y_offset + text_size[1] + margin * 2 + vertical_gaprect_x1 = text_x - margin * 2rect_y1 = text_y - text_size[1] - margin * 2rect_x2 = text_x + text_size[0] + margin * 2rect_y2 = text_y + margin * 2cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)cv2.putText(im0, txt, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)text_y_offset = rect_y2@staticmethoddef estimate_pose_angle(a, b, c):"""Calculate the pose angle for object.Args:a (float) : The value of pose point ab (float): The value of pose point bc (float): The value o pose point cReturns:angle (degree): Degree value of angle between three points"""a, b, c = np.array(a), np.array(b), np.array(c)radians = np.arctan2(c[1] - b[1], c[0] - b[0]) - np.arctan2(a[1] - b[1], a[0] - b[0])angle = np.abs(radians * 180.0 / np.pi)if angle > 180.0:angle = 360 - anglereturn angledef draw_specific_points(self, keypoints, indices=None, shape=(640, 640), radius=2, conf_thres=0.25):"""Draw specific keypoints for gym steps counting.Args:keypoints (list): Keypoints data to be plotted.indices (list, optional): Keypoint indices to be plotted. Defaults to [2, 5, 7].shape (tuple, optional): Image size for model inference. Defaults to (640, 640).radius (int, optional): Keypoint radius. Defaults to 2.conf_thres (float, optional): Confidence threshold for keypoints. Defaults to 0.25.Returns:(numpy.ndarray): Image with drawn keypoints.Note:Keypoint format: [x, y] or [x, y, confidence].Modifies self.im in-place."""if indices is None:indices = [2, 5, 7]for i, k in enumerate(keypoints):if i in indices:x_coord, y_coord = k[0], k[1]if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:if len(k) == 3:conf = k[2]if conf < conf_thres:continuecv2.circle(self.im, (int(x_coord), int(y_coord)), radius, (0, 255, 0), -1, lineType=cv2.LINE_AA)return self.imdef plot_angle_and_count_and_stage(self, angle_text, count_text, stage_text, center_kpt, color=(104, 31, 17), txt_color=(255, 255, 255)):"""Plot the pose angle, count value and step stage.Args:angle_text (str): angle value for workout monitoringcount_text (str): counts value for workout monitoringstage_text (str): stage decision for workout monitoringcenter_kpt (list): centroid pose index for workout monitoringcolor (tuple): text background color for workout monitoringtxt_color (tuple): text foreground color for workout monitoring"""angle_text, count_text, stage_text = (f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}")# Draw angle(angle_text_width, angle_text_height), _ = cv2.getTextSize(angle_text, 0, self.sf, self.tf)angle_text_position = (int(center_kpt[0]), int(center_kpt[1]))angle_background_position = (angle_text_position[0], angle_text_position[1] - angle_text_height - 5)angle_background_size = (angle_text_width + 2 * 5, angle_text_height + 2 * 5 + (self.tf * 2))cv2.rectangle(self.im,angle_background_position,(angle_background_position[0] + angle_background_size[0],angle_background_position[1] + angle_background_size[1],),color,-1,)cv2.putText(self.im, angle_text, angle_text_position, 0, self.sf, txt_color, self.tf)# Draw Counts(count_text_width, count_text_height), _ = cv2.getTextSize(count_text, 0, self.sf, self.tf)count_text_position = (angle_text_position[0], angle_text_position[1] + angle_text_height + 20)count_background_position = (angle_background_position[0],angle_background_position[1] + angle_background_size[1] + 5,)count_background_size = (count_text_width + 10, count_text_height + 10 + self.tf)cv2.rectangle(self.im,count_background_position,(count_background_position[0] + count_background_size[0],count_background_position[1] + count_background_size[1],),color,-1,)cv2.putText(self.im, count_text, count_text_position, 0, self.sf, txt_color, self.tf)# Draw Stage(stage_text_width, stage_text_height), _ = cv2.getTextSize(stage_text, 0, self.sf, self.tf)stage_text_position = (int(center_kpt[0]), int(center_kpt[1]) + angle_text_height + count_text_height + 40)stage_background_position = (stage_text_position[0], stage_text_position[1] - stage_text_height - 5)stage_background_size = (stage_text_width + 10, stage_text_height + 10)cv2.rectangle(self.im,stage_background_position,(stage_background_position[0] + stage_background_size[0],stage_background_position[1] + stage_background_size[1],),color,-1,)cv2.putText(self.im, stage_text, stage_text_position, 0, self.sf, txt_color, self.tf)def seg_bbox(self, mask, mask_color=(255, 0, 255), label=None, txt_color=(255, 255, 255)):"""Function for drawing segmented object in bounding box shape.Args:mask (list): masks data list for instance segmentation area plottingmask_color (RGB): mask foreground colorlabel (str): Detection label texttxt_color (RGB): text color"""cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)text_size, _ = cv2.getTextSize(label, 0, self.sf, self.tf)cv2.rectangle(self.im,(int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),(int(mask[0][0]) + text_size[0] // 2 + 10, int(mask[0][1] + 10)),mask_color,-1,)if label:cv2.putText(self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1])), 0, self.sf, txt_color, self.tf)def plot_distance_and_line(self, distance_m, distance_mm, centroids, line_color, centroid_color):"""Plot the distance and line on frame.Args:distance_m (float): Distance between two bbox centroids in meters.distance_mm (float): Distance between two bbox centroids in millimeters.centroids (list): Bounding box centroids data.line_color (RGB): Distance line color.centroid_color (RGB): Bounding box centroid color."""(text_width_m, text_height_m), _ = cv2.getTextSize(f"Distance M: {distance_m:.2f}m", 0, self.sf, self.tf)cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 10, 25 + text_height_m + 20), line_color, -1)cv2.putText(self.im,f"Distance M: {distance_m:.2f}m",(20, 50),0,self.sf,centroid_color,self.tf,cv2.LINE_AA,)(text_width_mm, text_height_mm), _ = cv2.getTextSize(f"Distance MM: {distance_mm:.2f}mm", 0, self.sf, self.tf)cv2.rectangle(self.im, (15, 75), (15 + text_width_mm + 10, 75 + text_height_mm + 20), line_color, -1)cv2.putText(self.im,f"Distance MM: {distance_mm:.2f}mm",(20, 100),0,self.sf,centroid_color,self.tf,cv2.LINE_AA,)cv2.line(self.im, centroids[0], centroids[1], line_color, 3)cv2.circle(self.im, centroids[0], 6, centroid_color, -1)cv2.circle(self.im, centroids[1], 6, centroid_color, -1)def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255)):"""Function for pinpoint human-vision eye mapping and plotting.Args:box (list): Bounding box coordinatescenter_point (tuple): center point for vision eye viewcolor (tuple): object centroid and line color valuepin_color (tuple): visioneye point color value"""center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)cv2.circle(self.im, center_point, self.tf * 2, pin_color, -1)cv2.circle(self.im, center_bbox, self.tf * 2, color, -1)cv2.line(self.im, center_point, center_bbox, color, self.tf)@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
@plt_settings()
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):"""Plot training labels including class histograms and box statistics."""import pandas # scope for faster 'import ultralytics'import seaborn # scope for faster 'import ultralytics'# Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarningswarnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")warnings.filterwarnings("ignore", category=FutureWarning)# Plot dataset labelsLOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")nc = int(cls.max() + 1) # number of classesboxes = boxes[:1000000] # limit to 1M boxesx = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"])# Seaborn correlogramseaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)plt.close()# Matplotlib labelsax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)for i in range(nc):y[2].patches[i].set_color([x / 255 for x in colors(i)])ax[0].set_ylabel("instances")if 0 < len(names) < 30:ax[0].set_xticks(range(len(names)))ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)else:ax[0].set_xlabel("classes")seaborn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)seaborn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)# Rectanglesboxes[:, 0:2] = 0.5 # centerboxes = ops.xywh2xyxy(boxes) * 1000img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)for cls, box in zip(cls[:500], boxes[:500]):ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plotax[1].imshow(img)ax[1].axis("off")for a in [0, 1, 2, 3]:for s in ["top", "right", "left", "bottom"]:ax[a].spines[s].set_visible(False)fname = save_dir / "labels.jpg"plt.savefig(fname, dpi=200)plt.close()if on_plot:on_plot(fname)def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True):"""Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.This function takes a bounding box and an image, and then saves a cropped portion of the image accordingto the bounding box. Optionally, the crop can be squared, and the function allows for gain and paddingadjustments to the bounding box.Args:xyxy (torch.Tensor or list): A tensor or list representing the bounding box in xyxy format.im (numpy.ndarray): The input image.file (Path, optional): The path where the cropped image will be saved. Defaults to 'im.jpg'.gain (float, optional): A multiplicative factor to increase the size of the bounding box. Defaults to 1.02.pad (int, optional): The number of pixels to add to the width and height of the bounding box. Defaults to 10.square (bool, optional): If True, the bounding box will be transformed into a square. Defaults to False.BGR (bool, optional): If True, the image will be saved in BGR format, otherwise in RGB. Defaults to False.save (bool, optional): If True, the cropped image will be saved to disk. Defaults to True.Returns:(numpy.ndarray): The cropped image.Example:```pythonfrom ultralytics.utils.plotting import save_one_boxxyxy = [50, 50, 150, 150]im = cv2.imread("image.jpg")cropped_im = save_one_box(xyxy, im, file="cropped.jpg", square=True)```"""if not isinstance(xyxy, torch.Tensor): # may be listxyxy = torch.stack(xyxy)b = ops.xyxy2xywh(xyxy.view(-1, 4)) # boxesif square:b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to squareb[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + padxyxy = ops.xywh2xyxy(b).long()xyxy = ops.clip_boxes(xyxy, im.shape)crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)]if save:file.parent.mkdir(parents=True, exist_ok=True) # make directoryf = str(increment_path(file).with_suffix(".jpg"))# cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issueImage.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGBreturn crop@threaded
def plot_images(images: Union[torch.Tensor, np.ndarray],batch_idx: Union[torch.Tensor, np.ndarray],cls: Union[torch.Tensor, np.ndarray],bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32),confs: Optional[Union[torch.Tensor, np.ndarray]] = None,masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8),kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32),paths: Optional[List[str]] = None,fname: str = "images.jpg",names: Optional[Dict[int, str]] = None,on_plot: Optional[Callable] = None,max_size: int = 1920,max_subplots: int = 16,save: bool = True,conf_thres: float = 0.25,
) -> Optional[np.ndarray]:"""Plot image grid with labels, bounding boxes, masks, and keypoints.Args:images: Batch of images to plot. Shape: (batch_size, channels, height, width).batch_idx: Batch indices for each detection. Shape: (num_detections,).cls: Class labels for each detection. Shape: (num_detections,).bboxes: Bounding boxes for each detection. Shape: (num_detections, 4) or (num_detections, 5) for rotated boxes.confs: Confidence scores for each detection. Shape: (num_detections,).masks: Instance segmentation masks. Shape: (num_detections, height, width) or (1, height, width).kpts: Keypoints for each detection. Shape: (num_detections, 51).paths: List of file paths for each image in the batch.fname: Output filename for the plotted image grid.names: Dictionary mapping class indices to class names.on_plot: Optional callback function to be called after saving the plot.max_size: Maximum size of the output image grid.max_subplots: Maximum number of subplots in the image grid.save: Whether to save the plotted image grid to a file.conf_thres: Confidence threshold for displaying detections.Returns:np.ndarray: Plotted image grid as a numpy array if save is False, None otherwise.Note:This function supports both tensor and numpy array inputs. It will automaticallyconvert tensor inputs to numpy arrays for processing."""if isinstance(images, torch.Tensor):images = images.cpu().float().numpy()if isinstance(cls, torch.Tensor):cls = cls.cpu().numpy()if isinstance(bboxes, torch.Tensor):bboxes = bboxes.cpu().numpy()if isinstance(masks, torch.Tensor):masks = masks.cpu().numpy().astype(int)if isinstance(kpts, torch.Tensor):kpts = kpts.cpu().numpy()if isinstance(batch_idx, torch.Tensor):batch_idx = batch_idx.cpu().numpy()bs, _, h, w = images.shape # batch size, _, height, widthbs = min(bs, max_subplots) # limit plot imagesns = np.ceil(bs**0.5) # number of subplots (square)if np.max(images[0]) <= 1:images *= 255 # de-normalise (optional)# Build Imagemosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # initfor i in range(bs):x, y = int(w * (i // ns)), int(h * (i % ns)) # block originmosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)# Resize (optional)scale = max_size / ns / max(h, w)if scale < 1:h = math.ceil(scale * h)w = math.ceil(scale * w)mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))# Annotatefs = int((h + w) * ns * 0.01) # font sizeannotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)for i in range(bs):x, y = int(w * (i // ns)), int(h * (i % ns)) # block originannotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # bordersif paths:annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenamesif len(cls) > 0:idx = batch_idx == iclasses = cls[idx].astype("int")labels = confs is Noneif len(bboxes):boxes = bboxes[idx]conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)if len(boxes):if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1boxes[..., [0, 2]] *= w # scale to pixelsboxes[..., [1, 3]] *= helif scale < 1: # absolute coords need scale if image scalesboxes[..., :4] *= scaleboxes[..., 0] += xboxes[..., 1] += yis_obb = boxes.shape[-1] == 5 # xywhrboxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)for j, box in enumerate(boxes.astype(np.int64).tolist()):c = classes[j]color = colors(c)c = names.get(c, c) if names else cif labels or conf[j] > conf_thres:label = f"{c}" if labels else f"{c} {conf[j]:.1f}"annotator.box_label(box, label, color=color, rotated=is_obb)elif len(classes):for c in classes:color = colors(c)c = names.get(c, c) if names else cannotator.text((x, y), f"{c}", txt_color=color, box_style=True)# Plot keypointsif len(kpts):kpts_ = kpts[idx].copy()if len(kpts_):if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01kpts_[..., 0] *= w # scale to pixelskpts_[..., 1] *= helif scale < 1: # absolute coords need scale if image scaleskpts_ *= scalekpts_[..., 0] += xkpts_[..., 1] += yfor j in range(len(kpts_)):if labels or conf[j] > conf_thres:annotator.kpts(kpts_[j], conf_thres=conf_thres)# Plot masksif len(masks):if idx.shape[0] == masks.shape[0]: # overlap_masks=Falseimage_masks = masks[idx]else: # overlap_masks=Trueimage_masks = masks[[i]] # (1, 640, 640)nl = idx.sum()index = np.arange(nl).reshape((nl, 1, 1)) + 1image_masks = np.repeat(image_masks, nl, axis=0)image_masks = np.where(image_masks == index, 1.0, 0.0)im = np.asarray(annotator.im).copy()for j in range(len(image_masks)):if labels or conf[j] > conf_thres:color = colors(classes[j])mh, mw = image_masks[j].shapeif mh != h or mw != w:mask = image_masks[j].astype(np.uint8)mask = cv2.resize(mask, (w, h))mask = mask.astype(bool)else:mask = image_masks[j].astype(bool)with contextlib.suppress(Exception):im[y : y + h, x : x + w, :][mask] = (im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6)annotator.fromarray(im)if not save:return np.asarray(annotator.im)annotator.im.save(fname) # saveif on_plot:on_plot(fname)@plt_settings()
def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):"""Plot training results from a results CSV file. The function supports various types of data including segmentation,pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.Args:file (str, optional): Path to the CSV file containing the training results. Defaults to 'path/to/results.csv'.dir (str, optional): Directory where the CSV file is located if 'file' is not provided. Defaults to ''.segment (bool, optional): Flag to indicate if the data is for segmentation. Defaults to False.pose (bool, optional): Flag to indicate if the data is for pose estimation. Defaults to False.classify (bool, optional): Flag to indicate if the data is for classification. Defaults to False.on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.Defaults to None.Example:```pythonfrom ultralytics.utils.plotting import plot_resultsplot_results("path/to/results.csv", segment=True)```"""import pandas as pd # scope for faster 'import ultralytics'from scipy.ndimage import gaussian_filter1dsave_dir = Path(file).parent if file else Path(dir)if classify:fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)index = [1, 4, 2, 3]elif segment:fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)index = [1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]elif pose:fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)index = [1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 18, 8, 9, 12, 13]else:fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7]ax = ax.ravel()files = list(save_dir.glob("results*.csv"))assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."for f in files:try:data = pd.read_csv(f)s = [x.strip() for x in data.columns]x = data.values[:, 0]for i, j in enumerate(index):y = data.values[:, j].astype("float")# y[y == 0] = np.nan # don't show zero valuesax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual resultsax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing lineax[i].set_title(s[j], fontsize=12)# if j in {8, 9, 10}: # share train and val loss y axes# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])except Exception as e:LOGGER.warning(f"WARNING: Plotting error for {f}: {e}")ax[1].legend()fname = save_dir / "results.png"fig.savefig(fname, dpi=200)plt.close()if on_plot:on_plot(fname)def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):"""Plots a scatter plot with points colored based on a 2D histogram.Args:v (array-like): Values for the x-axis.f (array-like): Values for the y-axis.bins (int, optional): Number of bins for the histogram. Defaults to 20.cmap (str, optional): Colormap for the scatter plot. Defaults to 'viridis'.alpha (float, optional): Alpha for the scatter plot. Defaults to 0.8.edgecolors (str, optional): Edge colors for the scatter plot. Defaults to 'none'.Examples:>>> v = np.random.rand(100)>>> f = np.random.rand(100)>>> plt_color_scatter(v, f)"""# Calculate 2D histogram and corresponding colorshist, xedges, yedges = np.histogram2d(v, f, bins=bins)colors = [hist[min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1),]for i in range(len(v))]# Scatter plotplt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)def plot_tune_results(csv_file="tune_results.csv"):"""Plot the evolution results stored in an 'tune_results.csv' file. The function generates a scatter plot for each keyin the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.Args:csv_file (str, optional): Path to the CSV file containing the tuning results. Defaults to 'tune_results.csv'.Examples:>>> plot_tune_results("path/to/tune_results.csv")"""import pandas as pd # scope for faster 'import ultralytics'from scipy.ndimage import gaussian_filter1ddef _save_one_file(file):"""Save one matplotlib plot to 'file'."""plt.savefig(file, dpi=200)plt.close()LOGGER.info(f"Saved {file}")# Scatter plots for each hyperparametercsv_file = Path(csv_file)data = pd.read_csv(csv_file)num_metrics_columns = 1keys = [x.strip() for x in data.columns][num_metrics_columns:]x = data.valuesfitness = x[:, 0] # fitnessj = np.argmax(fitness) # max fitness indexn = math.ceil(len(keys) ** 0.5) # columns and rows in plotplt.figure(figsize=(10, 10), tight_layout=True)for i, k in enumerate(keys):v = x[:, i + num_metrics_columns]mu = v[j] # best single resultplt.subplot(n, n, i + 1)plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none")plt.plot(mu, fitness.max(), "k+", markersize=15)plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9}) # limit to 40 charactersplt.tick_params(axis="both", labelsize=8) # Set axis label size to 8if i % n != 0:plt.yticks([])_save_one_file(csv_file.with_name("tune_scatter_plots.png"))# Fitness vs iterationx = range(1, len(fitness) + 1)plt.figure(figsize=(10, 6), tight_layout=True)plt.plot(x, fitness, marker="o", linestyle="none", label="fitness")plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2) # smoothing lineplt.title("Fitness vs Iteration")plt.xlabel("Iteration")plt.ylabel("Fitness")plt.grid(True)plt.legend()_save_one_file(csv_file.with_name("tune_fitness.png"))def output_to_target(output, max_det=300):"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""targets = []for i, o in enumerate(output):box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)j = torch.full((conf.shape[0], 1), i)targets.append(torch.cat((j, cls, ops.xyxy2xywh(box), conf), 1))targets = torch.cat(targets, 0).numpy()return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]def output_to_rotated_target(output, max_det=300):"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""targets = []for i, o in enumerate(output):box, conf, cls, angle = o[:max_det].cpu().split((4, 1, 1, 1), 1)j = torch.full((conf.shape[0], 1), i)targets.append(torch.cat((j, cls, box, angle, conf), 1))targets = torch.cat(targets, 0).numpy()return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):"""Visualize feature maps of a given model module during inference.Args:x (torch.Tensor): Features to be visualized.module_type (str): Module type.stage (int): Module stage within the model.n (int, optional): Maximum number of feature maps to plot. Defaults to 32.save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp')."""for m in {"Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"}: # all model headsif m in module_type:returnif isinstance(x, torch.Tensor):_, channels, height, width = x.shape # batch, channels, height, widthif height > 1 and width > 1:f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filenameblocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channelsn = min(n, channels) # number of plots_, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 colsax = ax.ravel()plt.subplots_adjust(wspace=0.05, hspace=0.05)for i in range(n):ax[i].imshow(blocks[i].squeeze()) # cmap='gray'ax[i].axis("off")LOGGER.info(f"Saving {f}... ({n}/{channels})")plt.savefig(f, dpi=300, bbox_inches="tight")plt.close()np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save