在之前的文章中记录了YOLO环境的配置安装和基本命令的一些使用,上一篇博文的地址快速链接:从零开始使用YOLOv8——环境配置与极简指令(CLI)操作:1篇文章解决—直接使用:模型部署 and 自建数据集:训练微调-CSDN博客
使用YOLO作为目标检测任务的平台一个好处是,其搭建了非常简洁明了的训练命令行模式,可以便捷的对自建数据集进行微调。
在对自己数据集进行模型训练前,非常重要费时的就是对数据的预处理,包括数据清洗、统计信息分析、数据格式转换。本文专注于将自己数据 json 格式转为YOLO训练支持的 txt 数据格式,并给出可以复用的数据集构建代码,代码已上传至Gitee平台。
Gitee链接:https://gitee.com/machine-bai-xue/yolo-source-code-analysis
如果链接失效,访问404拒绝,可以直接在Gitee码云主页搜索——“机器白学”,所有项目中的YOLO源码实验就是本系列所有实验代码。
目录
一、初始自建数据集Json格式
1.文件存放格式
2.标签JSON格式
二、YOLO训练数据集创建类
1.直接使用
2.可视化检查
3.完整代码与扩展
一、初始自建数据集Json格式
1.文件存放格式
首先约定一下数据的初始格式,本文选择最简单的 JSON 列表数据集格式作标签保存。所有图片在一个文件夹(img)下,所有便签在另一个文件夹(json)下。
2.标签JSON格式
对标签数据具体来说,坐标数据(下图红框)和类别数据(下图蓝框)存放在同一个列表里,前四个为左上右下两点的xy绝对坐标值,类别字符信息放在后面。
(如果还有除了类别外的其余信息可以直接在列表后添加,最后构建数据集只需取出对应索引即可。例如,对于同一批图片数据和坐标数据可能存在多种分类任务,就官方coco例子来说,其给的类别是详细分类后的物体名称——长颈鹿、花瓶、杯子......如果想构建一个大致分类的检测模型,如将细化类改为抽象类名称——动物、装饰品、日用品......只需在列表后面继续添加即可,如下第一个图第一个框可以改为——【385, 60, 600, 357, “giraffe”, “动物”】)
二、YOLO训练数据集创建类
将任意训练数据集搭建成初始的Json格式并存放按文件夹存放后,即可使用下面的数据转换类生成多个符合训练标准的数据集格式——包括训练集train和验证集val、yaml配置文件、txt标签文件等部分。
1.直接使用
首先总览整个类的使用。首先确保 opencv-python(cv2)和 pillow(PIL)库正确安装在环境里了。
导入定义的转化类(可在文章最后直接复制,或者在Gitee地址下载对应py文件),实例初始化。初始化中三个关于文件地址的基本参数是必须存在的。
初始化基本参数按输入顺序含义归纳在下面表格。
img_path | 所有图片存放文件地址(str) |
label_path | 所有Json格式标签数据文件地址(str) |
save_path | Yolo数据集结果保存地址(str) |
另外初始化中还有几个可以调整的附加参数。其影响数据集搭建的某些细节部分。
train_ratio | 训练集占总数据量的比重,小数数据格式(float) |
cls_id | 训练数据集中类别标签在原始数据中的索引位置(int,>=4) |
seed | 设置打乱数据集文件名的随机种子——随机分配训练和验证数据(int) |
配置完参数后,直接使用类下的 dataset_main() 方法就可以自动生成训练验证数据集和yaml配置文件了。
所有数据按照设定的划分比例随机采样分开。
其中 cls_freq.json 是统计所有类别出现的频率字典,字典键对应类别名,值对应出现的次数。可以根据其频率查看哪些类训练样本偏少,决定是否要进行数据增强操作。
.yaml 文件是YOLO训练的配置文件。其中names是按出现频率排序的类别和标签索引对。
2.可视化检查
转换类中还定义了一些可视化函数,可以检查数据是否正确。其中对于初始数据只需直接使用visual_json_main() 方法即可。
还可以直接使用类中的可视化函数,进行自定义的检查。
3.完整代码与扩展
下面将完整类代码放在下面,可以对其中相应函数方法进行修改实现扩展任务。欢迎批评指正。
import os
import json
import random
import cv2
import yaml
import numpy as np
from PIL import Image, ImageDraw, ImageFontclass YOLO_Dataset_Creator:def __init__(self, img_path, label_path, save_path, train_ratio=0.85, cls_id=4, seed=42):self.img_path, self.labels, self.data = img_path, label_path, save_pathself.train = train_ratioself.cls = Noneself.cls_id = cls_idself.seed = seeddef dataset_main(self):# 读取图片信息划分train和val集self.tr_name, self.val_name = self.divide_dataset(self.labels)# 根据json中类别信息生成配置yaml文件self.cls = self.generate_yaml(self.labels)# 生成yolo的txt训练数据集self.dataset_create(self.data)def divide_dataset(self, json_path):# 根据图片获取所有文件名信息total_file_list = []for file in os.listdir(json_path):if file.lower().split('.')[-1] in ['json']:base = file.split('.')[0]total_file_list.append(base)# 随机打乱后按比例生成训练train和验证valrandom.seed(self.seed)random.shuffle(total_file_list)length = len(total_file_list)tr_file_list = [tr for tr in total_file_list[:int(self.train * length)]]val_file_list = [te for te in total_file_list[int(self.train * length):]]return tr_file_list, val_file_listdef statis_info(self, json_path):cls_dict = dict()for file in os.listdir(json_path):if file.lower().split('.')[-1] in ['json']:jsondir = os.path.join(json_path, file)with open(jsondir, 'r', encoding='utf-8') as f:box_cls_list = json.load(f)for box_cls in box_cls_list:cls = box_cls[self.cls_id]if str(cls) not in cls_dict.keys():cls_dict[str(cls)] = 1else:cls_dict[str(cls)] +=1cls_dictdir = os.path.join(self.data, 'cls_freq.json')with open(cls_dictdir, 'w') as f:json.dump(cls_dict, f)return cls_dictdef generate_yaml(self, json_path):# 获得类别频率字典cls_dict = self.statis_info(json_path)# 生成yaml配置文件sorted_cls = sorted(cls_dict, key=cls_dict.get, reverse=True)names_dict = {}clses_dict = {}for id, c in enumerate(sorted_cls):names_dict[id] = cclses_dict[c] = idyaml_dict = {"path":self.data,"train":"images/train","val":"images/val","names":names_dict}yaml_savedir = os.path.join(self.data, 'HP_Data.yaml')with open(yaml_savedir, "w") as f:yaml.dump(yaml_dict,f)print('yaml success')return clses_dictdef dataset_create(self,data_path):# 必要子文件生成tr_img = os.path.join(data_path, 'images/train')va_img = os.path.join(data_path, 'images/val')tr_lab = os.path.join(data_path, 'labels/train')va_lab = os.path.join(data_path, 'labels/val')# 创建文件夹for file in [tr_img, va_img, tr_lab, va_lab]:os.makedirs(file, exist_ok=True)# train和valself._train(tr_img, tr_lab)self._val(va_img, va_lab)def _train(self, tr_img, tr_lab):# 生产训练集trainfor name in self.tr_name:print(name, 'start')# 图片复制保存移动imgsave = os.path.join(tr_img, name+'.jpg')jpgdir = os.path.join(self.img_path, name+'.jpg')# 标签txtsave = os.path.join(tr_lab, name+'.txt')jsondir = os.path.join(self.labels, name+'.json')self.label_create(jsondir, txtsave, name, jpgdir, imgsave)def _val(self, va_img, va_lab):# 生产验证集valfor name in self.val_name:# 图片复制保存移动imgsave = os.path.join(va_img, name + '.jpg')jpgdir = os.path.join(self.img_path, name + '.jpg')# 标签txtsave = os.path.join(va_lab, name + '.txt')jsondir = os.path.join(self.labels, name + '.json')self.label_create(jsondir, txtsave, name, jpgdir, imgsave)def label_create(self, jsondir, txtsave, name, jpgdir, imgsave):# 图片信息img = cv2.imread(jpgdir)if img is None:return print(name, 'jpg is empty')height, width, _ = img.shape# 框信息with open(jsondir, 'r', encoding='utf-8') as f:box_list = json.load(f)box_str_list = []for box_cls in box_list:box = box_cls[:4]conv_box = self.normalize((width, height), box)text = box_cls[self.cls_id]cls = self.cls[str(text)]conv_box.insert(0, int(cls))box_str = " ".join(str(item) for item in conv_box)+'\n'box_str_list.append(box_str)if box_str_list!=[]:with open(txtsave, "w", encoding="utf-8") as f:f.writelines(box_str_list)if os.path.exists(txtsave):cv2.imwrite(imgsave, img)def normalize(self, size, box): # size:(原图w,原图h) , box:(xmin,xmax,ymin,ymax)# 锚框归一化dw = 1. / size[0] # 1/wdh = 1. / size[1] # 1/hx = (box[0] + box[2]) / 2.0 # 物体在图中的中心点x坐标y = (box[1] + box[3]) / 2.0 # 物体在图中的中心点y坐标w = box[2] - box[0] # 物体实际像素宽度h = box[3] - box[1] # 物体实际像素高度x = x * dw # 物体中心点x的坐标比(相当于 x/原图w)w = w * dw # 物体宽度的宽度比(相当于 w/原图w)y = y * dh # 物体中心点y的坐标比(相当于 y/原图h)h = h * dh # 物体宽度的宽度比(相当于 h/原图h)return [x, y, w, h] # 返回 相对于原图的物体中心点的x坐标比,y坐标比,宽度比,高度比,取值范围[0-1]def denormalize(self, size, normalized_box): # size: (原图w, 原图h), normalized_box: [x, y, w, h]# 提取原图的宽度和高度w, h = size# 将中心点坐标和宽高还原为原图的像素坐标x_center = normalized_box[0] * wy_center = normalized_box[1] * hbox_width = normalized_box[2] * wbox_height = normalized_box[3] * h# 计算还原后的边界框坐标xmin = x_center - box_width / 2.0xmax = x_center + box_width / 2.0ymin = y_center - box_height / 2.0ymax = y_center + box_height / 2.0return [int(xmin), int(ymin), int(xmax), int(ymax)] # 返回还原后的边界框坐标def visual_json_main(self, visfile):for file in os.listdir(self.img_path):base = file.split('.')[0]jpgdir = os.path.join(self.img_path, file)img = cv2.imread(jpgdir)if img is None:return print('img is empty')jsondir = os.path.join(self.labels, base+'.json')with open(jsondir, 'r', encoding='utf-8') as f:boxes_list = json.load(f)box_list = []for box_ in boxes_list:box = box_[:4]text = box_[self.cls_id]box.append(text)box_list.append(box)visdir = os.path.join(visfile, file)vis = self.visual_word(box_list, img ,(0,255,0))cv2.imwrite(visdir, vis)print(file,' success')def visual_box(self, box_list, img, color):for idx, box in enumerate(box_list):if len(box) == 4:l, t, r, b = box# 图片画框cv2.rectangle(img, (int(l), int(t)), (int(r), int(b)), color, thickness=2, lineType=cv2.LINE_AA)elif len(box) == 8:pts = np.array(box, np.int32)pts = pts.reshape((-1, 1, 2))cv2.polylines(img, [pts], isClosed=True, color=color, thickness=2)return imgdef visual_word(self, word_list, img, color):font_path = "C:\Windows\Fonts\SimHei.ttf"font_size = 45for idx, box in enumerate(word_list):l, t, r, b, word = box# 图片画框cv2.rectangle(img, (l, t), (r, b), color, thickness=2, lineType=cv2.LINE_AA)caption = f"{word}"pil_img = Image.new('RGB', (0, 0))draw = ImageDraw.Draw(pil_img)text_size = draw.textbbox((0, 0), caption, font=ImageFont.truetype(font_path, font_size))text_width = text_size[2] - text_size[0]text_height = text_size[3] - text_size[1]# 使用 OpenCV 画文本背景框cv2.rectangle(img, (r, t), (r + text_width, t + text_height), color, -1)# 使用 PIL 绘制中文标签img = self.put_chinese_text(img, caption, (r, t), font_size, font_path, (0, 0, 0))return imgdef put_chinese_text(self, img, text, position, font_size, font_path, text_color):# 创建一个 PIL 图像pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))draw = ImageDraw.Draw(pil_img)# 加载字体font = ImageFont.truetype(font_path, font_size)# 绘制文本draw.text(position, text, font=font, fill=text_color)# 将 PIL 图像转换回 OpenCV 格式img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)return img