MNIST_FC

前言

提醒:
文章内容为方便作者自己后日复习与查阅而进行的书写与发布,其中引用内容都会使用链接表明出处(如有侵权问题,请及时联系)。
其中内容多为一次书写,缺少检查与订正,如有问题或其他拓展及意见建议,欢迎评论区讨论交流。

文章目录

  • 前言
  • 数据集
  • 全连接神经网络
      • 网络结构
      • 数学表示
      • 总结
      • 数学公式总结
      • 激活函数
      • 训练过程(反向传播)
      • 总结
  • 项目实例


数据集

数据集可参考:MNIST数据集_CNN
本文章采用全连接神经网络进行分析。

全连接神经网络

全连接神经网络(Fully Connected Neural Network,简称 FCNNDNN,即深度神经网络)是一种最基础的神经网络结构。在这种网络中,每一层的每个神经元都与上一层的每个神经元连接,因此称为“全连接”。

网络结构

全连接网络由多个层组成,常见的层包括输入层、隐藏层和输出层。每一层之间的连接是全连接的,即每一层的每个神经元都与上一层的所有神经元相连。

假设我们有一个包含 L L L 层的神经网络,输入层为 x ∈ R n x \in \mathbb{R}^{n} xRn(输入数据是 n n n 维的向量),经过 L − 1 L-1 L1 个隐藏层,最后输出层为 y ∈ R m y \in \mathbb{R}^{m} yRm(输出是 m m m 维的向量)。

数学表示

假设神经网络的第 l l l 层有 N l N_l Nl 个神经元,层与层之间的连接通过权重矩阵和偏置向量表示。

  1. 输入层到隐藏层的计算
    每一层的神经元输出是由前一层神经元加权求和之后,通过一个激活函数进行非线性变换得到的。

    对于第 l l l 层神经网络(假设 l ≥ 1 l \geq 1 l1),其输入是上一层的输出 a [ l − 1 ] \mathbf{a}^{[l-1]} a[l1],输出是该层的激活值 a [ l ] \mathbf{a}^{[l]} a[l]

    • 加权求和:第 l l l 层的加权求和公式为
      z [ l ] = W [ l ] a [ l − 1 ] + b [ l ] , \mathbf{z}^{[l]} = \mathbf{W}^{[l]} \mathbf{a}^{[l-1]} + \mathbf{b}^{[l]}, z[l]=W[l]a[l1]+b[l],
      其中, W [ l ] ∈ R N l × N l − 1 \mathbf{W}^{[l]} \in \mathbb{R}^{N_l \times N_{l-1}} W[l]RNl×Nl1 是该层的权重矩阵, b [ l ] ∈ R N l \mathbf{b}^{[l]} \in \mathbb{R}^{N_l} b[l]RNl 是偏置向量, a [ l − 1 ] ∈ R N l − 1 \mathbf{a}^{[l-1]} \in \mathbb{R}^{N_{l-1}} a[l1]RNl1 是上一层的激活值。

    • 激活函数:加权和 z [ l ] \mathbf{z}^{[l]} z[l] 会通过激活函数(如 ReLU、Sigmoid 或 Tanh)得到当前层的激活值:
      a [ l ] = f ( z [ l ] ) , \mathbf{a}^{[l]} = f(\mathbf{z}^{[l]}), a[l]=f(z[l]),
      其中 f ( ⋅ ) f(\cdot) f() 表示激活函数。

  2. 输出层的计算
    最后,输出层的计算通常也是类似的,输出是最终的预测值 y \mathbf{y} y。假设输出层没有激活函数(或者使用 softmax 激活函数用于分类问题),那么我们有:
    y = W [ L ] a [ L − 1 ] + b [ L ] , \mathbf{y} = \mathbf{W}^{[L]} \mathbf{a}^{[L-1]} + \mathbf{b}^{[L]}, y=W[L]a[L1]+b[L],
    其中 W [ L ] \mathbf{W}^{[L]} W[L] 是输出层的权重矩阵, b [ L ] \mathbf{b}^{[L]} b[L] 是输出层的偏置。

总结

因此,神经网络的前向传播过程可以概括为:

  1. 输入层到隐藏层:每一层的输入是上一层的输出,经过加权求和和激活函数得到输出。
  2. 最终输出层:输出层的计算与隐藏层类似,最终输出模型的预测结果。

