MNIST手写数字识别

MNIST是一个手写体数字的图片数据集,该数据集由美国国家标准与技术研究所(National Institute of Standards and Technology (NIST))发起整理,其包含 60,000 张训练图像和 10,000 张测试图像,每张图片的尺寸为 28 x 28
在这里插入图片描述

线性回归

我们尝试通过 线性回归模型 识别手写数字,输入的图片是 28 x 28像素,我们可以将其看为 784 个变量,即:
y = a 1 x 1 + a 2 x 2 + . . . + a n x n y = a_1x_1+a_2x_2+...+a_nx_n y=a1x1+a2x2+...+anxn

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf# 计算梯度,并更新 [a1,a2,...,an],b 值
def gradient(a, b, x_array, y_array, learning_rate):a_gradient = tf.zeros([784, 1])b_gradient = 0length = len(x_array)# 计算梯度for i in range(0, length):x = x_array[i]y = y_array[i]base_gradient = (2 / length) * (np.dot(x, a) + b - y)[0]# print("base_gradient", base_gradient)a_gradient += base_gradient * tf.reshape(x, [784, 1])b_gradient += base_gradient# 更新 a、b 值new_a = a - learning_rate * a_gradientnew_b = b - learning_rate * b_gradientreturn [new_a, new_b]# 计算损失
def computer_loss(a, b, x_array, y_array):length = len(x_array)loss = 0# 计算梯度for i in range(0, length):x = x_array[i]y = y_array[i]loss += (np.dot(x, a) + b - y) ** 2loss /= lengthreturn loss# 计算准确率
def computer_accuracy(a, b, x_array, y_array):accuracy = 0length = len(x_array)for i in range(0, length):x = x_array[i]y = np.dot(x, a) + by = round(y[0])if y == y_array[i]:accuracy += 1return accuracy / lengthmnist = tf.keras.datasets.mnist
(train_data, train_label), (test_data, test_label) = mnist.load_data()
train_size = len(train_data)
test_size = len(test_data)# 手写数字图片是 28x28 的 Tensor,需要将其转换为 1x784
train_data_reshape = tf.reshape(train_data, [train_size, 784])
# 将 int 转换为 float,否则会有计算问题
train_data_reshape = tf.cast(train_data_reshape, dtype=tf.dtypes.float32)
print("train_data_reshape shape", np.shape(train_data_reshape))
# 对数据进行归一化处理
train_data_reshape = train_data_reshape / tf.constant(255.0, shape=[train_size, 784])
print("train_data_reshape", train_data_reshape)test_data_reshape = tf.reshape(test_data, [test_size, 784])
test_data_reshape = tf.cast(test_data_reshape, dtype=tf.dtypes.float32)
train_data_reshape = test_data_reshape / tf.constant(255.0, shape=[test_size, 784])# 假设 y = a1x1 + a2x2 +...+ anxn +b 且 x shape [1,784],则 a shape 为 [784,1]
a = tf.random.normal([784, 1])
b = 0
loss_list = list()
accuracy_list = list()
for i in range(0, 1000):[a, b] = gradient(a, b, train_data_reshape, train_label, 0.01)if i % 10 == 0:loss = computer_loss(a, b, train_data_reshape, train_label)accuracy = computer_accuracy(a, b, test_data_reshape, test_label)print("loss = {} accuracy = {}".format(loss, accuracy))loss_list.append(loss)accuracy_list.append(accuracy)print("a = {} b = {}".format(a, b))
l1 = plt.plot(loss_list, label="loss")
l2 = plt.plot(accuracy_list, label="accuracy")
plt.legend()
plt.show()

在这里插入图片描述
可以看出损失收敛在10左右,准确率只有15%左右,这是因为该模型存在两个问题:

  • 如果预测的数据是 2.5,那实际值应该是2还是3呢?所以应该通过概率来解决该问题,它需要输出多个结果,例如:1的概率为0.999,2的概率为0.0001,3的概率为0.0001等,最终所有结果的概率综合为1。我们称这样的问题为分类问题
  • 图片像素与数字并非线性关系,而是复杂的非线性关系

非线性分类

多输出问题

