网络结构:
import torch
from torch import nn
from d2l import torch as d2lclass Reshape(torch.nn.Module):def forward(self, x):return x.view(-1, 1, 224, 224) #由于数据集采用的是mnist数据集,所以输入通道数为1net = torch.nn.Sequential(Reshape(), nn.Conv2d(1, 96, kernel_size = 11, stride = 4), nn.ReLU(),nn.MaxPool2d(3, stride = 2),nn.Conv2d(96, 256, kernel_size = 5, padding = 2, stride = 1), nn.ReLU(),nn.MaxPool2d(3, stride = 2),nn.Conv2d(256, 384, kernel_size = 3, padding = 1, stride = 1), nn.ReLU(),nn.Conv2d(384, 384, kernel_size = 3, padding = 1, stride = 1), nn.ReLU(),nn.Conv2d(384, 256, kernel_size = 3, padding = 1, stride = 1), nn.ReLU(),nn.MaxPool2d(3, stride = 2),nn.Flatten(),nn.Linear(6400, 4096), nn.ReLU(),nn.Dropout(),nn.Linear(4096, 4096), nn.ReLU(),nn.Dropout(),nn.Linear(4096, 1000), nn.Softmax(),
)
打印每层图片尺寸:
X = torch.randn(1, 1, 224, 224)
for layer in net:X = lay(X)print(layer.__class__.__name__, 'Output shape:\t', X.shape)
设置训练参数:
batch_size = 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size,resize = 224)
进行训练:
lr, num_epochs = 0.01, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())