深度学习笔记_4、CNN卷积神经网络+全连接神经网络解决MNIST数据

1、首先,导入所需的库和模块,包括NumPy、PyTorch、MNIST数据集、数据处理工具、模型层、优化器、损失函数、混淆矩阵、绘图工具以及数据处理工具。

import numpy as np
import torch
from torchvision.datasets import mnist
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import csv
import pandas as pd

2、设置超参数,包括训练批次大小、测试批次大小、学习率和训练周期数。

# 设置超参数
train_batch_size = 64
test_batch_size = 64
learning_rate = 0.001
num_epochs = 10

3、创建数据转换管道,将图像数据转换为张量并进行标准化。

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])
])

4、下载和预处理MNIST数据集,分为训练集和测试集。

# 下载和预处理数据集
train_dataset = mnist.MNIST('data', train=True, transform=transform, download=True)
test_dataset = mnist.MNIST('data', train=False, transform=transform)

5、创建用于训练和测试的数据加载器,以便有效地加载数据。

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

 6、定义了一个简单的CNN模型,包括两个卷积层和两个全连接层。

# 定义CNN模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=5)self.conv2 = nn.Conv2d(32, 64, kernel_size=5)self.fc1 = nn.Linear(1024, 256)self.fc2 = nn.Linear(256, 10)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2(x), 2))x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)

7、初始化模型、优化器和损失函数。

# 初始化模型、优化器和损失函数
model = CNN()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

8、准备用于记录训练和测试过程中损失和准确率的列表。

# 记录训练和测试过程中的损失和准确率
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

9、进入训练循环,遍历每个训练周期。在每个训练周期内,进入训练模式,遍历训练数据批次,计算损失、反向传播并更新模型参数,同时记录训练损失和准确率。

for epoch in range(num_epochs):model.train()train_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()train_loss += loss.item()# 计算训练准确率_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 计算平均训练损失和训练准确率train_loss /= len(train_loader)train_accuracy = 100. * correct / totaltrain_losses.append(train_loss)train_accuracies.append(train_accuracy)  # 记录训练准确率# 测试模型model.eval()test_loss = 0.0correct = 0all_labels = []all_preds = []with torch.no_grad():for data, target in test_loader:output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()all_labels.extend(target.numpy())all_preds.extend(pred.numpy())

10、在每个训练周期结束后,进入测试模式,遍历测试数据批次,计算测试损失和准确率,同时记录它们。打印每个周期的训练和测试损失以及准确率。

# 计算平均测试损失和测试准确率test_loss /= len(test_loader)test_accuracy = 100. * correct / len(test_loader.dataset)test_losses.append(test_loss)test_accuracies.append(test_accuracy)print(f'Epoch [{epoch + 1}/{num_epochs}] -> Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

11、losses、acces、eval_losses、eval_acces保存到TXT文件

# 保存训练结果
data = np.column_stack((train_losses,test_losses,train_accuracies, test_accuracies))
np.savetxt("results.txt", data)

12、绘制Loss、ACC图像