对于多个结果我们可以考虑使用矩阵的形式,例如 1x4 阶矩阵,需要输出 2 个结果,则可以进行如下运算:
[ a b c d ] ∗ [ 1 5 2 6 3 7 4 8 ] = [ 1 a + 2 b + 3 c + 4 d 5 a + 6 b + 7 c + 8 d ] {\begin{bmatrix} a&b&c&d\\ \end{bmatrix}} * {\begin{bmatrix} 1&5\\ 2&6\\ 3&7\\ 4&8\\ \end{bmatrix}} = {\begin{bmatrix} 1a+2b+3c+4d&5a+6b+7c+8d\\ \end{bmatrix}} [abcd] 12345678 =[1a+2b+3c+4d5a+6b+7c+8d]
手写数字需要10个结果,即 [10] 矩阵,每列的值代表数字 n 的概率,例如表示1的概率为0.999:
[ 0.999 0.0001 0.0001 0.0001 0.0001 0.0001 0.0001 0.0001 0.0001 0.0001 ] {\begin{bmatrix} 0.999 & 0.0001 & 0.0001& 0.0001& 0.0001& 0.0001& 0.0001& 0.0001& 0.0001& 0.0001 \end{bmatrix}} [0.9990.00010.00010.00010.00010.00010.00010.00010.00010.0001]
因为 x x x [ n , 784 ] [n,784] [n,784] 矩阵,所以应该给 x x x 点乘一个 [ 784 , 10 ] [784,10] [784,10] 矩阵,由此多个输出问题得以解决。

非线性问题

我们需要针对线性模型中增加非线性因子,使其变为非线性,这里采用ReLU函数:
在这里插入图片描述
y = r e l u ( a x + b ) y = relu(a x + b) y=relu(ax+b),其中 a = [ 784 , 10 ] a = [784,10] a=[784,10] y = [ 10 ] y = [10] y=[10]

# 线性回归模型识别手写数字import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tfnp.set_printoptions(edgeitems=10, linewidth=200)mnist = tf.keras.datasets.mnist
(train_data, train_label), (test_data, test_label) = mnist.load_data()
train_size = len(train_data)
test_size = len(test_data)# 将 numpy.ndarray 类型数据转换为 Tensor
train_data = tf.convert_to_tensor(train_data, dtype=tf.dtypes.float32)
# 手写数字图片是 28x28 的 Tensor,需要将其转换为 1x784
train_data = tf.reshape(train_data, [train_size, 784])
# 对数据进行归一化处理
train_data = train_data / 255.
print("train_data", train_data)# 对 label 进行 one_hot 编码
train_label = tf.convert_to_tensor(train_label, dtype=tf.dtypes.int8)
train_label = tf.one_hot(train_label, 10)# 沿着第一个维度切片,将 train_data、train_label 转换为 tf.data.Dataset 对象,并按60个合为一个数据集
train_batch = tf.data.Dataset.from_tensor_slices((train_data, train_label)).batch(60)
print("train_batch", train_batch)# 准备测试集数据
test_data = tf.convert_to_tensor(test_data, dtype=tf.dtypes.float32)
test_data = tf.reshape(test_data, [test_size, 784])
test_data = test_data / 255
test_label = tf.one_hot(test_label, 10)model = tf.keras.Sequential([tf.keras.layers.Dense(10, activation='relu')
])
optimizer = tf.optimizers.SGD(learning_rate=0.01)def computer_acc():# 预测测试集结果test_out = model.predict(test_data)# 将概率最大置1,其他置0max_val = tf.reduce_max(test_out, axis=1)max_val = tf.reshape(max_val, [-1, 1])test_out = tf.where(tf.equal(test_out, max_val), tf.ones_like(test_out), tf.zeros_like(test_out))# 降维,判断整行数据是否相等acc = tf.reduce_all(tf.equal(test_out, test_label), axis=1)return tf.reduce_mean(tf.cast(acc, tf.float32))loss_list = list()
acc_list = list()
for i in range(0, 1000):for (x, y) in train_batch:# -1 表示自动推断x = tf.reshape(x, (-1, 784))with tf.GradientTape() as tape:out = model(x)loss = tf.reduce_sum(tf.square(out - y) / x.shape[0])# 计算梯度gradient = tape.gradient(loss, model.trainable_variables)# 反向传递optimizer.apply_gradients(zip(gradient, model.trainable_variables))if i % 10 == 0:loss_list.append(loss)acc = computer_acc()acc_list.append(acc)print("i = {} loss = {} acc = {}".format(i, loss, acc))l1 = plt.plot(loss_list, label="loss")
l2 = plt.plot(acc_list, label="acc")
plt.legend()
plt.show()

