一般源码的测试代码涉及很多文件,因项目需要写一个独立测试的代码。传入的是字典
import time
import cv2
import os
import numpy as np
import torch
from modules.detec.models.common import DetectMultiBackend
from modules.detec.utils.dataloaders import LoadImages
from modules.detec.utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr,increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
from modules.detec.utils.augmentations import letterbox
from modules.detec.utils.plots import Annotator, colorsclass DetectionEstimation:def __init__(self, model_path, conf_threshold=0.9, iou_threshold=0.45, img_size=(384,640)):self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')self.model = DetectMultiBackend(model_path).to(self.device)self.conf_threshold = conf_thresholdself.iou_threshold = iou_thresholdself.img_size = img_sizedef _preprocess_image(self, img_dict):img_tensor_list = []original_sizes = {}for serial, img in img_dict.items():original_size = img.shape[:2]img_resized = letterbox(img, self.img_size, stride=32, auto=True)[0]img_resized = img_resized.transpose((2, 0, 1))[::-1]img_resized = np.ascontiguousarray(img_resized)img_tensor = torch.from_numpy(img_resized).float().to(self.device)img_normalized = img_tensor / 255if len(img_normalized.shape) == 3:img_normalized = img_normalized[None]img_tensor_list.append(img_normalized)original_sizes[serial] = original_sizeimage_input = torch.cat(img_tensor_list)return image_input, original_sizesdef _postprocess_predictions(self, predictions, original_sizes):results = {}for i, (det, (serial, img)) in enumerate(zip(predictions, original_sizes.items())):if det is not None and len(det):det[:, :4] = scale_boxes(self.img_size, det[:, :4], img).round()labels = []coordinates = []for *xyxy, conf, cls in reversed(det):label = self.model.names[int(cls)]labels.append((label, conf.item()))coordinates.append([xyxy[0].item(), xyxy[1].item(), xyxy[2].item(), xyxy[3].item()])results[serial] = {'labels': labels,'coordinates': coordinates}return resultsdef predict(self, img_dict):start_total = time.time()start_preprocess = time.time()img_tensor, original_sizes = self._preprocess_image(img_dict)preprocess_time = time.time() - start_preprocessprint(f"Preprocess Time: {preprocess_time * 1000:.2f}ms")start_inference = time.time()with torch.no_grad():predictions = self.model(img_tensor)inference_time = time.time() - start_inferenceprint(f"Inference Time:{inference_time * 1000:.2f}ms")start_non_max_suppression = time.time()predictions = non_max_suppression(predictions, self.conf_threshold, self.iou_threshold)non_max_suppression_time = time.time() - start_non_max_suppressionprint(f"Non-Max Suppression Time: {non_max_suppression_time * 1000:.2f}ms")start_postprocess = time.time()results = self._postprocess_predictions(predictions, original_sizes)postprocess_time = time.time() - start_postprocessprint(f"Postprocess Time: {postprocess_time * 1000:.2f}ms")total_time = time.time() - start_totalprint(f"Total Processing Time: {total_time * 1000:.2f}ms")print("res:",results)return resultsdef draw_results(self, img_dict, results):annotated_images = {}for serial, img in img_dict.items():if serial in results:det = results[serial]['coordinates'] # 从 results 中提取处理后的坐标labels = results[serial]['labels'] # 提取标签和置信度annotator = Annotator(img, line_width=3, example=self.model.names)for i, (xyxy, (label, conf)) in enumerate(zip(det, labels)):# 生成标签信息label_str = f'{label} {conf:.2f}'# 绘制检测框和标签annotator.box_label(xyxy, label_str, color=colors(i, True))annotated_images[serial] = annotator.result()return annotated_imagesdef _save_labels(self, results, output_folder, batch_size=3):os.makedirs(output_folder, exist_ok=True)img_serials = list(results.keys())for i in range(0, len(img_serials), batch_size):batch = img_serials[i:i + batch_size]combined_filename = '_'.join(batch) + '_labels.txt'labels_path = os.path.join(output_folder, combined_filename)with open(labels_path, 'w') as file:for serial in batch:if serial in results:result = results[serial]file.write("{\n")file.write(f" 'serial': '{result['serial']}',\n")file.write(f" 'labels': {result['labels']},\n")file.write(f" 'coordinates': {result['coordinates']},\n")file.write("}\n\n")if __name__ == "__main__":model_path = 'data/pt/best.pt'detector = DetectionEstimation(model_path)img_folder = './data/images/'img_dict = {}img_filenames = []for img_filename in os.listdir(img_folder):img_path = os.path.join(img_folder, img_filename)if img_path.lower().endswith(('.png', '.jpg', '.jpeg')):img_data = cv2.imread(img_path)serial = os.path.splitext(img_filename)[0]img_dict[serial] = img_dataimg_filenames.append(img_filename)batch_size = 2img_keys = list(img_dict.keys())for i in range(0, len(img_keys), batch_size):batch_dict = {k: img_dict[k] for k in img_keys[i:i + batch_size]}results = detector.predict(batch_dict)annotated_images = detector.draw_results(batch_dict, results)os.makedirs('results', exist_ok=True)for serial, img in annotated_images.items():output_path = f'results/{serial}.jpg'success = cv2.imwrite(output_path, img)if not success:print(f'Error saving image {output_path}')else:print(f'Successfully saved image {output_path}')detector._save_labels(results, 'results/labels', batch_size=batch_size)
在该代码同级目录下放models、results、utils文件夹和export.py
运行该代码得到的txt文件是字典: