
✨年轻人要:Living for the moment(活在当下)!💪


  • 😺〇、仓库源码
  • 😺一、数据集介绍
    • 🐶1.1 GitHub原始数据集
    • 🐶1.2 GitHub预处理后的数据集
      • 🦄1.2.1 简化的绘图文件(.ndjson)
      • 🦄1.2.2 二进制文件(.bin)
      • 🦄1.2.3 Numpy位图(.npy)
    • 🐶1.3 Kaggle数据集
  • 😺二、数据集准备
  • 😺三、获取png格式图片
  • 😺四、训练过程
    • 🐶4.1 split_datasets.py
    • 🐶4.2 option.py
    • 🐶4.3 getdata.py
    • 🐶4.4 model.py
    • 🐶4.5 train-DDP.py
    • 🐶4.6 model_transfer.py
    • 🐶4.7 evaluate.py




Quick Draw 数据集是 345 个类别的 5000 万张图纸的集合,由游戏 Quick, Draw!的玩家贡献。这些图画被捕获为带时间戳的矢量,并标记有元数据,包括要求玩家绘制的内容以及玩家所在的国家/地区。

GitHub数据集地址: 📎The Quick, Draw! Dataset

Kaggle数据集地址:📎Quick, Draw! Doodle Recognition Challenge

Github中提供了两种类型的数据集,分别是 原始数据集预处理后的数据集
Google Cloud提供了数据集下载链接:quickdraw_dataset

🐶1.1 GitHub原始数据集

原始数据以按类别分隔的 ndjson 文件的形式提供,格式如下:

countrycode字符串玩家所在位置的双字母国家/地区代码 (ISO 3166-1 alpha-2)
drawing字符串一个矢量绘制的 JSON 数组


  { "key_id":"5891796615823360","word":"nose","countrycode":"AE","timestamp":"2017-03-01 20:41:36.70725 UTC","recognized":true,"drawing":[[[129,128,129,129,130,130,131,132,132,133,133,133,133,...]]]}


[ [  // First stroke [x0, x1, x2, x3, ...],[y0, y1, y2, y3, ...],[t0, t1, t2, t3, ...]],[  // Second stroke[x0, x1, x2, x3, ...],[y0, y1, y2, y3, ...],[t0, t1, t2, t3, ...]],... // Additional strokes


🐶1.2 GitHub预处理后的数据集

🦄1.2.1 简化的绘图文件(.ndjson)


  1. 将绘图与左上角对齐,最小值为 0。
  2. 统一缩放绘图,最大值为 255。
  3. 以 1 像素的间距对所有描边重新取样。
  4. 使用 epsilon 值为 2.0 的Ramer-Douglas-Peucker 算法简化所有笔画。


# read_ndjson.py
import jsonwith open('aircraft carrier.ndjson', 'r') as file:for line in file:data = json.loads(line)key_id = data['key_id']drawing = data['drawing']# ……

读取aircraft carrier.ndjsondebug之后的输出结果如下图所示。可以看到第一行数据包含8个笔触。

🦄1.2.2 二进制文件(.bin)



# read_bin.py
import struct
from struct import unpackdef unpack_drawing(file_handle):key_id, = unpack('Q', file_handle.read(8))country_code, = unpack('2s', file_handle.read(2))recognized, = unpack('b', file_handle.read(1))timestamp, = unpack('I', file_handle.read(4))n_strokes, = unpack('H', file_handle.read(2))image = []for i in range(n_strokes):n_points, = unpack('H', file_handle.read(2))fmt = str(n_points) + 'B'x = unpack(fmt, file_handle.read(n_points))y = unpack(fmt, file_handle.read(n_points))image.append((x, y))return {'key_id': key_id,'country_code': country_code,'recognized': recognized,'timestamp': timestamp,'image': image}def unpack_drawings(filename):with open(filename, 'rb') as f:while True:try:yield unpack_drawing(f)except struct.error:breakfor drawing in unpack_drawings('nose.bin'):# do something with the drawingprint(drawing['country_code'])

🦄1.2.3 Numpy位图(.npy)



# read_npy.py
import numpy as npdata_path = 'aircraft_carrier.npy'data = np.load(data_path)

🐶1.3 Kaggle数据集


  • sample_submission.csv - 正确格式的样本提交文件
  • test_raw.csv - 矢量格式的测试数据raw
  • test_simplified.csv - 矢量格式的测试数据simplified
  • train_raw.zip - 向量格式的训练数据;每个单词一个 CSV 文件raw
  • train_simplified.zip - 向量格式的训练数据;每个单词一个 CSV 文件simplified




  1. 将所有类的csv格式文件保存为png图片格式;
  2. 对340个类别的png格式图片各抽取10000张用作后续实践;
  3. 对每个类别的10000张数据进行8:1:1的训练集、验证集、测试集的划分;
  4. 训练模型;
  5. 模型评估。



# csv2png.py
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy import interpolate, misc
import matplotlib
matplotlib.use('Agg')input_dir = 'kaggle/train_simplified'
output_base_dir = 'datasets256'os.makedirs(output_base_dir, exist_ok=True)csv_files = [f for f in os.listdir(input_dir) if f.endswith('.csv')]    # Retrieve all CSV files from the folderskipped_files = []  # Record skipped filesfor csv_file in csv_files:csv_file_path = os.path.join(input_dir, csv_file)   # Build a complete file pathoutput_dir = os.path.join(output_base_dir, os.path.splitext(csv_file)[0])   # Build output directoryif os.path.exists(output_dir):      # Check if the output directory existsskipped_files.append(csv_file)print(f'The directory already exists, skip file: {csv_file}')continueos.makedirs(output_dir, exist_ok=True)data = pd.read_csv(csv_file_path)       # Read CSV filefor index, row in data.iterrows():  # Traverse each row of datadrawing = eval(row['drawing'])key_id = row['key_id']word = row['word']img = np.zeros((256, 256))      # Initialize imagefig = plt.figure(figsize=(256/96, 256/96), dpi=96)for stroke in drawing:      # Draw each strokestroke_x = stroke[0]stroke_y = stroke[1]x = np.array(stroke_x)y = np.array(stroke_y)np.interp((x + y) / 2, x, y)plt.plot(x, y, 'k')ax = plt.gca()ax.xaxis.set_ticks_position('top')ax.invert_yaxis()plt.axis('off')plt.savefig(os.path.join(output_dir, f'{word}-{key_id}.png'))plt.close(fig)print(f'Conversion completed: {csv_file} the {index:06d}image')print("The skipped files are:")
for file in skipped_files:print(file)



  • GitHub 预处理后的ndjson文件有23G
  • Kaggletrain_raw.zip文件有206G
  • Kaggletrain_simplified.zip文件有23G
  • Kaggletrain_simplified转为256*256大小的图片有470G



# check_class_num.py
import osfolder = 'datasets256'subfolders = [f.path for f in os.scandir(folder) if f.is_dir()]for subfolder in subfolders:    # Traverse each subfoldersfolder_name = os.path.basename(subfolder)   # Get the name of the subfoldersfiles = [f for f in os.scandir(subfolder) if f.is_file()]   # Retrieve all files in the subfoldersimage_count = sum(1 for f in files if f.name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')))   # Calculate the number of imagesif image_count == 0:        # If the number of images is 0, print out the names of the subfolders and delete themprint(f"There are no images in the subfolders '{folder_name}', deleting them...")os.rmdir(subfolder)print(f"subfolders '{folder_name}' deleted")else:print(f"Number of images in subfolders: '{folder_name}' : {image_count}")



🐶4.1 split_datasets.py


import os
import shutil
import randomoriginal_dataset_path = 'datasets256'     # Original dataset path
new_dataset_path = 'datasets'                       # Divide the dataset pathtrain_path = os.path.join(new_dataset_path, 'train')
val_path = os.path.join(new_dataset_path, 'val')
test_path = os.path.join(new_dataset_path, 'test')if not os.path.exists(train_path):os.makedirs(train_path)if not os.path.exists(val_path):os.makedirs(val_path)if not os.path.exists(test_path):os.makedirs(test_path)classes = os.listdir(original_dataset_path)     # Get all categoriesrandom.seed(42)for class_name in classes:      # Traverse each categorysrc_folder = os.path.join(original_dataset_path, class_name)    # Source folder path# Check if the folder for this category already exists under train, val, and testtrain_folder = os.path.join(train_path, class_name)val_folder = os.path.join(val_path, class_name)test_folder = os.path.join(test_path, class_name)# If the train, val, and test folders already exist, skip the folder creation sectionif os.path.exists(train_folder) and os.path.exists(val_folder) and os.path.exists(test_folder):# Check if the folder is emptyif os.listdir(train_folder) and os.listdir(val_folder) and os.listdir(test_folder):print(f"Category {class_name} already exists and is not empty, skip processing.")continue# create folderif not os.path.exists(train_folder):os.makedirs(train_folder)if not os.path.exists(val_folder):os.makedirs(val_folder)if not os.path.exists(test_folder):os.makedirs(test_folder)files = os.listdir(src_folder)      # Retrieve all file names under this categoryfiles = files[:10000]       # Only retrieve the first 10000 filesrandom.shuffle(files)       # Shuffle file listtotal_files = len(files)train_split_index = int(total_files * 0.8)val_split_index = int(total_files * 0.9)train_files = files[:train_split_index]val_files = files[train_split_index:val_split_index]test_files = files[val_split_index:]for file in train_files:src_file = os.path.join(src_folder, file)dst_file = os.path.join(train_folder, file)shutil.copy(src_file, dst_file)for file in val_files:src_file = os.path.join(src_folder, file)dst_file = os.path.join(val_folder, file)shutil.copy(src_file, dst_file)for file in test_files:src_file = os.path.join(src_folder, file)dst_file = os.path.join(test_folder, file)shutil.copy(src_file, dst_file)print("Dataset partitioning completed!")


🐶4.2 option.py


import argparsedef get_args():parser = argparse.ArgumentParser(description='all argument')parser.add_argument('--num_classes', type=int, default=340, help='image num classes')parser.add_argument('--loadsize', type=int, default=64, help='image size')parser.add_argument('--epochs', type=int, default=100, help='all epochs')parser.add_argument('--batch_size', type=int, default=1024, help='batch size')parser.add_argument('--lr', type=float, default=0.001, help='init lr')parser.add_argument('--use_lr_scheduler', type=bool, default=True, help='use lr scheduler')parser.add_argument('--dataset_train', type=str, default='./datasets/train', help='train path')parser.add_argument('--dataset_val', type=str, default="./datasets/val", help='val path')parser.add_argument('--dataset_test', type=str, default="./datasets/test", help='test path')parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='ckpt path')parser.add_argument('--tensorboard_dir', type=str, default='./tensorboard_dir', help='log path')parser.add_argument('--resume', type=bool, default=False, help='continue training')parser.add_argument('--resume_ckpt', type=str, default='./checkpoints/model_best.pth', help='choose breakpoint ckpt')parser.add_argument('--local-rank', type=int, default=-1, help='local rank')parser.add_argument('--use_mix_precision', type=bool, default=False, help='use mix pretrain')parser.add_argument('--test_img_path', type=str, default='datasets/test/zigzag/zigzag-4508464694951936.png', help='choose test image')parser.add_argument('--test_dir_path', type=str, default='./datasets/test', help='choose test path')return parser.parse_args()


🐶4.3 getdata.py


import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from option import get_args
opt = get_args()mean = [0.9367, 0.9404, 0.9405]
std = [0.1971, 0.1970, 0.1972]
def data_augmentation():data_transform = {'train': transforms.Compose([transforms.Resize((opt.loadsize, opt.loadsize)),transforms.ToTensor(),  # HWC -> CHWtransforms.Normalize(mean, std)]),'val': transforms.Compose([transforms.Resize((opt.loadsize, opt.loadsize)),transforms.ToTensor(),transforms.Normalize(mean, std)]),}return data_transformdef MyData():data_transform = data_augmentation()image_datasets = {'train': ImageFolder(opt.dataset_train, data_transform['train']),'val': ImageFolder(opt.dataset_val, data_transform['val']),}data_sampler = {'train': torch.utils.data.distributed.DistributedSampler(image_datasets['train']),'val': torch.utils.data.distributed.DistributedSampler(image_datasets['val']),}dataloaders = {'train': DataLoader(image_datasets['train'], batch_size=opt.batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=data_sampler['train']),'val': DataLoader(image_datasets['val'], batch_size=opt.batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=data_sampler['val'])}return dataloadersclass_names =['The Eiffel Tower', 'The Great Wall of China', 'The Mona Lisa', 'airplane', 'alarm clock', 'ambulance', 'angel', 'animal migration', 'ant', 'anvil', 'apple', 'arm', 'asparagus', 'axe', 'backpack', 'banana', 'bandage', 'barn', 'baseball', 'baseball bat', 'basket', 'basketball', 'bat', 'bathtub', 'beach', 'bear', 'beard', 'bed', 'bee', 'belt', 'bench', 'bicycle', 'binoculars', 'bird', 'birthday cake', 'blackberry', 'blueberry', 'book', 'boomerang', 'bottlecap', 'bowtie', 'bracelet', 'brain', 'bread', 'bridge', 'broccoli', 'broom', 'bucket', 'bulldozer', 'bus', 'bush', 'butterfly', 'cactus', 'cake', 'calculator', 'calendar', 'camel', 'camera', 'camouflage', 'campfire', 'candle', 'cannon', 'canoe', 'car', 'carrot', 'castle', 'cat', 'ceiling fan', 'cell phone', 'cello', 'chair', 'chandelier', 'church', 'circle', 'clarinet', 'clock', 'cloud', 'coffee cup', 'compass', 'computer', 'cookie', 'cooler', 'couch', 'cow', 'crab', 'crayon', 'crocodile', 'crown', 'cruise ship', 'cup', 'diamond', 'dishwasher', 'diving board', 'dog', 'dolphin', 'donut', 'door', 'dragon', 'dresser', 'drill', 'drums', 'duck', 'dumbbell', 'ear', 'elbow', 'elephant', 'envelope', 'eraser', 'eye', 'eyeglasses', 'face', 'fan', 'feather', 'fence', 'finger', 'fire hydrant', 'fireplace', 'firetruck', 'fish', 'flamingo', 'flashlight', 'flip flops', 'floor lamp', 'flower', 'flying saucer', 'foot', 'fork', 'frog', 'frying pan', 'garden', 'garden hose', 'giraffe', 'goatee', 'golf club', 'grapes', 'grass', 'guitar', 'hamburger', 'hammer', 'hand', 'harp', 'hat', 'headphones', 'hedgehog', 'helicopter', 'helmet', 'hexagon', 'hockey puck', 'hockey stick', 'horse', 'hospital', 'hot air balloon', 'hot dog', 'hot tub', 'hourglass', 'house', 'house plant', 'hurricane', 'ice cream', 'jacket', 'jail', 'kangaroo', 'key', 'keyboard', 'knee', 'ladder', 'lantern', 'laptop', 'leaf', 'leg', 'light bulb', 'lighthouse', 'lightning', 'line', 'lion', 'lipstick', 'lobster', 'lollipop', 'mailbox', 'map', 'marker', 'matches', 'megaphone', 'mermaid', 'microphone', 'microwave', 'monkey', 'moon', 'mosquito', 'motorbike', 'mountain', 'mouse', 'moustache', 'mouth', 'mug', 'mushroom', 'nail', 'necklace', 'nose', 'ocean', 'octagon', 'octopus', 'onion', 'oven', 'owl', 'paint can', 'paintbrush', 'palm tree', 'panda', 'pants', 'paper clip', 'parachute', 'parrot', 'passport', 'peanut', 'pear', 'peas', 'pencil', 'penguin', 'piano', 'pickup truck', 'picture frame', 'pig', 'pillow', 'pineapple', 'pizza', 'pliers', 'police car', 'pond', 'pool', 'popsicle', 'postcard', 'potato', 'power outlet', 'purse', 'rabbit', 'raccoon', 'radio', 'rain', 'rainbow', 'rake', 'remote control', 'rhinoceros', 'river', 'roller coaster', 'rollerskates', 'sailboat', 'sandwich', 'saw', 'saxophone', 'school bus', 'scissors', 'scorpion', 'screwdriver', 'sea turtle', 'see saw', 'shark', 'sheep', 'shoe', 'shorts', 'shovel', 'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping bag', 'smiley face', 'snail', 'snake', 'snorkel', 'snowflake', 'snowman', 'soccer ball', 'sock', 'speedboat', 'spider', 'spoon', 'spreadsheet', 'square', 'squiggle', 'squirrel', 'stairs', 'star', 'steak', 'stereo', 'stethoscope', 'stitches', 'stop sign', 'stove', 'strawberry', 'streetlight', 'string bean', 'submarine', 'suitcase', 'sun', 'swan', 'sweater', 'swing set', 'sword', 't-shirt', 'table', 'teapot', 'teddy-bear', 'telephone', 'television', 'tennis racquet', 'tent', 'tiger', 'toaster', 'toe', 'toilet', 'tooth', 'toothbrush', 'toothpaste', 'tornado', 'tractor', 'traffic light', 'train', 'tree', 'triangle', 'trombone', 'truck', 'trumpet', 'umbrella', 'underwear', 'van', 'vase', 'violin', 'washing machine', 'watermelon', 'waterslide', 'whale', 'wheel', 'windmill', 'wine bottle', 'wine glass', 'wristwatch', 'yoga', 'zebra', 'zigzag'
]if __name__ == '__main__':mena_std_transform = transforms.Compose([transforms.ToTensor()])dataset = ImageFolder(opt.dataset_val, transform=mena_std_transform)print(dataset.class_to_idx)		# Index for each category

🐶4.4 model.py


import torch.nn as nn
from torchvision.models import mobilenet_v3_small
from torchsummary import summary
from option import get_args
opt = get_args()def CustomMobileNetV3():model = mobilenet_v3_small(weights='MobileNet_V3_Small_Weights.IMAGENET1K_V1')model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, opt.num_classes)return modelif __name__ == '__main__':model = CustomMobileNetV3()print(model)print(summary(model.to(opt.device), (3, opt.loadsize, opt.loadsize), opt.batch_size))


----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [1024, 16, 32, 32]             432BatchNorm2d-2         [1024, 16, 32, 32]              32Hardswish-3         [1024, 16, 32, 32]               0Conv2d-4         [1024, 16, 16, 16]             144BatchNorm2d-5         [1024, 16, 16, 16]              32ReLU-6         [1024, 16, 16, 16]               0AdaptiveAvgPool2d-7           [1024, 16, 1, 1]               0Conv2d-8            [1024, 8, 1, 1]             136ReLU-9            [1024, 8, 1, 1]               0Conv2d-10           [1024, 16, 1, 1]             144Hardsigmoid-11           [1024, 16, 1, 1]               0
SqueezeExcitation-12         [1024, 16, 16, 16]               0Conv2d-13         [1024, 16, 16, 16]             256BatchNorm2d-14         [1024, 16, 16, 16]              32InvertedResidual-15         [1024, 16, 16, 16]               0Conv2d-16         [1024, 72, 16, 16]           1,152BatchNorm2d-17         [1024, 72, 16, 16]             144ReLU-18         [1024, 72, 16, 16]               0Conv2d-19           [1024, 72, 8, 8]             648BatchNorm2d-20           [1024, 72, 8, 8]             144ReLU-21           [1024, 72, 8, 8]               0Conv2d-22           [1024, 24, 8, 8]           1,728BatchNorm2d-23           [1024, 24, 8, 8]              48InvertedResidual-24           [1024, 24, 8, 8]               0Conv2d-25           [1024, 88, 8, 8]           2,112BatchNorm2d-26           [1024, 88, 8, 8]             176ReLU-27           [1024, 88, 8, 8]               0Conv2d-28           [1024, 88, 8, 8]             792BatchNorm2d-29           [1024, 88, 8, 8]             176ReLU-30           [1024, 88, 8, 8]               0Conv2d-31           [1024, 24, 8, 8]           2,112BatchNorm2d-32           [1024, 24, 8, 8]              48InvertedResidual-33           [1024, 24, 8, 8]               0Conv2d-34           [1024, 96, 8, 8]           2,304BatchNorm2d-35           [1024, 96, 8, 8]             192Hardswish-36           [1024, 96, 8, 8]               0Conv2d-37           [1024, 96, 4, 4]           2,400BatchNorm2d-38           [1024, 96, 4, 4]             192Hardswish-39           [1024, 96, 4, 4]               0
AdaptiveAvgPool2d-40           [1024, 96, 1, 1]               0Conv2d-41           [1024, 24, 1, 1]           2,328ReLU-42           [1024, 24, 1, 1]               0Conv2d-43           [1024, 96, 1, 1]           2,400Hardsigmoid-44           [1024, 96, 1, 1]               0
SqueezeExcitation-45           [1024, 96, 4, 4]               0Conv2d-46           [1024, 40, 4, 4]           3,840BatchNorm2d-47           [1024, 40, 4, 4]              80InvertedResidual-48           [1024, 40, 4, 4]               0Conv2d-49          [1024, 240, 4, 4]           9,600BatchNorm2d-50          [1024, 240, 4, 4]             480Hardswish-51          [1024, 240, 4, 4]               0Conv2d-52          [1024, 240, 4, 4]           6,000BatchNorm2d-53          [1024, 240, 4, 4]             480Hardswish-54          [1024, 240, 4, 4]               0
AdaptiveAvgPool2d-55          [1024, 240, 1, 1]               0Conv2d-56           [1024, 64, 1, 1]          15,424ReLU-57           [1024, 64, 1, 1]               0Conv2d-58          [1024, 240, 1, 1]          15,600Hardsigmoid-59          [1024, 240, 1, 1]               0
SqueezeExcitation-60          [1024, 240, 4, 4]               0Conv2d-61           [1024, 40, 4, 4]           9,600BatchNorm2d-62           [1024, 40, 4, 4]              80InvertedResidual-63           [1024, 40, 4, 4]               0Conv2d-64          [1024, 240, 4, 4]           9,600BatchNorm2d-65          [1024, 240, 4, 4]             480Hardswish-66          [1024, 240, 4, 4]               0Conv2d-67          [1024, 240, 4, 4]           6,000BatchNorm2d-68          [1024, 240, 4, 4]             480Hardswish-69          [1024, 240, 4, 4]               0
AdaptiveAvgPool2d-70          [1024, 240, 1, 1]               0Conv2d-71           [1024, 64, 1, 1]          15,424ReLU-72           [1024, 64, 1, 1]               0Conv2d-73          [1024, 240, 1, 1]          15,600Hardsigmoid-74          [1024, 240, 1, 1]               0
SqueezeExcitation-75          [1024, 240, 4, 4]               0Conv2d-76           [1024, 40, 4, 4]           9,600BatchNorm2d-77           [1024, 40, 4, 4]              80InvertedResidual-78           [1024, 40, 4, 4]               0Conv2d-79          [1024, 120, 4, 4]           4,800BatchNorm2d-80          [1024, 120, 4, 4]             240Hardswish-81          [1024, 120, 4, 4]               0Conv2d-82          [1024, 120, 4, 4]           3,000BatchNorm2d-83          [1024, 120, 4, 4]             240Hardswish-84          [1024, 120, 4, 4]               0
AdaptiveAvgPool2d-85          [1024, 120, 1, 1]               0Conv2d-86           [1024, 32, 1, 1]           3,872ReLU-87           [1024, 32, 1, 1]               0Conv2d-88          [1024, 120, 1, 1]           3,960Hardsigmoid-89          [1024, 120, 1, 1]               0
SqueezeExcitation-90          [1024, 120, 4, 4]               0Conv2d-91           [1024, 48, 4, 4]           5,760BatchNorm2d-92           [1024, 48, 4, 4]              96InvertedResidual-93           [1024, 48, 4, 4]               0Conv2d-94          [1024, 144, 4, 4]           6,912BatchNorm2d-95          [1024, 144, 4, 4]             288Hardswish-96          [1024, 144, 4, 4]               0Conv2d-97          [1024, 144, 4, 4]           3,600BatchNorm2d-98          [1024, 144, 4, 4]             288Hardswish-99          [1024, 144, 4, 4]               0
AdaptiveAvgPool2d-100          [1024, 144, 1, 1]               0Conv2d-101           [1024, 40, 1, 1]           5,800ReLU-102           [1024, 40, 1, 1]               0Conv2d-103          [1024, 144, 1, 1]           5,904Hardsigmoid-104          [1024, 144, 1, 1]               0
SqueezeExcitation-105          [1024, 144, 4, 4]               0Conv2d-106           [1024, 48, 4, 4]           6,912BatchNorm2d-107           [1024, 48, 4, 4]              96
InvertedResidual-108           [1024, 48, 4, 4]               0Conv2d-109          [1024, 288, 4, 4]          13,824BatchNorm2d-110          [1024, 288, 4, 4]             576Hardswish-111          [1024, 288, 4, 4]               0Conv2d-112          [1024, 288, 2, 2]           7,200BatchNorm2d-113          [1024, 288, 2, 2]             576Hardswish-114          [1024, 288, 2, 2]               0
AdaptiveAvgPool2d-115          [1024, 288, 1, 1]               0Conv2d-116           [1024, 72, 1, 1]          20,808ReLU-117           [1024, 72, 1, 1]               0Conv2d-118          [1024, 288, 1, 1]          21,024Hardsigmoid-119          [1024, 288, 1, 1]               0
SqueezeExcitation-120          [1024, 288, 2, 2]               0Conv2d-121           [1024, 96, 2, 2]          27,648BatchNorm2d-122           [1024, 96, 2, 2]             192
InvertedResidual-123           [1024, 96, 2, 2]               0Conv2d-124          [1024, 576, 2, 2]          55,296BatchNorm2d-125          [1024, 576, 2, 2]           1,152Hardswish-126          [1024, 576, 2, 2]               0Conv2d-127          [1024, 576, 2, 2]          14,400BatchNorm2d-128          [1024, 576, 2, 2]           1,152Hardswish-129          [1024, 576, 2, 2]               0
AdaptiveAvgPool2d-130          [1024, 576, 1, 1]               0Conv2d-131          [1024, 144, 1, 1]          83,088ReLU-132          [1024, 144, 1, 1]               0Conv2d-133          [1024, 576, 1, 1]          83,520Hardsigmoid-134          [1024, 576, 1, 1]               0
SqueezeExcitation-135          [1024, 576, 2, 2]               0Conv2d-136           [1024, 96, 2, 2]          55,296BatchNorm2d-137           [1024, 96, 2, 2]             192
InvertedResidual-138           [1024, 96, 2, 2]               0Conv2d-139          [1024, 576, 2, 2]          55,296BatchNorm2d-140          [1024, 576, 2, 2]           1,152Hardswish-141          [1024, 576, 2, 2]               0Conv2d-142          [1024, 576, 2, 2]          14,400BatchNorm2d-143          [1024, 576, 2, 2]           1,152Hardswish-144          [1024, 576, 2, 2]               0
AdaptiveAvgPool2d-145          [1024, 576, 1, 1]               0Conv2d-146          [1024, 144, 1, 1]          83,088ReLU-147          [1024, 144, 1, 1]               0Conv2d-148          [1024, 576, 1, 1]          83,520Hardsigmoid-149          [1024, 576, 1, 1]               0
SqueezeExcitation-150          [1024, 576, 2, 2]               0Conv2d-151           [1024, 96, 2, 2]          55,296BatchNorm2d-152           [1024, 96, 2, 2]             192
InvertedResidual-153           [1024, 96, 2, 2]               0Conv2d-154          [1024, 576, 2, 2]          55,296BatchNorm2d-155          [1024, 576, 2, 2]           1,152Hardswish-156          [1024, 576, 2, 2]               0
AdaptiveAvgPool2d-157          [1024, 576, 1, 1]               0Linear-158               [1024, 1024]         590,848Hardswish-159               [1024, 1024]               0Dropout-160               [1024, 1024]               0Linear-161                [1024, 340]         348,500
Total params: 1,866,356
Trainable params: 1,866,356
Non-trainable params: 0
Input size (MB): 48.00
Forward/backward pass size (MB): 2979.22
Params size (MB): 7.12
Estimated Total Size (MB): 3034.34

🐶4.5 train-DDP.py


  • DDP分布式训练(单机双卡);
  • AMP混合精度训练;
  • 学习率衰减;
  • 早停;
  • 断点继续训练。
# python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="" --master_port=12345 train-DDP.py --use_mix_precision True
# Watch Training Log:tensorboard --logdir=tensorboard_dir
from tqdm import tqdm
import torch
import torch.nn.parallel
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
import time
import os
import torch.optim
import torch.utils.data
import torch.nn as nn
from collections import OrderedDict
from model import CustomMobileNetV3
from getdata import MyData
from torch.cuda.amp import GradScaler
from option import get_args
opt = get_args()
dist.init_process_group(backend='nccl', init_method='env://')os.makedirs(opt.checkpoints, exist_ok=True)def train(gpu):rank = dist.get_rank()model = CustomMobileNetV3()model.cuda(gpu)criterion = nn.CrossEntropyLoss().to(gpu)optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)model = nn.SyncBatchNorm.convert_sync_batchnorm(model)model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])scaler = GradScaler(enabled=opt.use_mix_precision)  dataloaders = MyData()train_loader = dataloaders['train']test_loader = dataloaders['val']if opt.use_lr_scheduler:scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)start_time = time.time()best_val_acc = 0.0no_improve_epochs = 0early_stopping_patience = 6  # Early Stopping Patience"""breakckpt resume"""if opt.resume:checkpoint = torch.load(opt.resume_ckpt)print('Loading checkpoint from:', opt.resume_ckpt)new_state_dict = OrderedDict()      # Create a new ordered dictionary and remove prefixesfor k, v in checkpoint['model'].items():name = k[7:]                    # Remove 'module.' To match the original model definitionnew_state_dict[name] = vmodel.load_state_dict(new_state_dict, strict=False)     # Load a new state dictionaryoptimizer.load_state_dict(checkpoint['optimizer'])start_epoch = checkpoint['epoch']                       # Set the starting epochif opt.use_lr_scheduler:scheduler.load_state_dict(checkpoint['scheduler'])else:start_epoch = 0for epoch in range(start_epoch + 1, opt.epochs):tqdm_trainloader = tqdm(train_loader, desc=f'Epoch {epoch}')running_loss, running_correct_top1, running_correct_top3, running_correct_top5 = 0.0, 0.0, 0.0, 0.0total_samples = 0for i, (images, target) in enumerate(tqdm_trainloader if rank == 0 else train_loader, 0):images = images.to(gpu)target = target.to(gpu)with torch.cuda.amp.autocast(enabled=opt.use_mix_precision):output = model(images)loss = criterion(output, target)optimizer.zero_grad()scaler.scale(loss).backward()scaler.step(optimizer)scaler.update() running_loss += loss.item() * images.size(0)_, predicted = torch.max(output.data, 1)running_correct_top1  += (predicted == target).sum().item()_, predicted_top3 = torch.topk(output.data, 3, dim=1)_, predicted_top5 = torch.topk(output.data, 5, dim=1)running_correct_top3 += (predicted_top3[:, :3] == target.unsqueeze(1).expand_as(predicted_top3)).sum().item()running_correct_top5 += (predicted_top5[:, :5] == target.unsqueeze(1).expand_as(predicted_top5)).sum().item()total_samples += target.size(0)state = {'epoch': epoch,'model': model.module.state_dict(),'optimizer': optimizer.state_dict(),'scheduler': scheduler.state_dict()}if rank == 0:current_lr = scheduler.get_last_lr()[0] if opt.use_lr_scheduler else opt.lrprint(f'[Epoch {epoch}]  'f'[Train Loss: {running_loss / len(train_loader.dataset):.6f}]  'f'[Train Top-1 Acc: {running_correct_top1 / len(train_loader.dataset):.6f}]  'f'[Train Top-3 Acc: {running_correct_top3 / len(train_loader.dataset):.6f}]  'f'[Train Top-5 Acc: {running_correct_top5 / len(train_loader.dataset):.6f}]  'f'[Learning Rate: {current_lr:.6f}]  'f'[Time: {time.time() - start_time:.6f} seconds]')writer.add_scalar('Train/Loss', running_loss / len(train_loader.dataset), epoch)writer.add_scalar('Train/Top-1 Accuracy', running_correct_top1 / len(train_loader.dataset), epoch)writer.add_scalar('Train/Top-3 Accuracy', running_correct_top3 / len(train_loader.dataset), epoch)writer.add_scalar('Train/Top-5 Accuracy', running_correct_top5 / len(train_loader.dataset), epoch)writer.add_scalar('Train/Learning Rate', current_lr, epoch)torch.save(state, f'{opt.checkpoints}model_epoch_{epoch}.pth')# dist.barrier()tqdm_trainloader.close()if opt.use_lr_scheduler:    # Learning-rate Schedulerscheduler.step()acc_top1 = valid(test_loader, model, epoch, gpu, rank)if acc_top1 is not None:if acc_top1 > best_val_acc:best_val_acc = acc_top1no_improve_epochs = 0torch.save(state, f'{opt.checkpoints}/model_best.pth')else:no_improve_epochs += 1if no_improve_epochs >= early_stopping_patience:print(f'Early stopping triggered after {early_stopping_patience} epochs without improvement.')breakelse:print("Warning: acc_top1 is None, skipping this epoch.")dist.destroy_process_group()def valid(val_loader, model, epoch, gpu, rank):model.eval()correct_top1, correct_top3, correct_top5, total = torch.tensor(0.).to(gpu), torch.tensor(0.).to(gpu), torch.tensor(0.).to(gpu), torch.tensor(0.).to(gpu)with torch.no_grad():tqdm_valloader = tqdm(val_loader, desc=f'Epoch {epoch}')for i, (images, target) in enumerate(tqdm_valloader, 0) :images = images.to(gpu)target = target.to(gpu)output = model(images)total += target.size(0)correct_top1  += (output.argmax(1) == target).type(torch.float).sum()_, predicted_top3 = torch.topk(output, 3, dim=1)_, predicted_top5 = torch.topk(output, 5, dim=1)correct_top3 += (predicted_top3[:, :3] == target.unsqueeze(1).expand_as(predicted_top3)).sum().item()correct_top5 += (predicted_top5[:, :5] == target.unsqueeze(1).expand_as(predicted_top5)).sum().item()dist.reduce(total, 0, op=dist.ReduceOp.SUM)     # Group communication reduce operation (change to allreduce if Gloo)dist.reduce(correct_top1, 0, op=dist.ReduceOp.SUM)dist.reduce(correct_top3, 0, op=dist.ReduceOp.SUM)dist.reduce(correct_top5, 0, op=dist.ReduceOp.SUM)if rank == 0:print(f'[Epoch {epoch}]  'f'[Val Top-1 Acc: {correct_top1 / total:.6f}]  'f'[Val Top-3 Acc: {correct_top3 / total:.6f}]  'f'[Val Top-5 Acc: {correct_top5 / total:.6f}]')writer.add_scalar('Validation/Top-1 Accuracy', correct_top1 / total, epoch)writer.add_scalar('Validation/Top-3 Accuracy', correct_top3 / total, epoch)writer.add_scalar('Validation/Top-5 Accuracy', correct_top5 / total, epoch)return float(correct_top1 / total)  # Return top 1 precisiontqdm_valloader.close()def main():train(opt.local_rank)if __name__ == '__main__':writer = SummaryWriter(log_dir=opt.tensorboard_dir)main()writer.close()


python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="" --master_port=12345 train-DDP.py --use_mix_precision True


  • nproc_per_node:显卡数量
  • nnodes:机器数量
  • node_rank:机器编号
  • master_addr:机器ip地址
  • master_port:机器端口


W0914 18:33:15.081479 140031432897728 torch/distributed/elastic/agent/server/api.py:741] Received Signals.SIGHUP death signal, shutting down workers
W0914 18:33:15.085310 140031432897728 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 1685186 closing signal SIGHUP
W0914 18:33:15.085644 140031432897728 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 1685192 closing signal SIGHUP

具体原因可以参考pytorch官方的discuss:DDP Error: torch.distributed.elastic.agent.server.api:Received 1 death signal, shutting down workers


  1. 安装tmuxsudo apt-get install tmux
  2. 新建会话:tmux new -s train-DDP(会话名称自定义)
  3. 激活虚拟环境:conda activate pytorch(虚拟环境以实际需要为准)
  4. 启动训练任务:python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="" --master_port=12345 train-DDP.py --use_mix_precision True


  • 查看当前全部的tmux会话:tmux ls
  • 新建会话:tmux new -s 会话名字
  • 重新进入会话:tmux attach -t 会话名字
  • kill会话:tmux kill-session -t 会话名字


🐶4.6 model_transfer.py


from torch.utils.mobile_optimizer import optimize_for_mobile
import torch
from model import CustomMobileNetV3
import onnx
from onnxsim import simplify
from torch.autograd import Variable
from option import get_args
opt = get_args()model = CustomMobileNetV3()
model.load_state_dict(torch.load(f'{opt.checkpoints}model_best.pth', map_location='cpu')['model'])
print("Model loaded successfully.")"""Save .pth format model"""
torch.save(model, f'{opt.checkpoints}/model.pth')"""Save .ptl format model"""
example = torch.rand(1, 3, 64, 64)
traced_script_module = torch.jit.trace(model, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter(f'{opt.checkpoints}model.ptl')"""Save .onnx format model"""
input_name = ['input']
output_name = ['output']
input = Variable(torch.randn(1, 3, opt.loadsize, opt.loadsize))
torch.onnx.export(model, input, f'{opt.checkpoints}model.onnx', input_names=input_name, output_names=output_name, verbose=True)
onnx.save(onnx.shape_inference.infer_shapes(onnx.load(f'{opt.checkpoints}model.onnx')), f'{opt.checkpoints}model.onnx')   # Perform shape judgment
# simplified model
model_onnx = onnx.load(f'{opt.checkpoints}model.onnx')
model_simplified, check = simplify(model_onnx)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simplified, f'{opt.checkpoints}model_simplified.onnx')

🐶4.7 evaluate.py


  • evaluate_image_single:对单张图像进行预测
  • evaluate_image_dir:对文件夹图像进行预测
  • evaluate_onnx_model:onnx模型对图像进行预测

代码提供了多个可视化图像与评估指标。包括 混淆矩阵、F1score 等。

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torch.nn.functional as F
import torch.utils.data
import onnxruntime
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix, roc_curve, auc
from tqdm import tqdm
from getdata import mean, std, class_names
from option import get_args
opt = get_args()
device = 'cuda:1'"""Predicting a single image"""
def evaluate_image_single(img_path, transform_test, model, class_names, top_k):image = Image.open(img_path).convert('RGB')img = transform_test(image).to(device)img = img.unsqueeze_(0)out = model(img)pred_softmax = F.softmax(out, dim=1)top_n, top_n_indices = torch.topk(pred_softmax, top_k)confs = top_n[0].cpu().detach().numpy().tolist()class_names_top = [class_names[i] for i in top_n_indices[0]]for i in range(top_k):print(f'Pre: {class_names_top[i]}   Conf: {confs[i]:.3f}')confs_max = confs[0]plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.axis('off')plt.title(f'Pre: {class_names_top[0]}   Conf: {confs_max:.3f}')plt.imshow(image)sorted_pairs = sorted(zip(class_names_top, confs), key=lambda x: x[1], reverse=True)sorted_class_names_top, sorted_confs = zip(*sorted_pairs)plt.subplot(1, 2, 2)bars = plt.bar(sorted_class_names_top, sorted_confs, color='lightcoral')plt.xlabel('Class Names')plt.ylabel('Confidence')plt.title('Top 5 Predictions (Descending Order)')plt.xticks(rotation=45)plt.ylim(0, 1)plt.tight_layout()for bar, conf in zip(bars, sorted_confs):yval = bar.get_height()plt.text(bar.get_x() + bar.get_width()/2, yval + 0.01, f'{conf:.3f}', ha='center', va='bottom')plt.savefig('predict_image_with_bars.jpg')"""Predicting folder images"""
def evaluate_image_dir(model, dataloader, class_names):model.eval()all_preds = []all_labels = []correct_top1, correct_top3, correct_top5, total = torch.tensor(0.).to(device), torch.tensor(0.).to(device), torch.tensor(0.).to(device), torch.tensor(0.).to(device)with torch.no_grad():for images, labels in tqdm(dataloader, desc="Evaluating"):images = images.to(device)labels = labels.to(device)outputs = model(images)total += labels.size(0)correct_top1  += (outputs.argmax(1) == labels).type(torch.float).sum()_, predicted_top3 = torch.topk(outputs, 3, dim=1)_, predicted_top5 = torch.topk(outputs, 5, dim=1)correct_top3 += (predicted_top3[:, :3] == labels.unsqueeze(1).expand_as(predicted_top3)).sum().item()correct_top5 += (predicted_top5[:, :5] == labels.unsqueeze(1).expand_as(predicted_top5)).sum().item()_, preds = torch.max(outputs, 1)all_preds.extend(preds)all_labels.extend(labels)all_preds = torch.tensor(all_preds)all_labels = torch.tensor(all_labels)top1 = correct_top1 / totaltop3 = correct_top3 / totaltop5 = correct_top5 / totalprint(f"Top-1 Accuracy: {top1:.4f}")print(f"Top-3 Accuracy: {top3:.4f}")print(f"Top-5 Accuracy: {top5:.4f}")accuracy = accuracy_score(all_labels.cpu().numpy(), all_preds.cpu().numpy())precision = precision_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='macro')recall = recall_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='macro')f1 = f1_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='macro')cm = confusion_matrix(all_labels.cpu().numpy(), all_preds.cpu().numpy())report = classification_report(all_labels.cpu().numpy(), all_preds.cpu().numpy(), target_names=class_names)print(f'Accuracy: {accuracy:.4f}')print(f'Precision: {precision:.4f}')print(f'Recall: {recall:.4f}')print(f'F1 Score: {f1:.4f}')print(report)plt.figure(figsize=(100, 100))sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names, annot_kws={"size": 8})plt.xticks(rotation=90) plt.yticks(rotation=0)  plt.xlabel('Predicted Label')plt.ylabel('True Label')plt.title('Confusion Matrix')plt.savefig('confusion_matrix.jpg')"""Using .onnx model to predict images"""
def evaluate_onnx_model(img_path, data_transform, onnx_model_path, class_names, top_k=5):ort_session = onnxruntime.InferenceSession(onnx_model_path)img_pil = Image.open(img_path).convert('RGB')input_img = data_transform(img_pil)input_tensor = input_img.unsqueeze(0).numpy()ort_inputs = {'input': input_tensor}out = ort_session.run(['output'], ort_inputs)[0]def softmax(x):return np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)prob_dist = softmax(out)result_dict = {label: float(prob_dist[0][i]) for i, label in enumerate(class_names)}result_dict = dict(sorted(result_dict.items(), key=lambda item: item[1], reverse=True))for key, value in list(result_dict.items())[:top_k]:print(f'Pre: {key}   Conf: {value:.3f}')confs_max = list(result_dict.values())[0]class_names_top = list(result_dict.keys())plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.axis('off')plt.title(f'Pre: {class_names_top[0]}   Conf: {confs_max:.3f}')plt.imshow(img_pil)plt.subplot(1, 2, 2)bars = plt.bar(class_names_top[:top_k], list(result_dict.values())[:top_k], color='lightcoral')plt.xlabel('Class Names')plt.ylabel('Confidence')plt.title('Top 5 Predictions (Descending Order)')plt.xticks(rotation=45)plt.ylim(0, 1)plt.tight_layout()for bar, conf in zip(bars, list(result_dict.values())[:top_k]):yval = bar.get_height()plt.text(bar.get_x() + bar.get_width()/2, yval + 0.01, f'{conf:.3f}', ha='center', va='bottom')plt.savefig('predict_image_with_bars.jpg')if __name__ == '__main__':data_transform = transforms.Compose([transforms.Resize((opt.loadsize, opt.loadsize)), transforms.ToTensor(),transforms.Normalize(mean, std)])image_datasets = ImageFolder(opt.dataset_test, data_transform)dataloaders = DataLoader(image_datasets, batch_size=512, shuffle=True)ptl_model_path = opt.checkpoints + 'model.ptl'pth_model_path = opt.checkpoints + 'model.pth'onnx_model_path = opt.checkpoints + 'model.onnx'ptl_model = torch.jit.load(ptl_model_path).to(device)pth_model = torch.load(pth_model_path).to(device)evaluate_image_single(opt.test_img_path, data_transform, pth_model, class_names, top_k=5)     # Predicting a single image# evaluate_image_dir(pth_model, dataloaders, class_names)     # Predicting folder images# evaluate_onnx_model(opt.test_img_path, data_transform, onnx_model_path, class_names, top_k=5)   # Predicting a single image