数学公式总结

  1. 对于第 l l l 层:
    z [ l ] = W [ l ] a [ l − 1 ] + b [ l ] , \mathbf{z}^{[l]} = \mathbf{W}^{[l]} \mathbf{a}^{[l-1]} + \mathbf{b}^{[l]}, z[l]=W[l]a[l1]+b[l],
    a [ l ] = f ( z [ l ] ) , \mathbf{a}^{[l]} = f(\mathbf{z}^{[l]}), a[l]=f(z[l]),
    其中 f ( ⋅ ) f(\cdot) f() 是激活函数。

  2. 对于输出层:
    y = W [ L ] a [ L − 1 ] + b [ L ] . \mathbf{y} = \mathbf{W}^{[L]} \mathbf{a}^{[L-1]} + \mathbf{b}^{[L]}. y=W[L]a[L1]+b[L].

激活函数

在每一层的计算中,激活函数是引入非线性的关键。常用的激活函数包括:

  • Sigmoid 函数: f ( x ) = 1 1 + e − x f(x) = \frac{1}{1 + e^{-x}} f(x)=1+ex1,常用于二分类问题。
  • Tanh 函数: f ( x ) = e x − e − x e x + e − x f(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}} f(x)=ex+exexex,输出范围是 ([-1, 1])。
  • ReLU(Rectified Linear Unit): f ( x ) = max ⁡ ( 0 , x ) f(x) = \max(0, x) f(x)=max(0,x),对大多数问题表现良好,尤其是在深度学习中。
  • Softmax:用于多分类问题,输出的是概率分布。

训练过程(反向传播)

神经网络的训练主要依赖于 反向传播算法(Backpropagation)。通过计算损失函数的梯度,使用梯度下降法更新权重和偏置,使得网络的预测误差最小化。

  1. 损失函数
    常用的损失函数包括:

    • 均方误差(MSE):用于回归问题。
    • 交叉熵损失:用于分类问题。
  2. 反向传播
    通过链式法则计算每一层的梯度,并更新每一层的权重和偏置。假设 L \mathcal{L} L 是损失函数,更新规则为:
    W [ l ] ← W [ l ] − η ∂ L ∂ W [ l ] , \mathbf{W}^{[l]} \leftarrow \mathbf{W}^{[l]} - \eta \frac{\partial \mathcal{L}}{\partial \mathbf{W}^{[l]}}, W[l]W[l]ηW[l]L,
    b [ l ] ← b [ l ] − η ∂ L ∂ b [ l ] , \mathbf{b}^{[l]} \leftarrow \mathbf{b}^{[l]} - \eta \frac{\partial \mathcal{L}}{\partial \mathbf{b}^{[l]}}, b[l]b[l]ηb[l]L,
    其中, η \eta η 是学习率, ∂ L ∂ W [ l ] \frac{\partial \mathcal{L}}{\partial \mathbf{W}^{[l]}} W[l]L ∂ L ∂ b [ l ] \frac{\partial \mathcal{L}}{\partial \mathbf{b}^{[l]}} b[l]L 分别是权重和偏置的梯度。

总结

全连接神经网络是一种简单但非常强大的模型,它可以通过多层非线性变换来学习复杂的映射关系。在实践中,通常使用多个隐藏层来构建深度神经网络,并结合适当的优化算法(如梯度下降、Adam等)进行训练。

