卷积神经网络 (CNN)

代码功能

网络结构:
卷积层:
两个卷积层,每个卷积层后接 ReLU 激活函数。
最大池化层用于降低维度。
全连接层:
使用一个隐藏层(128 个神经元)和一个输出层(10 类分类任务)。
数据集:
使用 PyTorch 内置的 MNIST 数据集,其中包含手写数字的灰度图像。
训练过程:
使用交叉熵损失函数 (CrossEntropyLoss)。
优化器为 Adam,学习率设为 0.001。
每轮训练输出损失。
测试与可视化:
测试模型在测试集上的准确率。
可视化 6 张测试样本的预测结果与真实标签。
在这里插入图片描述

代码

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt# 1. 定义卷积神经网络
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # 卷积层 1nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),  # 最大池化层nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # 卷积层 2nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2)  # 最大池化层)self.fc_layers = nn.Sequential(nn.Flatten(),nn.Linear(64 * 7 * 7, 128),  # 全连接层 1nn.ReLU(),nn.Linear(128, 10)  # 全连接层 2 (10 类分类))def forward(self, x):x = self.conv_layers(x)x = self.fc_layers(x)return x# 2. 加载 MNIST 数据集
def load_data(batch_size=64):transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # 标准化])train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)return train_loader, test_loader# 3. 训练 CNN
def train_cnn(model, train_loader, criterion, optimizer, epochs=5):model.train()for epoch in range(epochs):total_loss = 0for images, labels in train_loader:images, labels = images.cuda(), labels.cuda()  # 将数据移动到 GPU(如适用)# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播与优化optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")# 4. 测试 CNN
def test_cnn(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.cuda(), labels.cuda()  # 将数据移动到 GPU(如适用)outputs = model(images)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Test Accuracy: {100 * correct / total:.2f}%")# 5. 可视化测试结果
def visualize_predictions(model, test_loader):model.eval()images, labels = next(iter(test_loader))images, labels = images.cuda(), labels.cuda()outputs = model(images)_, predicted = torch.max(outputs, 1)# 绘制图像与预测结果images, labels, predicted = images.cpu(), labels.cpu(), predicted.cpu()plt.figure(figsize=(12, 8))for i in range(6):plt.subplot(2, 3, i + 1)plt.imshow(images[i].squeeze(), cmap='gray')plt.title(f"True: {labels[i]}, Pred: {predicted[i]}")plt.axis('off')plt.show()# 主程序
if __name__ == "__main__":# 加载数据train_loader, test_loader = load_data()# 初始化网络、损失函数和优化器model = CNN().cuda()  # 将模型移动到 GPU(如适用)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练和测试模型train_cnn(model, train_loader, criterion, optimizer, epochs=5)test_cnn(model, test_loader)# 可视化部分测试结果visualize_predictions(model, test_loader)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/16535.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

等保二级需要哪些安全设备?

在信息化高速发展的今天,服务器的安全性成为了企业乃至国家信息安全的重要基石。等保二级,作为信息安全等级保护制度中的一个关键环节,对服务器的安全防护提出了明确要求。本文将详细阐述服务器等保二级所需的各种安全设备,旨在为…

C++【深入项目-检测键盘】

神马是检测键盘,就是让编辑器可以检测键盘按下了什么按键,我们先科普复习检测键盘 。 检测键盘需要用到一些函数,请见下: ! KEY_DOWN( 80 ) 这个代码是检测按下键盘上P按键。那80是什么?原来是对应按键的&#xff0…

问题An object named ‘ResNetArcFace‘ was already registered in ‘arch‘ registry!

在安装 GFPGAN 的时候,一切都顺利,但是执行的时候出现了错误,哦还有一个问题, 问题一 就是如果basicsr安装不成功可以执行如下命令 pip install -i https://mirrors.aliyun.com/pypi/simple tb-nightly pip install -i https:/…

Leecode刷题C语言之最少翻转次数使二进制矩阵回文①

执行结果:通过 执行用时和内存消耗如下: 题目:最少翻转次数使二进制矩阵回文① 给你一个 m x n 的二进制矩阵 grid 。如果矩阵中一行或者一列从前往后与从后往前读是一样的,那么我们称这一行或者这一列是 回文 的。你可以将 grid 中任意格子…

K8S containerd拉取harbor镜像

前言 接前面的环境 K8S 1.24以后开始启用docker作为CRI,这里用containerd拉取 正文 vim /etc/containerd/config.toml #修改内容如下 #sandbox_image "registry.aliyuncs.com/google_containers/pause:3.10" systemd_cgroup true [plugins."io.…

三、计算机视觉_01图像的基本操作

0 前言 图像的读取和处理是计算机视觉领域中的一个基本任务,在Python中,有几个流行的库可以用来读取和处理图像数据 0.1 Matplotlib介绍 Matplotlib是Python中一个非常流行的绘图库,它通常用于数据可视化,虽然它不是专门的图像…

Liunx-Ubuntu22.04.1系统下配置Anaconda+pycharm+pytorch-gpu环境配置

这里写自定义目录标题 Liunx-Ubuntu22.04.1系统下配置Anacondapycharmpytorch-gpu环境配置一、Anaconda3配置1.Anaconda安装2.Anaconda更新3.Anaconda删除 二、pycharm配置1.pycharm安装 三、pytorch配置 Liunx-Ubuntu22.04.1系统下配置Anacondapycharmpytorch-gpu环境配置 一…

[Mysql] Mysql的多表查询----多表关系(下)

4、操作 方式二&#xff1a;创建表之后设置外键约束 外键约束也可以在修改表时添加&#xff0c;但是添加外键约束的前提是&#xff1a;从表中外键列中的数据必须与主表中主键列中的数据一致或者是没有数据。 语法&#xff1a; alter table <从表名> add constr…

WukongCRM:github高分开源项目,基于微服务架构 +vue ElementUI的前后端分离CRM系统

嗨&#xff0c;大家好&#xff0c;我是小华同学&#xff0c;关注我们获得“最新、最全、最优质”开源项目和高效工作学习方法 WukongCRM 是一款基于 Spring Cloud Alibaba 微服务架构和 Vue ElementUI 前后端分离的 CRM 系统。它在中国开源管理软件行业具有较高的知名度&#x…

【C#】C#编程入门指南:构建你的.NET开发基础

文章目录 前言&#xff1a;1. C# 开发环境 VS的基本熟悉2. 解决方案与项目的关系3. 编辑、编译、链接、运行4. 托管代码和CLR4.1 CLR&#xff1a;4.2 C# 代码第编译过程&#xff08;两次编译的&#xff09; 5. 命名空间6. 类的组成与分析7. C# 的数据类型7.1 值类型7.2 引用类型…

文心一言 VS 讯飞星火 VS chatgpt (392)-- 算法导论25.1 6题

六、假定我们还希望在本节所讨论的算法里计算出最短路径上的结点。说明如何在 O ( n 3 ) O(n^3) O(n3) 时间内从已经计算出的最短路径权重矩阵 L L L 计算出前驱矩阵Ⅱ。如果要写代码&#xff0c;请用go语言。 文心一言&#xff1a; 要计算前驱矩阵 $ \pi $&#xff0c;我们…

《网络硬件设备完全技术宝典》

《网络硬件设备完全技术宝典》 网卡 集线器 交换机 路由器 双绞线 光缆 无线接入点AP 交换机技术与选择策略 冗余链路技术 由于物理链路和网络模块的损坏都将导致网络链路的失败&#xff0c;因此两个设备之间&#xff0c;特别是核心交换机与汇聚交换机之间的单链路…

CC3学习记录

&#x1f338; CC3 之前学习到的cc1和cc6都是通过Runtime进行命令执行的&#xff0c;如果Runtime被加入黑名单的话&#xff0c;整个链子也就失效了。而cc3则是通过动态类加载机制进行任意代码执行的。 &#x1f338; 版本限制 JDK版本&#xff1a;8u65 Commons-Collections…

机器学习 ---线性回归

目录 摘要&#xff1a; 一、简单线性回归与多元线性回归 1、简单线性回归 2、多元线性回归 3、残差 二、线性回归的正规方程解 1、线性回归训练流程 2、线性回归的正规方程解 &#xff08;1&#xff09;适用场景 &#xff08;2&#xff09;正规方程解的公式 三、衡量…

麒麟服务器日志采集(服务器端)

服务端配置接收模块和监听端口 vim /etc/rsyslog.conf Copy 在rsyslog.conf内输入以下内容 #### MODULES #### module(load"imudp") # needs to be done just once input(type"imudp" port"514") module(load"imtcp") # needs to …

物联网低功耗广域网LoRa开发(三):Lora人机界面

一、TFT液晶屏驱动开发 &#xff08;一&#xff09;驱动源码移植 &#xff08;二&#xff09;硬件接口初始化 根据硬件设计&#xff0c;LoRa与LCD共用SPI总线&#xff0c;且LCD_MISO用于命令/数据模式切换控制 需要修改gpio初始化源码&#xff0c;让片选接口拉高(三)TFT液晶屏…

Android setTheme设置透明主题无效

【问题现象】 1、首先&#xff0c;你在AndroidManifest.xml中声明一个activity&#xff0c;不给application或者activity设置android:theme, 例如这样&#xff1a; <applicationandroid:allowBackup"true"android:icon"mipmap/ic_launcher"android:lab…

JavaScript--定时器

一.定义 关于JavaScript中的计时事件&#xff1f; JavaScript 一个设定的时间间隔之后来执行代码&#xff0c;我们称之为计时事件&#xff08;菜鸟说…&#xff09; 二.方法 2.1计时器 setInterval() &#xff1a; 是什么&#xff1a;这个方法设置一个定时器&#xff0c;…

数据分析-48-时间序列变点检测之在线实时数据的CPD

文章目录 1 时间序列结构1.1 变化点的定义1.2 结构变化的类型1.2.1 水平变化1.2.2 方差变化1.3 变点检测1.3.1 离线数据检测方法1.3.2 实时数据检测方法2 模拟数据2.1 模拟恒定方差数据2.2 模拟变化方差数据3 实时数据CPD3.1 SDAR学习算法3.2 Changefinder模块3.3 恒定方差CPD3…

厦门凯酷全科技有限公司正规吗?

在这个短视频风起云涌的时代&#xff0c;抖音作为电商领域的黑马&#xff0c;正以惊人的速度改变着消费者的购物习惯与品牌的市场策略。在这场变革中&#xff0c;厦门凯酷全科技有限公司凭借其专业的抖音电商服务&#xff0c;在众多服务商中脱颖而出&#xff0c;成为众多品牌信…