深度学习camp-第J3-1周:DenseNet算法 实现乳腺癌识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

我的环境

  • 语言环境:Python 3.12
  • 编译器:Jupyter Lab
  • 深度学习环境:Pytorch 2.4.1 Torchvision 0.19.1
  • 数据集:乳腺癌数据集

一、前期准备

今天我们使用前面的DenseNet实现对乳腺癌的识别

1、设置GPU以及库导入

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os, PIL, pathlib
from collections import OrderedDict
import torchsummary as summarydevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

代码输出:

device(type='cuda')

2、数据的导入以及预处理

data_dir = './data/J3-1-data'
data_dir = pathlib.Path(data_dir)data_path = list(data_dir.glob('*'))
classNames = [path.name for path in data_path]
print(classNames)

代码输出:

['0', '1']

可以看到,我们这次的数据只有两类,0代表不是乳腺癌,1代表是乳腺癌

接下来我们设置transforms:

train_transforms = transforms.Compose([transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])total_data = datasets.ImageFolder(data_dir, transform=train_transforms)
total_data

代码输出:

Dataset ImageFolderNumber of datapoints: 13403Root location: data\J3-1-dataStandardTransform
Transform: Compose(Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)ToTensor()Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))

总共有13403张图片,我们都使用transform对数据进行前期的标准化处理。

随后我们划分训练集,测试集以及验证集:

train_size = int(0.7 * len(total_data)) 
remain_size  = len(total_data) - train_size  
train_dataset, remain_dataset = torch.utils.data.random_split(total_data, [train_size, remain_size])
test_size = int(0.6 * len(remain_dataset))
validate_size = len(remain_dataset) - test_size
test_dataset, validate_dataset = torch.utils.data.random_split(remain_dataset, [test_size, validate_size]) #随机分配数据
train_dataset, test_dataset, validate_dataset

代码输出:

(<torch.utils.data.dataset.Subset at 0x22815024c20>,<torch.utils.data.dataset.Subset at 0x2281501d5e0>,<torch.utils.data.dataset.Subset at 0x22815024710>)

这里显示的是内存地址。

接下来使用dataloader对数据集进行加载:

batch_size = 32train_dl = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)test_dl = DataLoader(test_dataset,batch_size = batch_size,shuffle = True
)validate_dl = DataLoader(validate_dataset,batch_size = batch_size,shuffle = False
)for x, y in validate_dl:print("shape of x [N, C, H, W]:", x.shape)print("shape of y:", y.shape, y.dtype)break

代码输出:

shape of x [N, C, H, W]: torch.Size([32, 3, 224, 224])
shape of y: torch.Size([32]) torch.int64

3、数据的可视化

# 定义反归一化函数
def unnormalize(img, mean, std):mean = np.array(mean)std = np.array(std)img = img * std + mean  # 反归一化return np.clip(img, 0, 1)  # 限制值范围到 [0, 1]plt.figure(figsize=(10, 5))
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]for images, labels in validate_dl:  # 从 DataLoader 中获取一个批次for i in range(8):  # 显示前 8 张图片ax = plt.subplot(2, 4, i + 1)  # 创建 2 行 4 列的子图# 反归一化并转换为 (H, W, C)img = images[i].permute(1, 2, 0).numpy()  # (C, H, W) -> (H, W, C)img = unnormalize(img, mean, std)  # 反归一化# 显示图像plt.imshow(img)  # 显示图像,值范围应为 [0, 1]plt.title(classNames[labels[i].item()])  # 使用类别名称作为标题plt.axis("off")  # 关闭坐标轴break  # 仅显示第一个批次

代码输出:
在这里插入图片描述

二、DenseNet网络构建

我们使用上周构建的DenseNet121:

