手写体识别Tensorflow实现

在这里插入图片描述

简介:本文先讲解了手写体识别中涉及到的知识,然后分步讲解了代码的详细思路,完成了手写体识别案例的讲解,希望能给大家带来帮助,也希望大家多多关注我。本文是基于TensorFlow1.14.0的环境下运行的

手写体识别Tensorflow实现

  • 1 MNIST数据集处理
  • 2 神经网络
  • 3 Softmax函数
    • 3.1 什么时候用softmax
    • 3.2 softmax的优越性
  • 4 代码实现分步讲解
    • 4.1 导包
    • 4.2 载入数据
    • 4.3 批次batch
    • 4.4 placeholder的定义
    • 4.5神经网路模型的构建
    • 4.6 损失函数
    • 4.7使用梯度下降
    • 4.8 初始化 variable
    • 4.9 预测结果
      • 4.9.1 tf.equal函数
      • 4.9.2 tf.argmax函数
    • 4.10计算准确率
      • 4.10.1 tf.cast数据类型转换
      • 4.10.2 tf.reduce_mean
    • 4.13 对输入进行处理
    • 4.12 使用Session进行训练
    • 4.13代码汇总
  • 致谢

1 MNIST数据集处理

数据集的网址如下:https://yann.lecun.com/exdb/mnist/
在这里插入图片描述他的数据集有训练集 测试集图片与标签四部分组成
被分为两部分 6万行训练数据集和1万行的测试数据集

在这里插入图片描述
每一张图片包含2828个像素,把他展开成一维向量,长度是284284 = 784,所以训练集是shape为[60000,784]的张量,第一个维度数字用来索引图片,第二个维度数字用来索引图片中的像素点
他的标签是介于0-9的数字,我们要把它转化为one - hot,也叫做独热 ,比如3 转化为 [0,0,1,0,0,0,0,0,0],他是几就让第几个数字为1.。所以labels将会被转化为一个shape为[60000,10]的矩阵

2 神经网络

根据第一节的内容我们可以设计一个简单的神经网络实现手写体识别,如果想提升准确率,可以在中间加入隐藏层。
在这里插入图片描述

3 Softmax函数

就用手写体识别这个举例子,比如说预测了某张图片的shape为[1,10]的可能是[15,3,1,0,2,4,5,1,1,0],我们希望将他转化为概率,且需要所有概率和为1,我们来看softmax的数学公式

在这里插入图片描述
这个zi就是对应[1,10]矩阵中的权重,zj这个分母部分是所有的和
这样子计算既满足了归一化的需求

3.1 什么时候用softmax

一般是用在神经网络的输出层,用于分类或者回归

3.2 softmax的优越性

  • 满足了人们对归一化的需求
  • 指数函数容易求偏导
  • 指数函数咋信息论和统计学中常用,可以联系这些,为神经网络的构建提供数学依据

4 代码实现分步讲解

4.1 导包

因为环境和版本等种种原因,他经常会报一些无关痛痒的小警告,所以我们要把这些警告屏蔽掉,然后导入TensorFlow等包

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

4.2 载入数据

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print(x_train)
print(y_train)

在这里插入图片描述

4.3 批次batch

这些图片不会一次性处理运算量太大了,比如设置为100,每次都会处理100张图片

batch_size = 100

还需要计算一共有多少个批次

n_batch = len(x_train) // batch_size
print(n_batch)

在这里插入图片描述

4.4 placeholder的定义

x = tf.compat.v1.placeholder(tf.float32,[None,784])
y = tf.compat.v1.placeholder(tf.float32,[None,10])

4.5神经网路模型的构建

Weight = tf.compat.v1.Variable(tf.zeros([784,10]))
bias = tf.compat.v1.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,Weight)+bias)

4.6 损失函数

loss = tf.reduce_mean(tf.square(y-prediction))

4.7使用梯度下降

如果您不会用,请阅读我的文章: 线性回归,在该文章中讲解了该函数的具体用法

train_step = tf.compat.v1.train.GradientDescentOptimizer(0.2).minimize(loss)