Top-1 Accuracy: 0.6833
Top-3 Accuracy: 0.8521
Top-5 Accuracy: 0.8933
Accuracy: 0.6833
Precision: 0.6875
Recall: 0.6833
F1 Score: 0.6817
                         precision    recall  f1-score   supportThe Eiffel Tower       0.83      0.88      0.85      1000
The Great Wall of China       0.47      0.36      0.41      1000The Mona Lisa       0.68      0.86      0.76      1000airplane       0.83      0.74      0.78      1000alarm clock       0.76      0.76      0.76      1000ambulance       0.70      0.65      0.67      1000angel       0.87      0.78      0.82      1000animal migration       0.47      0.66      0.55      1000ant       0.77      0.74      0.75      1000anvil       0.80      0.66      0.72      1000apple       0.82      0.85      0.83      1000arm       0.74      0.69      0.71      1000asparagus       0.54      0.44      0.48      1000axe       0.69      0.67      0.68      1000backpack       0.61      0.75      0.67      1000banana       0.68      0.72      0.70      1000bandage       0.83      0.71      0.77      1000barn       0.66      0.68      0.67      1000baseball       0.77      0.71      0.74      1000baseball bat       0.75      0.73      0.74      1000basket       0.71      0.62      0.66      1000basketball       0.62      0.72      0.66      1000bat       0.79      0.62      0.69      1000bathtub       0.60      0.64      0.62      1000beach       0.58      0.65      0.61      1000bear       0.46      0.31      0.37      1000beard       0.56      0.73      0.63      1000bed       0.80      0.67      0.73      1000bee       0.82      0.74      0.78      1000belt       0.78      0.55      0.64      1000bench       0.59      0.53      0.56      1000bicycle       0.73      0.72      0.72      1000binoculars       0.74      0.77      0.76      1000bird       0.47      0.43      0.45      1000birthday cake       0.52      0.64      0.57      1000blackberry       0.46      0.42      0.44      1000blueberry       0.58      0.47      0.52      1000book       0.72      0.78      0.75      1000boomerang       0.73      0.70      0.71      1000bottlecap       0.58      0.54      0.56      1000bowtie       0.87      0.86      0.86      1000bracelet       0.68      0.60      0.64      1000brain       0.59      0.60      0.59      1000bread       0.54      0.63      0.58      1000bridge       0.61      0.64      0.63      1000broccoli       0.58      0.70      0.64      1000broom       0.56      0.68      0.61      1000bucket       0.62      0.66      0.64      1000bulldozer       0.69      0.70      0.70      1000bus       0.56      0.42      0.48      1000bush       0.47      0.65      0.55      1000butterfly       0.86      0.88      0.87      1000cactus       0.69      0.87      0.77      1000cake       0.53      0.42      0.47      1000calculator       0.76      0.82      0.79      1000calendar       0.54      0.50      0.52      1000camel       0.82      0.84      0.83      1000camera       0.87      0.74      0.80      1000camouflage       0.23      0.43      0.30      1000campfire       0.72      0.77      0.75      1000candle       0.75      0.73      0.74      1000cannon       0.77      0.69      0.72      1000canoe       0.67      0.63      0.65      1000car       0.65      0.63      0.64      1000carrot       0.75      0.82      0.78      1000castle       0.79      0.72      0.75      1000cat       0.69      0.66      0.68      1000ceiling fan       0.83      0.64      0.72      1000cell phone       0.62      0.60      0.61      1000cello       0.51      0.67      0.58      1000chair       0.83      0.80      0.81      1000chandelier       0.74      0.71      0.73      1000church       0.72      0.67      0.69      1000circle       0.53      0.86      0.66      1000clarinet       0.53      0.63      0.58      1000clock       0.86      0.77      0.82      1000cloud       0.73      0.69      0.71      1000coffee cup       0.67      0.43      0.52      1000compass       0.69      0.78      0.73      1000computer       0.79      0.62      0.69      1000cookie       0.68      0.80      0.74      1000cooler       0.47      0.33      0.38      1000couch       0.76      0.82      0.79      1000cow       0.70      0.57      0.63      1000crab       0.70      0.72      0.71      1000crayon       0.44      0.52      0.47      1000crocodile       0.65      0.57      0.60      1000crown       0.87      0.87      0.87      1000cruise ship       0.76      0.69      0.73      1000cup       0.43      0.50      0.47      1000diamond       0.73      0.88      0.80      1000dishwasher       0.56      0.47      0.51      1000diving board       0.53      0.54      0.53      1000dog       0.50      0.41      0.45      1000dolphin       0.79      0.59      0.68      1000donut       0.75      0.88      0.81      1000door       0.69      0.72      0.70      1000dragon       0.52      0.42      0.47      1000dresser       0.75      0.65      0.70      1000drill       0.78      0.71      0.75      1000drums       0.71      0.68      0.70      1000duck       0.68      0.49      0.57      1000dumbbell       0.78      0.80      0.79      1000ear       0.81      0.75      0.78      1000elbow       0.74      0.62      0.68      1000elephant       0.66      0.66      0.66      1000envelope       0.87      0.94      0.90      1000eraser       0.50      0.61      0.55      1000eye       0.83      0.85      0.84      1000eyeglasses       0.84      0.80      0.82      1000face       0.62      0.64      0.63      1000fan       0.76      0.60      0.67      1000feather       0.58      0.60      0.59      1000fence       0.67      0.71      0.69      1000finger       0.70      0.63      0.67      1000fire hydrant       0.56      0.64      0.60      1000fireplace       0.74      0.67      0.71      1000firetruck       0.71      0.50      0.59      1000fish       0.89      0.85      0.87      1000flamingo       0.69      0.75      0.72      1000flashlight       0.80      0.82      0.81      1000flip flops       0.64      0.75      0.69      1000floor lamp       0.77      0.70      0.74      1000flower       0.79      0.83      0.81      1000flying saucer       0.65      0.64      0.64      1000foot       0.68      0.66      0.67      1000fork       0.81      0.79      0.80      1000frog       0.46      0.47      0.47      1000frying pan       0.78      0.76      0.77      1000garden       0.59      0.63      0.61      1000garden hose       0.42      0.28      0.33      1000giraffe       0.87      0.80      0.84      1000goatee       0.72      0.73      0.72      1000golf club       0.60      0.62      0.61      1000grapes       0.68      0.65      0.66      1000grass       0.59      0.83      0.69      1000guitar       0.68      0.50      0.58      1000hamburger       0.66      0.83      0.73      1000hammer       0.71      0.75      0.73      1000hand       0.83      0.83      0.83      1000harp       0.83      0.78      0.80      1000hat       0.72      0.71      0.72      1000headphones       0.92      0.91      0.92      1000hedgehog       0.73      0.74      0.73      1000helicopter       0.81      0.83      0.82      1000helmet       0.63      0.66      0.64      1000hexagon       0.70      0.73      0.72      1000hockey puck       0.59      0.61      0.60      1000hockey stick       0.59      0.54      0.56      1000horse       0.53      0.85      0.65      1000hospital       0.80      0.68      0.74      1000hot air balloon       0.79      0.72      0.75      1000hot dog       0.60      0.63      0.62      1000hot tub       0.58      0.51      0.54      1000hourglass       0.86      0.87      0.87      1000house       0.77      0.77      0.77      1000house plant       0.85      0.82      0.83      1000hurricane       0.39      0.45      0.42      1000ice cream       0.82      0.85      0.84      1000jacket       0.75      0.72      0.74      1000jail       0.71      0.72      0.71      1000kangaroo       0.73      0.71      0.72      1000key       0.71      0.76      0.74      1000keyboard       0.50      0.48      0.49      1000knee       0.63      0.68      0.65      1000ladder       0.88      0.91      0.89      1000lantern       0.70      0.53      0.60      1000laptop       0.63      0.80      0.71      1000leaf       0.73      0.71      0.72      1000leg       0.58      0.50      0.54      1000light bulb       0.69      0.79      0.73      1000lighthouse       0.71      0.74      0.72      1000lightning       0.76      0.69      0.72      1000line       0.55      0.82      0.66      1000lion       0.70      0.76      0.73      1000lipstick       0.59      0.69      0.63      1000lobster       0.61      0.47      0.53      1000lollipop       0.76      0.85      0.80      1000mailbox       0.75      0.66      0.70      1000map       0.65      0.73      0.68      1000marker       0.39      0.16      0.23      1000matches       0.52      0.47      0.49      1000megaphone       0.80      0.70      0.75      1000mermaid       0.76      0.84      0.80      1000microphone       0.64      0.73      0.68      1000microwave       0.79      0.75      0.77      1000monkey       0.59      0.56      0.57      1000moon       0.69      0.60      0.64      1000mosquito       0.48      0.55      0.51      1000motorbike       0.64      0.62      0.63      1000mountain       0.74      0.80      0.77      1000mouse       0.53      0.46      0.49      1000moustache       0.75      0.72      0.73      1000mouth       0.72      0.76      0.74      1000mug       0.54      0.65      0.59      1000mushroom       0.66      0.76      0.70      1000nail       0.58      0.66      0.62      1000necklace       0.75      0.63      0.68      1000nose       0.69      0.75      0.72      1000ocean       0.54      0.54      0.54      1000octagon       0.71      0.62      0.66      1000octopus       0.89      0.83      0.86      1000onion       0.75      0.68      0.71      1000oven       0.50      0.39      0.44      1000owl       0.68      0.65      0.67      1000paint can       0.51      0.49      0.50      1000paintbrush       0.58      0.63      0.61      1000palm tree       0.73      0.83      0.78      1000panda       0.66      0.62      0.64      1000pants       0.75      0.68      0.71      1000paper clip       0.75      0.78      0.76      1000parachute       0.81      0.79      0.80      1000parrot       0.54      0.59      0.56      1000passport       0.60      0.55      0.58      1000peanut       0.70      0.73      0.71      1000pear       0.72      0.80      0.76      1000peas       0.70      0.56      0.62      1000pencil       0.58      0.60      0.59      1000penguin       0.69      0.78      0.73      1000piano       0.65      0.66      0.65      1000pickup truck       0.60      0.64      0.62      1000picture frame       0.68      0.89      0.77      1000pig       0.77      0.56      0.65      1000pillow       0.60      0.58      0.59      1000pineapple       0.80      0.85      0.82      1000pizza       0.65      0.77      0.70      1000pliers       0.69      0.55      0.61      1000police car       0.67      0.68      0.67      1000pond       0.40      0.47      0.43      1000pool       0.51      0.23      0.32      1000popsicle       0.70      0.79      0.75      1000postcard       0.74      0.58      0.65      1000potato       0.54      0.40      0.46      1000power outlet       0.61      0.72      0.66      1000purse       0.64      0.69      0.66      1000rabbit       0.66      0.80      0.72      1000raccoon       0.43      0.44      0.44      1000radio       0.71      0.59      0.64      1000rain       0.77      0.90      0.83      1000rainbow       0.79      0.92      0.85      1000rake       0.69      0.67      0.68      1000remote control       0.67      0.68      0.67      1000rhinoceros       0.65      0.75      0.69      1000river       0.66      0.61      0.64      1000roller coaster       0.70      0.52      0.60      1000rollerskates       0.86      0.83      0.84      1000sailboat       0.84      0.87      0.86      1000sandwich       0.50      0.68      0.57      1000saw       0.81      0.83      0.82      1000saxophone       0.79      0.77      0.78      1000school bus       0.51      0.44      0.47      1000scissors       0.80      0.84      0.82      1000scorpion       0.70      0.76      0.73      1000screwdriver       0.58      0.62      0.60      1000sea turtle       0.79      0.73      0.76      1000see saw       0.85      0.79      0.82      1000shark       0.72      0.72      0.72      1000sheep       0.75      0.80      0.77      1000shoe       0.73      0.75      0.74      1000shorts       0.67      0.76      0.71      1000shovel       0.62      0.73      0.67      1000sink       0.62      0.76      0.68      1000skateboard       0.83      0.85      0.84      1000skull       0.86      0.83      0.85      1000skyscraper       0.65      0.56      0.60      1000sleeping bag       0.55      0.59      0.57      1000smiley face       0.74      0.80      0.77      1000snail       0.79      0.90      0.84      1000snake       0.65      0.66      0.65      1000snorkel       0.79      0.73      0.76      1000snowflake       0.79      0.84      0.81      1000snowman       0.83      0.90      0.86      1000soccer ball       0.69      0.70      0.69      1000sock       0.77      0.75      0.76      1000speedboat       0.65      0.65      0.65      1000spider       0.72      0.79      0.76      1000spoon       0.69      0.57      0.63      1000spreadsheet       0.67      0.62      0.65      1000square       0.52      0.84      0.65      1000squiggle       0.41      0.40      0.40      1000squirrel       0.71      0.74      0.72      1000stairs       0.90      0.91      0.90      1000star       0.93      0.91      0.92      1000steak       0.53      0.46      0.49      1000stereo       0.61      0.68      0.64      1000stethoscope       0.87      0.75      0.81      1000stitches       0.71      0.79      0.75      1000stop sign       0.86      0.88      0.87      1000stove       0.71      0.66      0.69      1000strawberry       0.80      0.80      0.80      1000streetlight       0.75      0.71      0.73      1000string bean       0.51      0.39      0.44      1000submarine       0.83      0.67      0.74      1000suitcase       0.75      0.57      0.64      1000sun       0.87      0.88      0.87      1000swan       0.69      0.67      0.68      1000sweater       0.68      0.65      0.67      1000swing set       0.89      0.90      0.89      1000sword       0.85      0.81      0.83      1000t-shirt       0.80      0.78      0.79      1000table       0.73      0.76      0.74      1000teapot       0.82      0.77      0.80      1000teddy-bear       0.66      0.74      0.70      1000telephone       0.67      0.54      0.60      1000television       0.88      0.85      0.86      1000tennis racquet       0.86      0.74      0.80      1000tent       0.80      0.77      0.78      1000tiger       0.53      0.47      0.50      1000toaster       0.59      0.70      0.64      1000toe       0.67      0.63      0.65      1000toilet       0.74      0.80      0.77      1000tooth       0.72      0.74      0.73      1000toothbrush       0.74      0.76      0.75      1000toothpaste       0.54      0.56      0.55      1000tornado       0.63      0.69      0.66      1000tractor       0.65      0.71      0.68      1000traffic light       0.84      0.84      0.84      1000train       0.61      0.74      0.67      1000tree       0.72      0.75      0.73      1000triangle       0.87      0.93      0.90      1000trombone       0.58      0.48      0.53      1000truck       0.50      0.41      0.45      1000trumpet       0.65      0.49      0.56      1000umbrella       0.91      0.86      0.88      1000underwear       0.83      0.64      0.72      1000van       0.46      0.58      0.51      1000vase       0.82      0.67      0.74      1000violin       0.52      0.52      0.52      1000washing machine       0.74      0.78      0.76      1000watermelon       0.56      0.66      0.61      1000waterslide       0.57      0.70      0.63      1000whale       0.71      0.74      0.72      1000wheel       0.82      0.50      0.62      1000windmill       0.82      0.77      0.79      1000wine bottle       0.77      0.81      0.79      1000wine glass       0.86      0.85      0.86      1000wristwatch       0.72      0.74      0.73      1000yoga       0.60      0.57      0.58      1000zebra       0.73      0.66      0.69      1000zigzag       0.73      0.75      0.74      1000accuracy                           0.68    340000macro avg       0.69      0.68      0.68    340000weighted avg       0.69      0.68      0.68    340000