# 绘制Loss曲线图
plt.figure(figsize=(10, 2))
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(test_losses, label='Test Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.grid(True)
plt.savefig('loss_curve.png')
plt.show()# 绘制Accuracy曲线图
plt.figure(figsize=(10, 2))
plt.plot(train_accuracies, label='Train Accuracy', color='red')  # 绘制训练准确率曲线
plt.plot(test_accuracies, label='Test Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curve')
plt.grid(True)
plt.savefig('accuracy_curve.png')
plt.show()

 

 13、绘制混淆矩阵图像

# 计算混淆矩阵
confusion_mat = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(confusion_mat, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.savefig('confusion_matrix.png')
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/148614.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

uniapp 实现下拉筛选框 二次开发定制

前言 最近又收到了一个需求,需要在uniapp 小程序上做一个下拉筛选框,然后找了一下插件市场,确实有找到,但不过他不支持搜索,于是乎,我就自动动手,进行了二开定制,站在巨人的肩膀上&…

归并排序及其非递归实现

个人主页:Lei宝啊 愿所有美好如期而遇 目录 归并排序递归实现 归并排序非递归实现 归并排序递归实现 图示: 代码: 先分再归并,像是后序一般。 //归并排序 void MergeSort(int* arr, int left, int right) {int* temp (int…

tcp滑动窗口原理

18.1 滑动窗口 我们再来看这个比喻: 网络仅仅是保证了整个网络的连通性,我们我们基于整个网络去传输,那么是不是我想发送多少数据就发送多少数据呢?如果是这样的话,是不是就会像我们的从一个池塘抽水去灌到另外一个…

速看:免费领取4台阿里云服务器_申请入口及领取流程

注册阿里云账号,免费领云服务器,最高领取4台云服务器,每月750小时,3个月免费试用时长,可快速搭建网站/小程序,部署开发环境,开发多种企业应用。阿里云服务器网分享阿里云服务器免费领取入口、免…

openGauss学习笔记-89 openGauss 数据库管理-内存优化表MOT管理-内存表特性-使用MOT-MOT使用查询原生编译

文章目录 openGauss学习笔记-89 openGauss 数据库管理-内存优化表MOT管理-内存表特性-使用MOT-MOT使用查询原生编译89.1 查询编译:PREPARE语句89.2 运行命令89.3 轻量执行支持的查询89.4 轻量执行不支持的查询89.5 JIT存储过程89.6 MOT JIT诊断89.6.1 mot_jit_detai…

Pygame中监控鼠标动作的方法

在Pygame中监控键盘按键的方法_pygame获取键盘输入-CSDN博客中提到,通过在while True循环中获取队列中事件的方法监控键盘动作。监控鼠标动作的方法与监控键盘动作的方法相同。 相关连接1 队列与事件的相关知识,请参考 Pygame中监控键盘按键的方法_pyg…

网络爬虫--伪装浏览器

从用户请求的Headers反反爬 在访问某些网站的时候,网站通常会用判断访问是否带有头文件来鉴别该访问是否为爬虫,用来作为反爬取的一种策略。很多网站都会对Headers的User-Agent进行检测,还有一部分网站会对Referer进行检测(一些资…

电脑通过串口助手和51单片机串口通讯

今天有时间把电脑和51单片机之间的串口通讯搞定了,电脑发送的串口数据,单片机能够正常接收并显示到oled屏幕上,特此记录一下,防止后面自己忘记了怎么搞得了。 先来两个图片看看结果吧! 下面是串口3.c的文件全部内容&a…

Eureka

大家好我是苏麟今天带来Eureka的使用 . 提供者和消费者 在服务调用关系中,会有两个不同的角色: 服务提供者:一次业务中,被其它微服务调用的服务。(提供接口给其它微服务) 服务消费者:一次业务…

微信小程序点单左右联动的效果实现

微信小程序点单左右联动的效果实现 原理解析:   点击左边标签会跳到右边相应位置:点击改变rightCur值,转跳相应位置滑动右边,左边标签会跳到相应的位置:监听并且设置每个右边元素的top和bottom,再判断当…

计算机网络基础(一):网络系统概述、OSI七层模型、TCP/IP协议及数据传输

通信,在古代是通过书信与他人互通信息的意思。 今天,“通信”这个词的外沿已经得到了极大扩展,它目前的大意是指双方或多方借助某种媒介实现信息互通的行为。 如果按照当代汉语的方式理解“通信”,那么古代的互遣使节、飞鸽传书…

竞赛选题 机器视觉的试卷批改系统 - opencv python 视觉识别

文章目录 0 简介1 项目背景2 项目目的3 系统设计3.1 目标对象3.2 系统架构3.3 软件设计方案 4 图像预处理4.1 灰度二值化4.2 形态学处理4.3 算式提取4.4 倾斜校正4.5 字符分割 5 字符识别5.1 支持向量机原理5.2 基于SVM的字符识别5.3 SVM算法实现 6 算法测试7 系统实现8 最后 0…

Windows安装Docker并创建Ubuntu环境及运行神经网络模型

目录 前言在Windows上安装Docker在Docker上创建Ubuntu镜像并运行容器创建Ubuntu镜像配置容器,使其可以在宿主机上显示GUI 创建容器并运行神经网络模型创建容器随便找一个神经网络模型试试 总结 前言 学生党一般用个人电脑玩神经网络,估计很少有自己的服…

遗留系统陷入困境

当我们谈论遗留系统时,我们经常会想到数据中心某处的过时服务器和交换机。我们带着一种病态的迷恋阅读了有关系统性技术问题的文章,这些问题在假期周末困扰着其他旅行者,并为他们缺乏远见而摇头。 然后,我们坐在屏幕前&#xff0…

c语言实现玫瑰花

浅浅跟波风 1.效果图 2.代码实现 #define _CRT_SECURE_NO_WARNINGS #include <stdio.h> #include <math.h>const int max_iterations 128; const float stop_threshold 0.01f; const float grad_step 0.01f; const float clip_far 10.0f;const float PI 3.1…

【网络安全】2023年堡垒机品牌大全

随着大家网络安全意识的增加&#xff0c;随着国家等保政策的严格执行&#xff0c;越来越多的企业开始采购堡垒机。这里就给大家总结了部分堡垒机品牌&#xff0c;让大家参考参考。 2023年堡垒机品牌大全 1、行云堡垒 2、JumpServer 3、安恒 4、骞云 5、齐治 6、阿里云 …

【项目开发 | C语言项目 | C语言病人管理系统】

该项目旨在为医院或其他医疗机构提供一个简易的病人信息管理工具。用户可以通过命令行界面进行病人信息的增、删、查和改操作&#xff0c;并将数据持久化存储在txt文件中。 一&#xff0c;开发环境需求 操作系统 &#xff1a;Windows, Linux 开发环境工具 &#xff1a;Qt, VSC…

大模型 Decoder 的生成策略

本文将介绍以下内容&#xff1a; IntroductionGreedy Searchbeam searchSamplingTop-K SamplingTop-p (nucleus) sampling总结 一、Introduction 1、简介 近年来&#xff0c;由于在数百万个网页数据上训练的大型基于 Transformer 的语言模型的兴起&#xff0c;开放式语言生…

安装NodeJS并使用yarn下载前端依赖

文章目录 1、安装NodeJS1.1 下载NodeJS安装包1.2 解压并配置NodeJS1.3 验证是否安装成功2、使用yarn下载前端依赖2.1 安装yarn2.2 使用yarn下载前端依赖参考目标:在Windows下安装新版NodeJS,并使用yarn下载前端依赖,实现运行前端项目。 1、安装NodeJS 1.1 下载NodeJS安装包…

Git小书系列笔记

Git准备 首先根据自己的系统安装git&#xff0c;安装成功后可以通过如下指令查看git版本。 使用Git之前&#xff0c;需要配置用户名称和电子邮件。 1.设置全局的用户名和电子邮件 git config --global user.name "Your Name" git config --global user.email &quo…