4.8 初始化 variable

init_option = tf.compat.v1.global_variables_initializer()

4.9 预测结果

4.9.1 tf.equal函数

他的作用是判断预测和真实是否一致

4.9.2 tf.argmax函数

因为我们计算的是某张图是那个数字的概率,所以需要把最大的拿出来当做是这个图的预测结果

最后我们的道德结果是一个由False和True组成的列表

correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))

4.10计算准确率

4.10.1 tf.cast数据类型转换

我们需要先把布尔类型的结果转化为浮点类型 1…0和0

4.10.2 tf.reduce_mean

他的作用是计算张量的平均值

accuracy_rate = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

4.13 对输入进行处理

# 将输入数据重塑为二维形式(原本图像数据是二维的,这里要展平为一维向量作为神经网络输入)
# 例如原来是 (60000, 28, 28) 变成 (60000, 784),60000是样本数量,784是28*28(图像像素数量)
x_train = x_train.reshape(x_train.shape[0], -1)
x_test = x_test.reshape(x_test.shape[0], -1)

4.12 使用Session进行训练

with tf.compat.v1.Session() as calculate:calculate.run(init_option)y_train = calculate.run(tf.one_hot(y_train, depth=10))# 将测试集标签进行独热编码,显式指定会话参数y_test = calculate.run(tf.one_hot(y_test, depth=10))for epoch in range(21):for batch in range(n_batch):# 计算当前batch的起始索引和结束索引start_index = batch * batch_sizeend_index = start_index + batch_size# 从训练数据集中提取当前batch的输入数据和标签数据batch_x = x_train[start_index:end_index]batch_y = y_train[start_index:end_index]# 将当前batch的数据喂入计算图进行训练calculate.run(train_step, feed_dict={x: batch_x, y: batch_y})# 在每个epoch结束后,在测试集上计算并打印当前的准确率acc = calculate.run(accuracy_rate, feed_dict={x: x_test, y: y_test})print("Epoch {}: Accuracy {}".format(epoch + 1, acc))

在这里插入图片描述

4.13代码汇总

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print(x_train)
print(y_train)
batch_size = 100
n_batch = len(x_train) // batch_size
print(n_batch)
x = tf.compat.v1.placeholder(tf.float32,[None,784])
y = tf.compat.v1.placeholder(tf.float32,[None,10])
Weight = tf.compat.v1.Variable(tf.zeros([784,10]))
bias = tf.compat.v1.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,Weight)+bias)
loss = tf.reduce_mean(tf.square(y-prediction))
train_step = tf.compat.v1.train.GradientDescentOptimizer(0.2).minimize(loss)
init_option = tf.compat.v1.global_variables_initializer()
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
accuracy_rate = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
# 将输入数据重塑为二维形式(原本图像数据是二维的,这里要展平为一维向量作为神经网络输入)
# 例如原来是 (60000, 28, 28) 变成 (60000, 784),60000是样本数量,784是28*28(图像像素数量)
x_train = x_train.reshape(x_train.shape[0], -1)
x_test = x_test.reshape(x_test.shape[0], -1)
with tf.compat.v1.Session() as calculate:calculate.run(init_option)y_train = calculate.run(tf.one_hot(y_train, depth=10))# 将测试集标签进行独热编码,显式指定会话参数y_test = calculate.run(tf.one_hot(y_test, depth=10))for epoch in range(21):for batch in range(n_batch):# 计算当前batch的起始索引和结束索引start_index = batch * batch_sizeend_index = start_index + batch_size# 从训练数据集中提取当前batch的输入数据和标签数据batch_x = x_train[start_index:end_index]batch_y = y_train[start_index:end_index]# 将当前batch的数据喂入计算图进行训练calculate.run(train_step, feed_dict={x: batch_x, y: batch_y})# 在每个epoch结束后,在测试集上计算并打印当前的准确率acc = calculate.run(accuracy_rate, feed_dict={x: x_test, y: y_test})print("Epoch {}: Accuracy {}".format(epoch + 1, acc))

致谢