项目实例

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms# 设置设备为 GPU(如果可用),否则使用 CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 设置网络的超参数
input_size = 784  # 输入层的大小(MNIST图片28x28,扁平化后是784维)
hidden_size = 500  # 隐藏层的神经元数量
num_classes = 10  # 输出类别数(MNIST有10个数字分类)
num_epochs = 5  # 训练的轮数
batch_size = 100  # 每个批次的数据量
learning_rate = 0.001  # 学习率# 下载并加载 MNIST 训练数据集
train_dataset = torchvision.datasets.MNIST(root='data', train=True,  # 训练数据集download=True,  # 如果数据集不存在就下载transform=transforms.ToTensor())  # 将图片转换为Tensor格式# 下载并加载 MNIST 测试数据集
test_dataset = torchvision.datasets.MNIST(root='data', train=False,  # 测试数据集transform=transforms.ToTensor(),  # 同样转换为Tensor格式download=True)# 使用DataLoader将数据集加载为可迭代的批次,自动打乱训练数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)# 测试数据集的DataLoader,不进行打乱
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)# 打印训练集和测试集的长度
print(len(train_dataset), len(test_dataset))# 打印训练集和测试集的第一个样本的图像和标签
print(train_dataset[0][0].shape, train_dataset[0][1])
print(test_dataset[0][0].shape, test_dataset[0][1])# 打印测试集的一个批次的图像和标签形状
for X, y in test_loader:print(f"Shape of X [N, C, H, W]: {X.shape}")  # N:批次大小, C:通道数, H:高度, W:宽度print(f"Shape of y: {y.shape} {y.dtype}")  # y: 标签的形状和类型break  # 只查看一个批次
class NeuralNet(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super(NeuralNet, self).__init__()# 第一个全连接层,输入784维,输出500维self.fc1 = nn.Linear(input_size, hidden_size)# 激活函数,使用Sigmoidself.sigmoid = nn.Sigmoid()# 第二个全连接层,输入500维,输出10维(对应10个类别)self.fc2 = nn.Linear(hidden_size, num_classes)  # 定义前向传播过程def forward(self, x):out = self.fc1(x)  # 输入经过第一层out = self.sigmoid(out)  # 激活函数out = self.fc2(out)  # 输入经过第二层return out
# 实例化神经网络模型并转移到GPU(如果可用)
model = NeuralNet(input_size, hidden_size, num_classes).to(device)# 使用均方误差(MSELoss)作为损失函数,这对于多类分类问题通常不太合适,但这里示范用作教学
criterion = nn.MSELoss()# 使用Adam优化器,学习率为0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  # 获取训练集的总步数
total_step = len(train_loader)
# 训练过程
for epoch in range(num_epochs):  # 遍历每个epochfor i, (images, labels) in enumerate(train_loader):  # 遍历每个batchimages = images.reshape(-1, 28*28).to(device)  # 将每张28x28的图片展开成784维# 将标签转为one-hot编码形式,并转换为浮点类型(MSELoss要求标签是浮点数)labels = torch.nn.functional.one_hot(labels, num_classes=10).float().to(device)# 前向传播:将输入数据传入模型outputs = model(images)# 计算损失loss = criterion(outputs, labels)# 梯度清零,反向传播,更新参数optimizer.zero_grad()loss.backward()optimizer.step()# 每100个步骤打印一次当前损失if (i+1) % 100 == 0:print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
# 在测试集上评估模型的性能
with torch.no_grad():  # 不计算梯度correct = 0total = 0# 遍历测试数据for images, labels in test_loader:images = images.reshape(-1, 28*28).to(device)  # 扁平化图片labels = labels.to(device)  # 获取标签# 前向传播:通过模型获得输出outputs = model(images)# 获取预测类别(根据最大输出值选择类别)predicted = torch.argmax(outputs.data, dim=1)total += labels.size(0)  # 统计总数correct += (predicted == labels).sum().item()  # 统计正确预测的个数# 打印测试集上的准确率print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))

代码运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/35557.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

掌握时间,从`datetime`开始

文章目录 掌握时间,从datetime开始第一部分:背景介绍第二部分:datetime库是什么?第三部分:如何安装这个库?第四部分:简单库函数使用方法1. 获取当前日期和时间2. 创建特定的日期3. 计算两个日期…

算法之括号匹配中最长有效字符串

目录 1. 题目2. 解释3. 思路4. 代码5. 总结 1. 题目 任何一个左括号都能找到和其正确配对的右括号任何一个右括号都能找到和其正确配对的左括号 求最长的有效的括号长度 2. 解释 例如,这里的括号 ((((()()()()()()()))()最长有效是:((()()()()()()(…

统信桌面专业版部署postgresql-14.2+postgis-3.2方法介绍

文章来源:统信桌面专业版部署postgresql-14.2postgis-3.2方法介绍 | 统信软件-知识分享平台 应用场景 CPU架构:X86(海光C86-3G 3350) OS版本信息:1070桌面专业版 软件信息:postgresql-14.2postgis-3.2 …

【书生大模型实战营】Python 基础知识-L0G2000

前言:本文是书生大模型实战营系列的第2篇文章,是入门岛的第二个任务,主题为:Python基础知识。 官方教程参考链接:Tutorial/docs/L0/Python at camp4 InternLM/Tutorial 1.任务概览 本关为Python基础关卡&#xff0…

智能安全新时代:大语言模型与智能体在网络安全中的革命性应用

一、引言 随着信息技术的飞速发展,网络安全问题日益严重,成为各行各业面临的重大挑战。传统的安全防护措施已难以应对日益复杂的网络威胁,人工智能(AI)技术的引入为网络安全带来了新的希望。特别是大语言模型&#xff…

数仓技术hive与oracle对比(三)

更新处理 oracle使用dblink透明网关连接其他数据库,mysql、sqlserver、oracle,然后用sql、plsql更新数据;或者使用etl工具实现更新。 hive使用sqoop连接mysql、sqlserver、oracle实现数据更新。 oracle oracle数据加载命令 批量sql脚本上…

在 Vue.js 中使用对象映射和枚举类型

学习啦! 对象映射是一种将一个对象的属性名映射到另一个对象的属性名的方法。 const keyMapping {username: 用户名, gender: { label: 性别, mapping: gender }, // gender 映射为 性别email: 邮箱, // email 映射为 邮箱phone: 电话, // phone 映射为 电话addres…

嵌入式学习(15)-stm32通用GPIO模拟串口发送数据

一、概述 在项目开发中可能会遇到串口不够用的情况这时候可以用通过GPIO来模拟串口的通信方式。 二、协议格式 按照1位起始位8位数据位1位停止位的方式去编写发送端的程序。起始位拉低一个波特率的时间;发送8位数据;拉高一个波特率的时间。 三、代码 …

【C语言期末复习全攻略】:知识点汇总与考试重点剖析、附刷题资料软件

零、引用 期末考试临近,无论你是初学者还是“熬夜选手”,C语言的学习都需要系统梳理和重点突破。本文将全面总结C语言的核心知识点,并针对考试中常见的题型提供复习建议,助你轻松拿下高分。 文末提供了一款免费的C语言刷题软件 …

美颜SDK接入实战:构建智能化直播美颜APP的技术路径详解

如何将美颜SDK顺利接入并构建一个智能化的直播美颜APP呢?本文将从技术路径的角度,带你深入解析这一过程。 一、了解美颜SDK的基本功能 美颜SDK通常包括多个功能模块,针对不同的直播场景,SDK会提供针对性的优化算法,确…

【Spring】Spring事务和事务传播机制

🔥个人主页: 中草药 🔥专栏:【Java】登神长阶 史诗般的Java成神之路 一、Spring事务 我们在MySQL阶段已经学习了MySQL的事务相关知识,详情可见 【MySQL数据库】索引与事务-CSDN博客 1、概念 我们在此做一个简单回顾…

Qt 小项目 学生管理信息系统

主要是对数据库的增删查改的操作 登录/注册界面: 主页面: 添加信息: 删除信息: 删除第一行(支持多行删除) 需求分析: 用QT实现一个学生管理信息系统,数据库为MySQL 要求&#xf…

核心网S6730-H48X6C-V2堆叠

核心网是电信网络的中枢,负责数据传输、服务提供和网络管理,对保障通信质量、支持新技术服务和维护网络安全至关重要。堆叠技术通过将多个网络设备逻辑上整合为一个单元,简化管理,提升网络可用性和性能,同时降低成本,增强网络扩展性。 堆叠在网络建设中至关重要,它通过…

教程: 5分钟部署 APIPark 开源 LLM Gateway 与 API 开放门户

极大简化了大语言模型调用的过程,无需复杂代码即可同时连接主流大语言模型,让企业更加快捷、安全地使用AI。喜欢或感兴趣的小伙伴们赶紧去体验吧! 🔗更详细使用教程可以查看:APIPark 产品使用文档 APIPark 提供出色的…

HTML5教程-表格宽度设置,最大宽度,自动宽度

HTML表格宽度 参考:html table width HTML表格是网页设计中常用的元素之一,可以用来展示数据、创建布局等。表格的宽度是一个重要的参数,可以通过不同的方式来设置表格的宽度,本文将详细介绍HTML表格宽度的不同设置方式和示例代…

RISC-V架构下OP-TEE 安全系统实践

安全之安全(security)博客目录导读 本篇博客,我们聚焦RISC-V 2024中国峰会上的RISC-V和OP-TEE结合的一个安全系统实践,来自芯来科技桂兵老师。 关于RISC-V TEE(可信执行环境)的相关方案,如感兴趣可参考RISC-V TEE(可信执行环境)方案初探 首…

RTK数据的采集方法

采集RTK(实时动态定位)数据通常涉及使用高精度的GNSS(全球导航卫星系统)接收器,并通过基站和流动站的配合来实现。本文给出RTK数据采集的基本步骤 文章目录 准备设备设置基站设置流动站数据采集数据存储与处理应用数据…

【银河麒麟操作系统真实案例分享】内存黑洞导致服务器卡死分析全过程

了解更多银河麒麟操作系统全新产品,请点击访问 麒麟软件产品专区:https://product.kylinos.cn 开发者专区:https://developer.kylinos.cn 文档中心:https://documentkylinos.cn 现象描述 机房显示器连接服务器后黑屏&#xff…

Mongodb副本集环境安全认证

我所配置的mongodb副本集群 step1启动 MongoDB 副本集的每一个节点 mongod --config=/usr/local/mongodb_wjx/wjx01/mongod.conf mongod --config=/usr/local/mongodb_wjx/wjx02/mongod.conf mongod --config=/usr/local/mongodb_wjx/wjx03/mongod.conf step2通过主节点添加管…

完美解决Qt Qml窗口全屏软键盘遮挡不显示

1、前提 说明:我使用的是第三方软键盘 QVirtualKeyboard QVirtualKeyboard: Qt5虚拟键盘支持中英文,仿qt官方的virtualkeyboard模块,但使用QWidget实现。 - Gitee.com 由于参考了几篇文章尝试但没有效果,链接如下: 文章一:可能…