在这里插入图片描述
最终准确率收敛在了 84% 左右,原因是增加一个非线性因素可能不够,所以我们需要增加多个,使其可以拟合更复杂的非线性函数:
o u t 1 = r e l u ( a 1 x + b ) out_1 = relu(a_1 x + b) out1=relu(a1x+b),其中 a 1 = [ 784 , 512 ] a_1 = [784,512] a1=[784,512] o u t 1 = [ 512 ] out_1 = [512] out1=[512]
o u t 2 = r e l u ( a 2 o u t 1 + b ) out_2 = relu(a_2 out_1 + b) out2=relu(a2out1+b),其中 a 1 = [ 512 , 256 ] a_1 = [512,256] a1=[512,256] o u t 1 = [ 256 ] out_1 = [256] out1=[256]
o u t 3 = r e l u ( a 3 o u t 2 + b ) out_3 = relu(a_3 out_2 + b) out3=relu(a3out2+b),其中 a 1 = [ 256 , 10 ] a_1 = [256,10] a1=[256,10] o u t 3 = [ 10 ] out_3 = [10] out3=[10]

model = tf.keras.Sequential([tf.keras.layers.Dense(512, activation='relu'),tf.keras.layers.Dense(256, activation='relu'),tf.keras.layers.Dense(10)
])

增加两层网络后,最终准确率收敛在了 98% 左右
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/146680.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Explain执行计划字段解释说明---ID字段说明

ID字段说明 1、select查询的序列号,包含一组数字,表示查询中执行select子句或操作表的顺序 2、ID的三种情况 (1)id相同,执行顺序由上至下。 (2)id不同,如果是子查询,id的序号会…

NEON优化:性能优化经验总结

NEON优化:性能优化经验总结 1. 什么是 NEONArm Adv SIMD 历史 2. 寄存器3. NEON 命名方式4. 优化技巧5. 优化 NEON 代码(Armv7-A内容,但区别不大)5.1 优化 NEON 汇编代码5.1.1 Cortex-A 处理器之间的 NEON 管道差异5.1.2 内存访问优化 Reference: NEON优…

大数据Flink(九十四):DML:TopN 子句

文章目录 DML:TopN 子句 DML:TopN 子句 TopN 定义(支持 Batch\Streaming):TopN 其实就是对应到离线数仓中的 row_number(),可以使用 row_number() 对某一个分组的数据进行排序 应用场景

APP或小程序突然打开显示连接网络失败,内容一片空白的原因是,SSL证书到期啦,续签即可

由于我们使用的是https,所以SSL证书到期了,通过https进入读取内容的APP或网站或小程序就会打开后连接网络失败,出现空白,这是因为我们申请的SSL证书到期了,因为我们申请的证书有效期有时是1个月或3个月,到期…

BI神器Power Query(26)-- 使用PQ实现表格多列转换(2/3)

实例需求:原始表格包含多列属性数据,现在需要将不同属性分列展示在不同的行中,att1、att3、att5为一组,att2、att3、att6为另一组,数据如下所示。 更新表格数据 原始数据表: Col1Col2Att1Att2Att3Att4Att5Att6AAADD…

【AI视野·今日NLP 自然语言处理论文速览 第四十二期】Wed, 27 Sep 2023

AI视野今日CS.NLP 自然语言处理论文速览 Wed, 27 Sep 2023 Totally 50 papers 👉上期速览✈更多精彩请移步主页 Daily Computation and Language Papers Attention Satisfies: A Constraint-Satisfaction Lens on Factual Errors of Language Models Authors Mert …

Flutter开发之Package与Plugin

前言 在flutter中有包和插件两个概念,插件 (plugin) 是 package 的一种,全称是 plugin package,我们简称为 plugin,中文叫插件。包(Package)主要指对flutter相关功能的封装,类似于Android中的插件和iOS中的三方库。而插…

JVM机制理解与调优方案

作者:逍遥Sean 简介:一个主修Java的Web网站\游戏服务器后端开发者 主页:https://blog.csdn.net/Ureliable 觉得博主文章不错的话,可以三连支持一下~ 如有需要我的支持,请私信或评论留言! 前言 很多Java开发…

2023年9月随笔之摩托车驾考

1. 回头看 日更坚持了273天。 读《SQL学习指南(第3版)》更新完成 读《高性能MySQL(第4版)》持续更新 学信息系统项目管理师第4版系列持续更新 9月码字81307字,日均码字数2710字,累计码字451704字&…

Node18.x基础使用总结(二)