本文参考了一些博主的文章,博取了他们的长处,也结合了我的一些经验,对他们表达诚挚的感谢,使我对 Tensorflow在手写体识别的使用有更深入的了解,也推荐大家去阅读一下他们的文章。纸上学来终觉浅,明知此事要躬行:
Softmax函数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/17791.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【SpringBoot】公共字段自动填充

问题引入 JavaEE开发的时候,新增字段,修改字段大都会涉及到创建时间(createTime),更改时间(updateTime),创建人(craeteUser),更改人(updateUser),如果每次都要自己去setter(),会比较麻烦&#…

【项目开发】为什么文件名要小写?

未经许可,不得转载。 文章目录 一、可移植性二、易读性三、易用性四、便捷性一、可移植性 Linux 系统对文件名大小写敏感,而 Windows 和 Mac 系统则不敏感。这种差异可能导致跨平台的问题。 例如,以下四个文件名: computerComPutercomPuterCOMPOTer在 Linux 系统上,它们…

ssm127基于SSM的乡镇篮球队管理系统+jsp(论文+源码)_kaic

毕 业 设 计(论 文) 题目:乡镇篮球队管理系统设计与实现 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管理方式。本乡镇篮球队管理…

C#获取视频第一帧_腾讯云媒体处理获取视频第一帧

一、 使用步骤: 第一步、腾讯云开启万象 第二步、安装Tencent.QCloud.Cos.Sdk 包 第三步、修改 腾讯云配置 图片存储目录配置 第四步、执行获取图片并保存 二、封装代码 using System.Text; using System.Threading.Tasks;using COSXML.Model.CI; using COSXML.A…

【数据分享】2003-2022年各省土地利用面积统计数据

数据介绍 2003-2022年各省土地利用面积统计数据数据时间2003-2008、2013、2015-2017、2019、2022数据类型excel数据指标土地调查面积/万公顷农用地面积/万公顷园林面积/万公顷牧草地面积/万公顷建设用地面积/万公顷居民点及工矿用地/万公顷交通用地/万公顷水利设施用地/万公顷…

任务调度工具Spring Test

Spring Task 是Spring框架提供的任务调度工具,可以按照约定的时间自动执行某个代码逻辑。 作用:定时自动执行某段Java代码 应用场景: 信用卡每月还款提醒 银行贷款每月还款提醒 火车票售票系统处理未支付订单 入职纪念日为用户发送通知 一.…

20 轮转数组