class _DenseLayer(nn.Sequential):def __init__(self,num_input_features, growth_rate, bn_size, drop_rate):super(_DenseLayer,self).__init__()self.add_module('norm1',nn.BatchNorm2d(num_input_features))self.add_module('relu1',nn.ReLU(inplace=True))self.add_module('conv1',nn.Conv2d(num_input_features, bn_size*growth_rate, kernel_size=1, stride=1,bias = False))self.add_module('norm2',nn.BatchNorm2d(bn_size*growth_rate))self.add_module('relu2',nn.ReLU(inplace=True))self.add_module('conv2',nn.Conv2d(bn_size*growth_rate, growth_rate, kernel_size=3, stride=1,padding=1, bias = False))self.drop_rate = drop_ratedef forward(self,x):new_features = super(_DenseLayer, self).forward(x)if self.drop_rate > 0:new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)return torch.cat([x, new_features],1)class _DenseBlock(nn.Sequential):def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):super(_DenseBlock,self).__init__()for i in range(num_layers):layer = _DenseLayer(num_input_features+i*growth_rate, growth_rate, bn_size, drop_rate)self.add_module('denselayer%d'%(i+1,),layer)class _Transition(nn.Sequential):def __init__(self, num_input_features, num_output_features):super(_Transition,self).__init__()self.add_module('norm',nn.BatchNorm2d(num_input_features))self.add_module('relu',nn.ReLU(inplace=True))self.add_module('conv',nn.Conv2d(num_input_features, num_output_features, kernel_size=1,stride=1, bias=False))self.add_module('pool',nn.AvgPool2d(2, stride=2))class DenseNet(nn.Module):def __init__(self, growth_rate = 32, block_config =(6, 12, 24, 16), num_init_features=64, bn_size = 4, compression_rate = 0.5, drop_rate = 0, num_classes = 1000):super(DenseNet, self).__init__()self.features = nn.Sequential(OrderedDict([('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),('norm0', nn.BatchNorm2d(num_init_features)),('relu0', nn.ReLU(inplace=True)),('pool0', nn.MaxPool2d(3, stride=2, padding=1))]))num_features = num_init_featuresfor i, num_layers in enumerate(block_config):block = _DenseBlock(num_layers, num_features, bn_size, growth_rate, drop_rate)self.features.add_module('denseblock%d' % (i + 1), block)num_features += num_layers * growth_rateif i != len(block_config) - 1:transition = _Transition(num_features, int(num_features * compression_rate))self.features.add_module('transition%d' % (i + 1), transition)num_features = int(num_features * compression_rate)self.features.add_module('norm5', nn.BatchNorm2d(num_features))self.features.add_module('relu5', nn.ReLU(inplace=True))self.classifier = nn.Linear(num_features, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight,1)elif isinstance(m, nn.Linear):nn.init.constant_(m.bias,0)def forward(self, x):features = self.features(x)out = F.avg_pool2d(features, kernel_size=7).view(features.size(0), -1)out = self.classifier(out)return outdensenet121 = DenseNet(num_init_features=64,growth_rate=32,block_config=(6, 12, 24, 6),num_classes=len(classNames))model = densenet121.cuda()
model

代码输出:

DenseNet((features): Sequential((conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu0): ReLU(inplace=True)(pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(denseblock1): _DenseBlock((denselayer1): _DenseLayer((norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer2): _DenseLayer((norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer3): _DenseLayer((norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer4): _DenseLayer((norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer5): _DenseLayer((norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer6): _DenseLayer((norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)))(transition1): _Transition((norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(pool): AvgPool2d(kernel_size=2, stride=2, padding=0))(denseblock2): _DenseBlock((denselayer1): _DenseLayer((norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer2): _DenseLayer((norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer3): _DenseLayer((norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer4): _DenseLayer((norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer5): _DenseLayer((norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer6): _DenseLayer((norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer7): _DenseLayer((norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer8): _DenseLayer((norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer9): _DenseLayer((norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer10): _DenseLayer((norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer11): _DenseLayer((norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer12): _DenseLayer((norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)))(transition2): _Transition((norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(pool): AvgPool2d(kernel_size=2, stride=2, padding=0))(denseblock3): _DenseBlock((denselayer1): _DenseLayer((norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer2): _DenseLayer((norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer3): _DenseLayer((norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer4): _DenseLayer((norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer5): _DenseLayer((norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer6): _DenseLayer((norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer7): _DenseLayer((norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer8): _DenseLayer((norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer9): _DenseLayer((norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer10): _DenseLayer((norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer11): _DenseLayer((norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer12): _DenseLayer((norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer13): _DenseLayer((norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer14): _DenseLayer((norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer15): _DenseLayer((norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer16): _DenseLayer((norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer17): _DenseLayer((norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer18): _DenseLayer((norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer19): _DenseLayer((norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer20): _DenseLayer((norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer21): _DenseLayer((norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer22): _DenseLayer((norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer23): _DenseLayer((norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer24): _DenseLayer((norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)))(transition3): _Transition((norm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(pool): AvgPool2d(kernel_size=2, stride=2, padding=0))(denseblock4): _DenseBlock((denselayer1): _DenseLayer((norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer2): _DenseLayer((norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer3): _DenseLayer((norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer4): _DenseLayer((norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer5): _DenseLayer((norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer6): _DenseLayer((norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)))(norm5): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu5): ReLU(inplace=True))(classifier): Linear(in_features=704, out_features=2, bias=True)
)

我们对模型进行总结:

summary.summary(model, (3, 224, 224))

代码输出:

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 64, 112, 112]           9,408BatchNorm2d-2         [-1, 64, 112, 112]             128ReLU-3         [-1, 64, 112, 112]               0MaxPool2d-4           [-1, 64, 56, 56]               0BatchNorm2d-5           [-1, 64, 56, 56]             128ReLU-6           [-1, 64, 56, 56]               0Conv2d-7          [-1, 128, 56, 56]           8,192BatchNorm2d-8          [-1, 128, 56, 56]             256ReLU-9          [-1, 128, 56, 56]               0Conv2d-10           [-1, 32, 56, 56]          36,864BatchNorm2d-11           [-1, 96, 56, 56]             192ReLU-12           [-1, 96, 56, 56]               0Conv2d-13          [-1, 128, 56, 56]          12,288BatchNorm2d-14          [-1, 128, 56, 56]             256ReLU-15          [-1, 128, 56, 56]               0Conv2d-16           [-1, 32, 56, 56]          36,864BatchNorm2d-17          [-1, 128, 56, 56]             256ReLU-18          [-1, 128, 56, 56]               0Conv2d-19          [-1, 128, 56, 56]          16,384BatchNorm2d-20          [-1, 128, 56, 56]             256ReLU-21          [-1, 128, 56, 56]               0Conv2d-22           [-1, 32, 56, 56]          36,864BatchNorm2d-23          [-1, 160, 56, 56]             320ReLU-24          [-1, 160, 56, 56]               0Conv2d-25          [-1, 128, 56, 56]          20,480BatchNorm2d-26          [-1, 128, 56, 56]             256ReLU-27          [-1, 128, 56, 56]               0Conv2d-28           [-1, 32, 56, 56]          36,864BatchNorm2d-29          [-1, 192, 56, 56]             384ReLU-30          [-1, 192, 56, 56]               0Conv2d-31          [-1, 128, 56, 56]          24,576BatchNorm2d-32          [-1, 128, 56, 56]             256ReLU-33          [-1, 128, 56, 56]               0Conv2d-34           [-1, 32, 56, 56]          36,864BatchNorm2d-35          [-1, 224, 56, 56]             448ReLU-36          [-1, 224, 56, 56]               0Conv2d-37          [-1, 128, 56, 56]          28,672BatchNorm2d-38          [-1, 128, 56, 56]             256ReLU-39          [-1, 128, 56, 56]               0Conv2d-40           [-1, 32, 56, 56]          36,864BatchNorm2d-41          [-1, 256, 56, 56]             512ReLU-42          [-1, 256, 56, 56]               0Conv2d-43          [-1, 128, 56, 56]          32,768AvgPool2d-44          [-1, 128, 28, 28]               0BatchNorm2d-45          [-1, 128, 28, 28]             256ReLU-46          [-1, 128, 28, 28]               0Conv2d-47          [-1, 128, 28, 28]          16,384BatchNorm2d-48          [-1, 128, 28, 28]             256ReLU-49          [-1, 128, 28, 28]               0Conv2d-50           [-1, 32, 28, 28]          36,864BatchNorm2d-51          [-1, 160, 28, 28]             320ReLU-52          [-1, 160, 28, 28]               0Conv2d-53          [-1, 128, 28, 28]          20,480BatchNorm2d-54          [-1, 128, 28, 28]             256ReLU-55          [-1, 128, 28, 28]               0Conv2d-56           [-1, 32, 28, 28]          36,864BatchNorm2d-57          [-1, 192, 28, 28]             384ReLU-58          [-1, 192, 28, 28]               0Conv2d-59          [-1, 128, 28, 28]          24,576BatchNorm2d-60          [-1, 128, 28, 28]             256ReLU-61          [-1, 128, 28, 28]               0Conv2d-62           [-1, 32, 28, 28]          36,864BatchNorm2d-63          [-1, 224, 28, 28]             448ReLU-64          [-1, 224, 28, 28]               0Conv2d-65          [-1, 128, 28, 28]          28,672BatchNorm2d-66          [-1, 128, 28, 28]             256ReLU-67          [-1, 128, 28, 28]               0Conv2d-68           [-1, 32, 28, 28]          36,864BatchNorm2d-69          [-1, 256, 28, 28]             512ReLU-70          [-1, 256, 28, 28]               0Conv2d-71          [-1, 128, 28, 28]          32,768BatchNorm2d-72          [-1, 128, 28, 28]             256ReLU-73          [-1, 128, 28, 28]               0Conv2d-74           [-1, 32, 28, 28]          36,864BatchNorm2d-75          [-1, 288, 28, 28]             576ReLU-76          [-1, 288, 28, 28]               0Conv2d-77          [-1, 128, 28, 28]          36,864BatchNorm2d-78          [-1, 128, 28, 28]             256ReLU-79          [-1, 128, 28, 28]               0Conv2d-80           [-1, 32, 28, 28]          36,864BatchNorm2d-81          [-1, 320, 28, 28]             640ReLU-82          [-1, 320, 28, 28]               0Conv2d-83          [-1, 128, 28, 28]          40,960BatchNorm2d-84          [-1, 128, 28, 28]             256ReLU-85          [-1, 128, 28, 28]               0Conv2d-86           [-1, 32, 28, 28]          36,864BatchNorm2d-87          [-1, 352, 28, 28]             704ReLU-88          [-1, 352, 28, 28]               0Conv2d-89          [-1, 128, 28, 28]          45,056BatchNorm2d-90          [-1, 128, 28, 28]             256ReLU-91          [-1, 128, 28, 28]               0Conv2d-92           [-1, 32, 28, 28]          36,864BatchNorm2d-93          [-1, 384, 28, 28]             768ReLU-94          [-1, 384, 28, 28]               0Conv2d-95          [-1, 128, 28, 28]          49,152BatchNorm2d-96          [-1, 128, 28, 28]             256ReLU-97          [-1, 128, 28, 28]               0Conv2d-98           [-1, 32, 28, 28]          36,864BatchNorm2d-99          [-1, 416, 28, 28]             832ReLU-100          [-1, 416, 28, 28]               0Conv2d-101          [-1, 128, 28, 28]          53,248BatchNorm2d-102          [-1, 128, 28, 28]             256ReLU-103          [-1, 128, 28, 28]               0Conv2d-104           [-1, 32, 28, 28]          36,864BatchNorm2d-105          [-1, 448, 28, 28]             896ReLU-106          [-1, 448, 28, 28]               0Conv2d-107          [-1, 128, 28, 28]          57,344BatchNorm2d-108          [-1, 128, 28, 28]             256ReLU-109          [-1, 128, 28, 28]               0Conv2d-110           [-1, 32, 28, 28]          36,864BatchNorm2d-111          [-1, 480, 28, 28]             960ReLU-112          [-1, 480, 28, 28]               0Conv2d-113          [-1, 128, 28, 28]          61,440BatchNorm2d-114          [-1, 128, 28, 28]             256ReLU-115          [-1, 128, 28, 28]               0Conv2d-116           [-1, 32, 28, 28]          36,864BatchNorm2d-117          [-1, 512, 28, 28]           1,024ReLU-118          [-1, 512, 28, 28]               0Conv2d-119          [-1, 256, 28, 28]         131,072AvgPool2d-120          [-1, 256, 14, 14]               0BatchNorm2d-121          [-1, 256, 14, 14]             512ReLU-122          [-1, 256, 14, 14]               0Conv2d-123          [-1, 128, 14, 14]          32,768BatchNorm2d-124          [-1, 128, 14, 14]             256ReLU-125          [-1, 128, 14, 14]               0Conv2d-126           [-1, 32, 14, 14]          36,864BatchNorm2d-127          [-1, 288, 14, 14]             576ReLU-128          [-1, 288, 14, 14]               0Conv2d-129          [-1, 128, 14, 14]          36,864BatchNorm2d-130          [-1, 128, 14, 14]             256ReLU-131          [-1, 128, 14, 14]               0Conv2d-132           [-1, 32, 14, 14]          36,864BatchNorm2d-133          [-1, 320, 14, 14]             640ReLU-134          [-1, 320, 14, 14]               0Conv2d-135          [-1, 128, 14, 14]          40,960BatchNorm2d-136          [-1, 128, 14, 14]             256ReLU-137          [-1, 128, 14, 14]               0Conv2d-138           [-1, 32, 14, 14]          36,864BatchNorm2d-139          [-1, 352, 14, 14]             704ReLU-140          [-1, 352, 14, 14]               0Conv2d-141          [-1, 128, 14, 14]          45,056BatchNorm2d-142          [-1, 128, 14, 14]             256ReLU-143          [-1, 128, 14, 14]               0Conv2d-144           [-1, 32, 14, 14]          36,864BatchNorm2d-145          [-1, 384, 14, 14]             768ReLU-146          [-1, 384, 14, 14]               0Conv2d-147          [-1, 128, 14, 14]          49,152BatchNorm2d-148          [-1, 128, 14, 14]             256ReLU-149          [-1, 128, 14, 14]               0Conv2d-150           [-1, 32, 14, 14]          36,864BatchNorm2d-151          [-1, 416, 14, 14]             832ReLU-152          [-1, 416, 14, 14]               0Conv2d-153          [-1, 128, 14, 14]          53,248BatchNorm2d-154          [-1, 128, 14, 14]             256ReLU-155          [-1, 128, 14, 14]               0Conv2d-156           [-1, 32, 14, 14]          36,864BatchNorm2d-157          [-1, 448, 14, 14]             896ReLU-158          [-1, 448, 14, 14]               0Conv2d-159          [-1, 128, 14, 14]          57,344BatchNorm2d-160          [-1, 128, 14, 14]             256ReLU-161          [-1, 128, 14, 14]               0Conv2d-162           [-1, 32, 14, 14]          36,864BatchNorm2d-163          [-1, 480, 14, 14]             960ReLU-164          [-1, 480, 14, 14]               0Conv2d-165          [-1, 128, 14, 14]          61,440BatchNorm2d-166          [-1, 128, 14, 14]             256ReLU-167          [-1, 128, 14, 14]               0Conv2d-168           [-1, 32, 14, 14]          36,864BatchNorm2d-169          [-1, 512, 14, 14]           1,024ReLU-170          [-1, 512, 14, 14]               0Conv2d-171          [-1, 128, 14, 14]          65,536BatchNorm2d-172          [-1, 128, 14, 14]             256ReLU-173          [-1, 128, 14, 14]               0Conv2d-174           [-1, 32, 14, 14]          36,864BatchNorm2d-175          [-1, 544, 14, 14]           1,088ReLU-176          [-1, 544, 14, 14]               0Conv2d-177          [-1, 128, 14, 14]          69,632BatchNorm2d-178          [-1, 128, 14, 14]             256ReLU-179          [-1, 128, 14, 14]               0Conv2d-180           [-1, 32, 14, 14]          36,864BatchNorm2d-181          [-1, 576, 14, 14]           1,152ReLU-182          [-1, 576, 14, 14]               0Conv2d-183          [-1, 128, 14, 14]          73,728BatchNorm2d-184          [-1, 128, 14, 14]             256ReLU-185          [-1, 128, 14, 14]               0Conv2d-186           [-1, 32, 14, 14]          36,864BatchNorm2d-187          [-1, 608, 14, 14]           1,216ReLU-188          [-1, 608, 14, 14]               0Conv2d-189          [-1, 128, 14, 14]          77,824BatchNorm2d-190          [-1, 128, 14, 14]             256ReLU-191          [-1, 128, 14, 14]               0Conv2d-192           [-1, 32, 14, 14]          36,864BatchNorm2d-193          [-1, 640, 14, 14]           1,280ReLU-194          [-1, 640, 14, 14]               0Conv2d-195          [-1, 128, 14, 14]          81,920BatchNorm2d-196          [-1, 128, 14, 14]             256ReLU-197          [-1, 128, 14, 14]               0Conv2d-198           [-1, 32, 14, 14]          36,864BatchNorm2d-199          [-1, 672, 14, 14]           1,344ReLU-200          [-1, 672, 14, 14]               0Conv2d-201          [-1, 128, 14, 14]          86,016BatchNorm2d-202          [-1, 128, 14, 14]             256ReLU-203          [-1, 128, 14, 14]               0Conv2d-204           [-1, 32, 14, 14]          36,864BatchNorm2d-205          [-1, 704, 14, 14]           1,408ReLU-206          [-1, 704, 14, 14]               0Conv2d-207          [-1, 128, 14, 14]          90,112BatchNorm2d-208          [-1, 128, 14, 14]             256ReLU-209          [-1, 128, 14, 14]               0Conv2d-210           [-1, 32, 14, 14]          36,864BatchNorm2d-211          [-1, 736, 14, 14]           1,472ReLU-212          [-1, 736, 14, 14]               0Conv2d-213          [-1, 128, 14, 14]          94,208BatchNorm2d-214          [-1, 128, 14, 14]             256ReLU-215          [-1, 128, 14, 14]               0Conv2d-216           [-1, 32, 14, 14]          36,864BatchNorm2d-217          [-1, 768, 14, 14]           1,536ReLU-218          [-1, 768, 14, 14]               0Conv2d-219          [-1, 128, 14, 14]          98,304BatchNorm2d-220          [-1, 128, 14, 14]             256ReLU-221          [-1, 128, 14, 14]               0Conv2d-222           [-1, 32, 14, 14]          36,864BatchNorm2d-223          [-1, 800, 14, 14]           1,600ReLU-224          [-1, 800, 14, 14]               0Conv2d-225          [-1, 128, 14, 14]         102,400BatchNorm2d-226          [-1, 128, 14, 14]             256ReLU-227          [-1, 128, 14, 14]               0Conv2d-228           [-1, 32, 14, 14]          36,864BatchNorm2d-229          [-1, 832, 14, 14]           1,664ReLU-230          [-1, 832, 14, 14]               0Conv2d-231          [-1, 128, 14, 14]         106,496BatchNorm2d-232          [-1, 128, 14, 14]             256ReLU-233          [-1, 128, 14, 14]               0Conv2d-234           [-1, 32, 14, 14]          36,864BatchNorm2d-235          [-1, 864, 14, 14]           1,728ReLU-236          [-1, 864, 14, 14]               0Conv2d-237          [-1, 128, 14, 14]         110,592BatchNorm2d-238          [-1, 128, 14, 14]             256ReLU-239          [-1, 128, 14, 14]               0Conv2d-240           [-1, 32, 14, 14]          36,864BatchNorm2d-241          [-1, 896, 14, 14]           1,792ReLU-242          [-1, 896, 14, 14]               0Conv2d-243          [-1, 128, 14, 14]         114,688BatchNorm2d-244          [-1, 128, 14, 14]             256ReLU-245          [-1, 128, 14, 14]               0Conv2d-246           [-1, 32, 14, 14]          36,864BatchNorm2d-247          [-1, 928, 14, 14]           1,856ReLU-248          [-1, 928, 14, 14]               0Conv2d-249          [-1, 128, 14, 14]         118,784BatchNorm2d-250          [-1, 128, 14, 14]             256ReLU-251          [-1, 128, 14, 14]               0Conv2d-252           [-1, 32, 14, 14]          36,864BatchNorm2d-253          [-1, 960, 14, 14]           1,920ReLU-254          [-1, 960, 14, 14]               0Conv2d-255          [-1, 128, 14, 14]         122,880BatchNorm2d-256          [-1, 128, 14, 14]             256ReLU-257          [-1, 128, 14, 14]               0Conv2d-258           [-1, 32, 14, 14]          36,864BatchNorm2d-259          [-1, 992, 14, 14]           1,984ReLU-260          [-1, 992, 14, 14]               0Conv2d-261          [-1, 128, 14, 14]         126,976BatchNorm2d-262          [-1, 128, 14, 14]             256ReLU-263          [-1, 128, 14, 14]               0Conv2d-264           [-1, 32, 14, 14]          36,864BatchNorm2d-265         [-1, 1024, 14, 14]           2,048ReLU-266         [-1, 1024, 14, 14]               0Conv2d-267          [-1, 512, 14, 14]         524,288AvgPool2d-268            [-1, 512, 7, 7]               0BatchNorm2d-269            [-1, 512, 7, 7]           1,024ReLU-270            [-1, 512, 7, 7]               0Conv2d-271            [-1, 128, 7, 7]          65,536BatchNorm2d-272            [-1, 128, 7, 7]             256ReLU-273            [-1, 128, 7, 7]               0Conv2d-274             [-1, 32, 7, 7]          36,864BatchNorm2d-275            [-1, 544, 7, 7]           1,088ReLU-276            [-1, 544, 7, 7]               0Conv2d-277            [-1, 128, 7, 7]          69,632BatchNorm2d-278            [-1, 128, 7, 7]             256ReLU-279            [-1, 128, 7, 7]               0Conv2d-280             [-1, 32, 7, 7]          36,864BatchNorm2d-281            [-1, 576, 7, 7]           1,152ReLU-282            [-1, 576, 7, 7]               0Conv2d-283            [-1, 128, 7, 7]          73,728BatchNorm2d-284            [-1, 128, 7, 7]             256ReLU-285            [-1, 128, 7, 7]               0Conv2d-286             [-1, 32, 7, 7]          36,864BatchNorm2d-287            [-1, 608, 7, 7]           1,216ReLU-288            [-1, 608, 7, 7]               0Conv2d-289            [-1, 128, 7, 7]          77,824BatchNorm2d-290            [-1, 128, 7, 7]             256ReLU-291            [-1, 128, 7, 7]               0Conv2d-292             [-1, 32, 7, 7]          36,864BatchNorm2d-293            [-1, 640, 7, 7]           1,280ReLU-294            [-1, 640, 7, 7]               0Conv2d-295            [-1, 128, 7, 7]          81,920BatchNorm2d-296            [-1, 128, 7, 7]             256ReLU-297            [-1, 128, 7, 7]               0Conv2d-298             [-1, 32, 7, 7]          36,864BatchNorm2d-299            [-1, 672, 7, 7]           1,344ReLU-300            [-1, 672, 7, 7]               0Conv2d-301            [-1, 128, 7, 7]          86,016BatchNorm2d-302            [-1, 128, 7, 7]             256ReLU-303            [-1, 128, 7, 7]               0Conv2d-304             [-1, 32, 7, 7]          36,864BatchNorm2d-305            [-1, 704, 7, 7]           1,408ReLU-306            [-1, 704, 7, 7]               0Linear-307                    [-1, 2]           1,410
================================================================
Total params: 5,481,026
Trainable params: 5,481,026
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 286.44
Params size (MB): 20.91
Estimated Total Size (MB): 307.92
----------------------------------------------------------------

三、模型训练

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)num_batches = len(dataloader)train_loss, train_acc = 0, 0for x, y in dataloader:x, y = x.to(device), y.to(device)pred = model(x)loss = loss_fn(pred, y)#backwardoptimizer.zero_grad()loss.backward()optimizer.step()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_lossdef test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)test_loss, test_acc = 0, 0for x, y in dataloader:x, y = x.to(device), y.to(device)pred = model(x)loss = loss_fn(pred, y)test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss += loss.item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss
import copy
from torch.optim.lr_scheduler import ReduceLROnPlateauopt = torch.optim.Adam(model.parameters(), lr= 1e-4)
scheduler = ReduceLROnPlateau(opt, mode='min', factor=0.1, patience=5, verbose=True) # 当指标(如损失)连续 5 次没有改善时,将学习率乘以 0.1
loss_fn = nn.CrossEntropyLoss() # 交叉熵epochs = 32train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []best_acc = 0    # 设置一个最佳准确率,作为最佳模型的判别指标for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)scheduler.step(epoch_test_loss)if epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = opt.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(best_model.state_dict(), PATH)print('Done')

代码输出:

Epoch: 1, Train_acc:84.8%, Train_loss:0.353, Test_acc:88.2%, Test_loss:0.290, Lr:1.00E-04
Epoch: 2, Train_acc:88.1%, Train_loss:0.287, Test_acc:89.8%, Test_loss:0.259, Lr:1.00E-04
Epoch: 3, Train_acc:89.0%, Train_loss:0.269, Test_acc:89.3%, Test_loss:0.278, Lr:1.00E-04
Epoch: 4, Train_acc:90.2%, Train_loss:0.240, Test_acc:90.8%, Test_loss:0.223, Lr:1.00E-04
Epoch: 5, Train_acc:90.5%, Train_loss:0.235, Test_acc:89.1%, Test_loss:0.266, Lr:1.00E-04
Epoch: 6, Train_acc:91.4%, Train_loss:0.218, Test_acc:90.9%, Test_loss:0.226, Lr:1.00E-04
Epoch: 7, Train_acc:91.9%, Train_loss:0.204, Test_acc:91.6%, Test_loss:0.229, Lr:1.00E-04
Epoch: 8, Train_acc:92.5%, Train_loss:0.191, Test_acc:91.2%, Test_loss:0.240, Lr:1.00E-04
Epoch: 9, Train_acc:92.2%, Train_loss:0.189, Test_acc:90.7%, Test_loss:0.227, Lr:1.00E-04
Epoch:10, Train_acc:93.0%, Train_loss:0.176, Test_acc:90.3%, Test_loss:0.244, Lr:1.00E-05
Epoch:11, Train_acc:95.3%, Train_loss:0.126, Test_acc:93.6%, Test_loss:0.178, Lr:1.00E-05
Epoch:12, Train_acc:95.9%, Train_loss:0.113, Test_acc:93.5%, Test_loss:0.170, Lr:1.00E-05
Epoch:13, Train_acc:96.3%, Train_loss:0.100, Test_acc:93.7%, Test_loss:0.179, Lr:1.00E-05
Epoch:14, Train_acc:96.6%, Train_loss:0.093, Test_acc:93.7%, Test_loss:0.176, Lr:1.00E-05
Epoch:15, Train_acc:97.1%, Train_loss:0.085, Test_acc:93.0%, Test_loss:0.185, Lr:1.00E-05
Epoch:16, Train_acc:96.9%, Train_loss:0.082, Test_acc:93.3%, Test_loss:0.182, Lr:1.00E-05
Epoch:17, Train_acc:97.5%, Train_loss:0.069, Test_acc:92.9%, Test_loss:0.184, Lr:1.00E-05
Epoch:18, Train_acc:97.6%, Train_loss:0.068, Test_acc:93.2%, Test_loss:0.195, Lr:1.00E-06
Epoch:19, Train_acc:98.3%, Train_loss:0.054, Test_acc:93.2%, Test_loss:0.187, Lr:1.00E-06
Epoch:20, Train_acc:98.3%, Train_loss:0.058, Test_acc:93.7%, Test_loss:0.186, Lr:1.00E-06
Epoch:21, Train_acc:98.3%, Train_loss:0.053, Test_acc:93.3%, Test_loss:0.185, Lr:1.00E-06
Epoch:22, Train_acc:98.2%, Train_loss:0.056, Test_acc:93.5%, Test_loss:0.187, Lr:1.00E-06
Epoch:23, Train_acc:98.5%, Train_loss:0.051, Test_acc:93.4%, Test_loss:0.191, Lr:1.00E-06
Epoch:24, Train_acc:98.5%, Train_loss:0.051, Test_acc:93.1%, Test_loss:0.184, Lr:1.00E-07
Epoch:25, Train_acc:98.3%, Train_loss:0.052, Test_acc:93.5%, Test_loss:0.184, Lr:1.00E-07
Epoch:26, Train_acc:98.5%, Train_loss:0.051, Test_acc:93.4%, Test_loss:0.186, Lr:1.00E-07
Epoch:27, Train_acc:98.1%, Train_loss:0.057, Test_acc:93.4%, Test_loss:0.187, Lr:1.00E-07
Epoch:28, Train_acc:98.2%, Train_loss:0.052, Test_acc:93.4%, Test_loss:0.191, Lr:1.00E-07
Epoch:29, Train_acc:98.4%, Train_loss:0.052, Test_acc:93.6%, Test_loss:0.188, Lr:1.00E-07
Epoch:30, Train_acc:98.3%, Train_loss:0.054, Test_acc:93.7%, Test_loss:0.184, Lr:1.00E-08
Epoch:31, Train_acc:98.4%, Train_loss:0.050, Test_acc:93.4%, Test_loss:0.184, Lr:1.00E-08
Epoch:32, Train_acc:98.5%, Train_loss:0.052, Test_acc:93.3%, Test_loss:0.184, Lr:1.00E-08
Done

四、数据可视化

epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

代码输出:
在这里插入图片描述
可以看到测试集的准确率可以达到93%左右

五、数据的预测

我们使用模型对代码进行测试:

plt.figure(figsize=(10, 5))  # 遍历验证数据集,取第一个批次
for images, labels in validate_dl:for i in range(8):  # 只显示前 8 张图片ax = plt.subplot(2, 4, i + 1)# 显示图片img = images[i].permute(1, 2, 0).numpy()  # 转换为 (H, W, C)img = unnormalize(img, mean, std)  # 反归一化plt.imshow(img)  # 显示图像,值范围为 [0, 1]# 增加一个维度用于模型预测img_tensor = images[i].unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")# 使用模型预测类别best_model.eval()  # 切换到评估模式with torch.no_grad():  # 禁用梯度计算predictions = best_model(img_tensor)  # 预测结果predicted_class_index = predictions.argmax(dim=1).item()  # 获取预测类别索引predicted_class = classNames[predicted_class_index]  # 获取预测类别名称# 获取真实类别名称true_class = classNames[labels[i].item()]# 设置标题为真实类别和预测类别plt.title(f"T: {true_class}\nP: {predicted_class}")plt.axis("off")  # 隐藏坐标轴# 打印真实类别和预测类别print(f"Image {i+1}: True Label = {true_class}, Predicted Label = {predicted_class}")break  # 只处理第一个批次

代码输出:

Image 1: True Label = 0, Predicted Label = 1
Image 2: True Label = 0, Predicted Label = 0
Image 3: True Label = 1, Predicted Label = 1
Image 4: True Label = 0, Predicted Label = 0
Image 5: True Label = 0, Predicted Label = 0
Image 6: True Label = 1, Predicted Label = 1
Image 7: True Label = 0, Predicted Label = 0
Image 8: True Label = 0, Predicted Label = 0

在这里插入图片描述
最后我们查看验证集的总体正确率:

def validate(dataloader, model):model.eval()size = len(dataloader.dataset)num_batches = len(dataloader)validate_acc = 0for x, y in dataloader:x, y = x.to(device), y.to(device)pred = model(x)validate_acc += (pred.argmax(1) == y).type(torch.float).sum().item()validate_acc /= sizereturn validate_acc# 计算验证集准确率
validate_acc = validate(validate_dl, best_model)
print(f"Validation Accuracy: {validate_acc:.2%}")

代码输出:

Validation Accuracy: 93.23%

准确率达到93.23%总体不错

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/36250.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Elasticsearch 单节点安全配置与用户认证

Elasticsearch 单节点安全配置与用户认证 安全扫描时发现了一个高危漏洞&#xff1a;Elasticsearch 未授权访问 。在使用 Elasticsearch 构建搜索引擎或处理大规模数据时&#xff0c;需要启用基本的安全功能来防止未经授权的访问。本文将通过简单的配置步骤&#xff0c;为单节…

Vulhub:Shiro[漏洞复现]

目录 CVE-2010-3863(Shiro未授权) 使用浏览器访问靶场主页面 使用Yakit进行抓包 使用ffuf对靶机8080端口进行根路径FUZZ CVE-2016-4437(Shiro-550) 使用浏览器访问靶场主页面 使用Yakit进行抓包 使用Yakit反连中自带的Yso-Java Hack进行漏洞利用 首先运行脚本生成一个…

数学拯救世界(一)———寻“数”记

一、 很久很久以前&#xff0c;在一个只认识整数和小数的国度&#xff0c;有一个很残暴的国王提了一个要求&#xff1a;要是不能表示出把一段1米的绳子三等分后的大小&#xff0c;就要把所有的大臣杀掉。 1➗3 0.333&#xff0c;怎么办呀&#xff1f;怎么办呀&#xff1f; 袁q…

Codeforces Round 991 (Div. 3)题解

先随随便便写一点东西吧&#xff0c;毕竟只是一场div3 A. Line Breaks 思路&#xff1a;一道很简单的模拟题吧&#xff0c;就是遍历一遍&#xff0c;当大于x的时候就break&#xff0c;然后前面那个就是找到的前x个字的总长度不超过m #include<bits/stdc.h> using names…

掌握谈判技巧,达成双赢协议

在当今竞争激烈且合作频繁的社会环境中&#xff0c;谈判成为了我们解决分歧、谋求共同发展的重要手段。无论是商业合作、职场交流&#xff0c;还是国际事务协商&#xff0c;掌握谈判技巧以达成双赢协议都具有极其关键的意义。它不仅能够让各方在利益分配上找到平衡点&#xff0…

基于Matlab特征提取与浅层神经网络的数字图像处理乳腺癌检测系统(GUI界面+训练代码+数据集)

本研究提出了一种结合数字图像处理技术、特征提取与浅层神经网络的创新癌症检测系统&#xff0c;旨在为医学图像的分析和早期癌症检测提供有效支持。系统主要处理癌症与正常组织的医学图像&#xff0c;通过灰度共生矩阵&#xff08;GLCM&#xff09;等方法&#xff0c;从图像中…

Backblaze 2024 Q3硬盘故障质量报告解读

作为一家在2021年在美国纳斯达克上市的云端备份公司&#xff0c;Backblaze一直保持着对外定期发布HDD和SSD的故障率稳定性质量报告&#xff0c;给大家提供了一份真实应用场景下的稳定性分析参考数据&#xff1a; 以往报告解读系列参考&#xff1a; Backblaze发布2024 Q2硬盘故障…

河工oj第七周补题题解2024

A.GO LecturesⅠ—— Victory GO LecturesⅠ—— Victory - 问题 - 软件学院OJ 代码 统计 #include<bits/stdc.h> using namespace std;double b, w;int main() {for(int i 1; i < 19; i ) {for(int j 1; j < 19; j ) {char ch; cin >> ch;if(ch B) b …

[ABC234A] Weird Function

解题思路 这是一道模拟题…… 设置一个函数 &#xff0c;返回值为 。 最后答案就是 。 代码 记得开 long long ! #include<bits/stdc.h> using namespace std;long long t; long long f(long long x) {return x*xx*23; }int main() {cin>>t;cout<<f(f(f…

蓝牙键鼠无法被电脑识别

起因是我的键鼠是三模的&#xff0c;但是我蓝牙模式我只用过几次&#xff0c;基本一直使用的是有线模式&#xff0c;最近突然要用无线连接&#xff0c;如果使用收发器就显得过于繁琐&#xff0c;还占用usb口&#xff0c;因此想用蓝牙连&#xff0c;但是由于 win10更新了英特尔…

【C#设计模式(18)——中介者模式(Mediator Pattern)】

前言 中介者模式&#xff1a;是两者之间通过第三者来帮助传话。 代码 //抽象接收者public abstract class Receiver{protected Mediator mediator;protected Receiver(Mediator mediator){this.mediator mediator;}public abstract void SendMessage(string message);public a…

动态计算加载图片

学习啦 别名路径&#xff1a;①npm install path --save-dev②配置 // vite.config,js import { defineConfig } from vite import vue from vitejs/plugin-vueimport { viteStaticCopy } from vite-plugin-static-copy import path from path export default defineConfig({re…

Java HashMap用法详解

文章目录 一、定义二、核心方法三、实例演示3.1、方法示例3.2、get()方法注意点&#xff01; 一、定义 Java 的 HashMap 是 Java 集合框架中的一个非常重要的类&#xff0c;它实现了 Map 接口。HashMap基于哈希表的数据结构&#xff0c;允许使用键-值对存储数据。这种存储方式使…

淘宝直播间智能化升级:基于LLM的学习与分析

自营直播应用技术团队负责的业务中&#xff0c;淘宝买菜的直播业务起步较晚&#xff0c;业务发展压力较大&#xff0c;业务上也就有了期望能够对一些二方的标杆直播间进行学习&#xff0c;并将其优点应用到自己直播间的需求。 最初 - 人海战术&#xff0c;学习PK 业务侧最直接的…

有的开发者用Apache-2.0开源协议,但是不允许商用?合理吗

Apache 2.0开源协议是设计用来允许商业使用的。该协议明确授予了使用者在遵守许可条款的情况下&#xff0c;对软件进行复制、修改、分发以及商业使用的权利。这包括但不限于&#xff1a; 1. 永久、全球性的版权许可&#xff1a;允许复制、准备衍生作品、公开展示、公开演出、从…

java学习 -----项目(1)

项目 写在前面的话&#xff1a;耳机没电&#xff0c;先来写写今早的感受。说实话&#xff0c;我并不喜欢我们的职业规划老师&#xff0c;满嘴荒唐言&#xff0c;被社会那所大缸浸染了一身社会气。课快结束时&#xff0c;老师问还有谁的视频没做&#xff0c;我把手举了起来。&a…

某j vue3 ts 随笔

因为ts组件封装的缘故&#xff0c;使用某个组件就必须按照这个组件的规则使用&#xff0c;老是忘记&#xff0c;这里就记一下吧 1.ApiSelect 组件 {label: 角色,field: selectedroles,component: ApiSelect,componentProps: {mode: multiple,api: getAllRolesListNoByTenant,la…

红旗Asianux8.1+高斯GaussDB6.0安装手册

一、简介 服务器系统&#xff1a;红旗Asianux8.1&#xff08;需联网&#xff09;高斯GaussDB6.0&#xff1a;openGauss_6.0.0 极简版 二、安装准备 关闭防火墙 systemctl stop firewalld systemctl disable firewalld###查看状态 systemctl status firewalld 上传安装包 创建组…

如何实现Docker容器自动更新?从此无需再手动更新!(如何实现docker容器的自动更新、docker容器如何实现定时更新)

以下是经过优化后的完整文章内容: 文章目录 📖 介绍 📖🏡 演示环境 🏡📒 Docker 容器自动更新的需求 📒📝 解决方案概述📝 Docker 容器自动更新📝 Docker 容器定期更新📝 实现指定容器更新或排除更新⚓️ 相关链接 ⚓️📖 介绍 📖 随着容器化技术的普…

python异常、模块和包

文章目录 异常异常简介异常处理捕获所有异常捕获指定异常捕获多个指定异常 异常else、finally异常的传递 模块模块导入自定义模块 包自定义python包安装第三方python包 综合案例 异常 异常简介 异常就是程序运行过程中出现了错误 f open(RLlearn_2.txt, "r", enc…