当前位置: 首页 > news >正文

【PyTorch】colab上跑VGG(深度学习)数据集是 CIFAR10

跑得结果是测试准确率10%,欠拟合。

import torch
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transformstransform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor()])
device =  torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")train_data = datasets.CIFAR10(root='cifar', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='cifar', train=False, download=True, transform=transform)train_data_size = len(train_data)
test_data_size = len (test_data)print("Training data size: {}".format(train_data_size))
print("Testing data size: {}".format(test_data_size))train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)class VGG(nn.Module):def __init__(self):super(VGG, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 64, 3, 1, 1), #224 * 224*64nn.ReLU(),nn.Conv2d(64, 64, 3, 1, 1), #224 * 224*64nn.ReLU(),nn.MaxPool2d(2,2), #112 * 112*64nn.Conv2d(64,128, 3, 1, 1),#112 * 112*128nn.ReLU(),nn.Conv2d(128,128, 3, 1, 1), #112 * 112*128nn.ReLU(),nn.MaxPool2d(2,2), #56 * 56*128nn.Conv2d(128,256, 3, 1, 1), #56 * 56*256nn.ReLU(),nn.Conv2d(256,256, 3, 1, 1), #56 * 56*256nn.ReLU(),nn.MaxPool2d(2,2), #28 * 28*256nn.Conv2d(256,512, 3, 1, 1), #28 * 28*512nn.ReLU(),nn.Conv2d(512,512, 3, 1, 1), #28 * 28*512nn.ReLU(),nn.Conv2d(512,512, 3, 1, 1), #28 * 28*512nn.ReLU(),nn.MaxPool2d(2,2), #14 * 14*512nn.Conv2d(512,512, 3, 1, 1), #14 * 14*512nn.ReLU(),nn.Conv2d(512,512, 3, 1, 1), #14 * 14*512nn.ReLU(),nn.Conv2d(512,512, 3, 1, 1), #14 * 14*512nn.ReLU(),nn.MaxPool2d(2,2), #7 * 7*512nn.Flatten(), #7*7*512 -> 25088nn.Linear(25088, 4096), #25088 -> 4096nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 4096), #4096 -> 4096nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 1000), #4096 -> 1000)def forward(self, x):x = self.model(x)return xvgg = VGG()
vgg = vgg.to(device)loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.SGD(vgg.parameters(), lr = 0.01)total_train_step = 0
total_test_step = 0epoch = 10writer = SummaryWriter("../logs")
for i in range(epoch):print("---------------------------第{}轮训练开始-------------------------------------".format(i+1))vgg.train()for data in train_dataloader:imgs, targets = dataimgs, targets = imgs.to(device), targets.to(device)outputs = vgg (imgs)loss = loss_fn(outputs, targets)optim.zero_grad()loss.backward()optim.step()total_train_step += 1if total_train_step % 100 == 0:print("训练次数 {},损失值:{}".format(total_train_step,loss))writer.add_scalar("train_loss", loss.item(),total_train_step)#Testingtotal_test_loss = 0total_accuracy = 0vgg.eval()with torch.no_grad():for data in test_dataloader:imgs, targets = dataimgs, targets = imgs.to(device), targets.to(device)outputs = vgg(imgs)loss = loss_fn(outputs, targets)total_test_loss += lossaccuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint("整体测试集的Loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss", total_test_loss, total_test_step)total_test_step += 1torch.save(vgg, "Elite_{}.pth".format(i))
writer.close()
http://www.xdnf.cn/news/28693.html

相关文章:

  • B端APP设计:打破传统限制,为企业开启便捷新通道
  • 软件架构分层策略对比及Go项目实践
  • 深度解析 SOA:架构原理、核心优势与实践挑战
  • 2025年渗透测试面试题总结-拷打题库06(题目+回答)
  • LeetCode每日一题4.19
  • 【Bluedroid】蓝牙存储模块配置管理:启动、读写、加密与保存流程解析
  • sqlilabs-Less之HTTP头部参数的注入——基础篇
  • [HCIP] OSPF 综合实验
  • Vue3+TS中svg图标的使用
  • 数据分析与挖掘
  • RAGFlow在Docker中运行Ollama直接运行于主机的基础URL的地址
  • opencv 给图片和视频添加水印
  • leetcode57.插入区间
  • Windows系统C盘深度清理指南
  • 车载诊断新架构--- SOVD初入门(上)
  • 23种设计模式-创建型模式之原型模式(Java版本)
  • 医疗器械电磁兼容相关标准
  • 豆瓣图书数据采集与可视化分析(一)- 豆瓣图书数据爬取
  • 性能比拼: Deno vs. Node.js vs. Bun (2025版)
  • C++之虚函数 Virtual Function
  • Redis 的持久化机制(RDB, AOF)对微服务的数据一致性和恢复性有何影响?如何选择?
  • 零基础上手Python数据分析 (18):Matplotlib 基础绘图 - 让数据“开口说话”
  • FPGA——基于DE2_115实现DDS信号发生器
  • FPGA IO引脚 K7-认知4
  • 【java实现+4种变体完整例子】排序算法中【插入排序】的详细解析,包含基础实现、常见变体的完整代码示例,以及各变体的对比表格
  • windows下用xmake交叉编译鸿蒙.so库
  • 交换机与路由器的主要区别:深入分析其工作原理与应用场景
  • hackmyvm-airbind
  • 【人工智能学习-01-01】20250419《数字图像处理》复习材料的word合并PDF,添加页码
  • AI 趋势下 Python 的崛起:深度剖析