20 轮转数组 20.1 轮转数组解决方案 class Solution { public:void rotate(vector<int>& nums, int k) {int n nums.size();k k % n; // 如果 k 大于数组长度&#xff0c;取模减少不必要的旋转// 第一步&#xff1a;反转整个数组reverse(nums.begin(), nums.end(…

字符串相关题解

目录 字母异位词 最长公共前缀 博主主页&#xff1a;东洛的克莱斯韦克-CSDN博客 字母异位词 49. 字母异位词分组 - 力扣&#xff08;LeetCode&#xff09; 这道题更像一道语法题&#xff0c;考察对容器的掌握情况。如果按题目要求去模拟&#xff0c;不仅要分析每个字符串&am…

【微软:多模态基础模型】(3)视觉生成

欢迎关注【youcans的AGI学习笔记】原创作品 【微软&#xff1a;多模态基础模型】&#xff08;1&#xff09;从专家到通用助手 【微软&#xff1a;多模态基础模型】&#xff08;2&#xff09;视觉理解 【微软&#xff1a;多模态基础模型】&#xff08;3&#xff09;视觉生成 【微…

CentOS8 启动错误,enter emergency mode ,开机直接进入紧急救援模式,报错 Failed to mount /home 解决方法

先看现场问题截图&#xff1a; 1.根据提示 按 ctrld 输入 root 密码&#xff0c;进入系统。 2. 在紧急模式下运行&#xff1a;journalctl -xe &#xff0c;查看相关日志&#xff0c;找到关键点&#xff1a; Failed to mount /home 3.接着执行修复命令&#xff1a; xfs_repa…

2024140读书笔记|《作家榜名著:生如夏花·泰戈尔经典诗选》——你从世界的生命的溪流浮泛而下,终于停泊在我的心头

2024140读书笔记|《作家榜名著&#xff1a;生如夏花泰戈尔经典诗选》——你从世界的生命的溪流浮泛而下&#xff0c;终于停泊在我的心头 《作家榜名著&#xff1a;生如夏花泰戈尔经典诗选》[印]泰戈尔&#xff0c;郑振铎译&#xff0c;泰戈尔的诗有的清丽&#xff0c;有的童真&…

lenovo联想ThinkBook 14 G5 ABP(21JE)原装出厂Windows11系统恢复镜像包下载

适用机型 &#xff1a;【21JE】 链接&#xff1a;https://pan.baidu.com/s/1FUjwN8ZeaQ9qr3kNalSkYg?pwdqasf 提取码&#xff1a;qasf 联想原装出厂系统自带所有驱动、出厂主题壁纸、系统属性联机支持标志、系统属性专属LOGO标志、Office办公软件、联想电脑管家、联想浏览…

MySQL 数据类型

数值类型 int类型 类型说明tinyint1字节&#xff0c;范围从-128到127&#xff08;有符号&#xff09;&#xff0c;0到255&#xff08;无符号&#xff09;smallint2字节&#xff0c;范围从-2^15到2^15-1&#xff08;有符号&#xff09;&#xff0c;0到2^16-1&#xff08;无符号…

【WPF】Prism学习(三)

Prism Commands 1.复合命令&#xff08;Composite Commanding&#xff09; 这段内容主要介绍了在应用程序中如何使用复合命令&#xff08;Composite Commands&#xff09;来实现多个视图模型&#xff08;ViewModels&#xff09;上的命令。以下是对这段内容的解释&#xff1a; …

用go语言后端开发速查

文章目录 一、发送请求和接收请求示例1.1 发送请求1.2 接收请求 二、发送form-data格式的数据示例 用go语言发送请求和接收请求的快速参考 一、发送请求和接收请求示例 1.1 发送请求 package mainimport ("bytes""encoding/json""fmt""ne…

SpringCloud Alibaba入门简介和Nacos服务注册和配置中心

前面已经把spring cloud相关的组件都一一学了个遍,现在有点小佩服自己…本来计划今天周末好好出去玩一圈,天气太热了,39了都,还是在办公室学习吧,进行下面的springCloud Alibaba 学习吧…不废话了赶快进入正体 1. SpringCloud Alibaba入门简介 1.1 why会出现SpringCloud alib…

如何让Excel公式中的参数实现动态引用

如果你想成为Excel函数高手&#xff0c;仅仅掌握VLOOKUP和Countif等函数是远远不够的&#xff0c;起码你得学会使用INDIRECT函数&#xff0c;熟练掌握INDIRECT函数能让你从一个初学者晋级为高手&#xff0c;学会它就好比孙悟空掌握了72般变化的基本功&#xff0c;你说厉不厉害。…

【流量分析】常见webshell流量分析

免责声明&#xff1a;本文仅作分享&#xff01; 对于常见的webshell工具&#xff0c;就要知攻善防&#xff1b;后门脚本的执行导致webshell的连接&#xff0c;对于默认的脚本要了解&#xff0c;才能更清晰&#xff0c;更方便应对。 &#xff08;这里仅针对部分后门代码进行流量…

车载诊断架构 --- 关于DTC的开始检测条件

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 所有人的看法和评价都是暂时的,只有自己的经历是伴随一生的,几乎所有的担忧和畏惧,都是来源于自己的想象,只有你真的去做了,才会发现有多快乐。…

掌握 Spring Boot 的最佳方法 – 学习路线图

在企业界&#xff0c;人们说“Java 永垂不朽&#xff01;”。但为什么呢&#xff1f;Java 仍然是开发企业应用程序的主要平台之一。大型公司使用企业应用程序来赚钱。这些应用程序具有高可靠性要求和庞大的代码库。根据Java开发人员生产力报告&#xff0c;62% 的受访开发人员使…