Node18.x基础使用总结 1、Node.js模块化1.1、模块暴露数据1.2、引入模块 2、包管理工具2.1、npm2.2、npm的安装2.3、npm基本使用2.4、搜索包2.5、下载安装包2.6、生产环境与开发环境2.7、生产依赖与开发依赖2.8、全局安装2.9、修改windows执行策略2.10、安装包依赖2.11、安装指…

日期范围搜索

1.日期范围选择界面 <?xml version"1.0" encoding"utf-8"?> <ScrollViewandroid:layout_width"fill_parent"android:layout_height"fill_parent"xmlns:android"http://schemas.android.com/apk/res/android">…

桂院校园导航 静态项目 二次开发教程 1.2

Gitee代码仓库&#xff1a;桂院校园导航小程序 GitHub代码仓库&#xff1a;GLU-Campus-Guide 先 假装 大伙都成功安装了静态项目&#xff0c;并能在 微信开发者工具 和 手机 上正确运行。 接着就是 将项目 改成自己的学校。 代码里的注释我就不说明了&#xff0c;有提到 我…

嵌入式Linux应用开发-基础知识-第十九章驱动程序基石①

嵌入式Linux应用开发-基础知识-第十九章驱动程序基石① 第十九章 驱动程序基石①19.1 休眠与唤醒19.1.1 适用场景19.1.2 内核函数19.1.2.1 休眠函数19.1.2.2 唤醒函数 19.1.3 驱动框架19.1.4 编程19.1.4.1 驱动程序关键代码19.1.4.2 应用程序 19.1.5 上机实验19.1.6 使用环形缓…

89、Redis 的 value 所支持的数据类型(String、List、Set、Zset、Hash)---->Zset 相关命令

本次讲解要点&#xff1a; ** Set相关命令&#xff1a;是指value中的数据类型** 启动redis服务器&#xff1a; 打开小黑窗&#xff1a; C:\Users\JH>e: E:>cd E:\install\Redis6.0\Redis-x64-6.0.14\bin E:\install\Redis6.0\Redis-x64-6.0.14\bin>redis-server.exe …

【算法分析与设计】贪心算法(下)

目录 一、单源最短路径1.1 算法基本思想1.2 算法设计思想1.3 算法的正确性和计算复杂性1.4 归纳证明思路1.5 归纳步骤证明 二、最小生成树2.1 最小生成树性质2.1.1 生成树的性质2.1.2 生成树性质的应用 2.2 Prim算法2.2.1 正确性证明2.2.2 归纳基础2.2.3 归纳步骤2.3 Kruskal算…

【刷题笔记10.2】LeetCode: 罗马数字转整数

LeetCode: 罗马数字转整数 一、题目描述 二、分析 方法一&#xff1a; 将给定字符串s中的"IV", “IX”, “XL”, “XC”, “CD”, “CM” 全部替换为其他字符如&#xff1a;a, b, c, d, e, f 这种&#xff0c;然后就可以遍历累加了。 s s.replace("IV",…

python-切换镜像源和使用PyCharm进行第三方开源包安装

文章目录 前言python-切换镜像源和使用PyCharm进行第三方开源包安装1. 切换镜像源2. 使用PyCharm进行第三方开源包安装 前言 如果您觉得有用的话&#xff0c;记得给博主点个赞&#xff0c;评论&#xff0c;收藏一键三连啊&#xff0c;写作不易啊^ _ ^。   而且听说点赞的人每…

7.3 调用函数

前言&#xff1a; 思维导图&#xff1a; 7.3.1 函数调用的形式 我的笔记&#xff1a; 函数调用的形式 在C语言中&#xff0c;调用函数是一种常见的操作&#xff0c;主要有以下几种调用方式&#xff1a; 1. 函数调用语句 此时&#xff0c;函数调用独立存在&#xff0c;作为…

ARINC825规范简介

ARINC825规范简介 机载CAN网络通用标准 ARINC825规范全称为机载CAN网络通用标准&#xff08;The General Standardization of CAN for Airborne Use&#xff09;。顾名思义&#xff0c;ARINC825规范是建立在CAN物理网络基础上的高层规范。CAN网络使用共享的双绞电缆传输数据&…

接雨水问题

接雨水问题 问题背景 LeetCode 42. 接雨水 接雨水问题是一个经典的计算雨水滞留量的问题&#xff0c;通常使用柱状图来表示不同高度的柱子。在下雨的情况下&#xff0c;柱子之间的凹陷部分能够存储雨水&#xff0c;问题的目标是计算这些柱子所能接收的雨水总量。 相关知识 …