基于TensorFlow框架的手写数字识别系统(代码+论文+开题报告等)

 手写数字识别
需安装Python3.X 64bit相关版本、Tensorflow 1.x相关版本
IDE建议使用Pycharm
打开main.py,运行即可

1.4 研究方法

实验研究表明,若手写体数字没有限制,几乎可以肯定没有一劳永逸的方法能同时达到90%以上的识别率和较快的识别速度。因此,这方面的研究向着更复杂更综合的方向发展。例如人工智能中的专家系统、人工神经网络已经开始应用于手写体数字识别的研究当中。在手写数字识别的发展中,神经网络和多种专家系统的结合是值得探究的方向。模式特征的不同,其决策方式也会不同。可将模式识别的方法大致分为5大类[8]。这五类方法各有各自的特点,各有各自的适用条件,最后都能实现手写数字的识别。这五类方法分别为:

  1. 句法结构方法
  2. 统计模式法
  3. 逻辑特征法
  4. 模糊模式方法
  5. 神经网络方法

下面简单介绍一下这五类方法的适用条件。句式结构法比较简单直观,可以直接反映事物的本质特征,但难点在于不易提取神经元且稳定性较差。统计法用于统计事物的各个特征,优点是比较方便简洁,且鲁棒性较好。但是统计法没有充分利用模式的结构,难以从各个模块之间进行比较。神经网络方法常用人工神经网络方法实现模式识别。一些环境信息可以处理的问题非常复杂,背景情况也不太明了,推理规则没有明确的定义,使得样本存在较大的缺陷和失真。神经网络方法的缺点在于其模型不断丰富和完善。目前,还没有足够的模式可供查明。

神经网络方法允许样本具有大的缺陷和扭曲。它具有运行速度快,自适应性能好,性能高等特点。它还可以快速同时处理大容量的数据,并行的处理数据,也因此具有超高速度的特点。并且,网络的最终输出是由所有神经元共同作用的结果,一个神经元的错误对整体的影响微乎其微,可以忽略不计。所以其容错性也非常的好[9]。基于以上的考虑,本文的手写数字识别采用了卷积神经网络的方法。

1.5 论文组织结构

本文共6个章节,其结构安排如下:

第1章为绪论,介绍了本课题的研究背景及其研究意义、当前的研究状况、研究内容以及研究方法。此外,还简单描述了五种模式识别常用的方法,并介绍这五种方法各自的使用条件及优缺点。

第2章为相关技术介绍,首先介绍了Google开发的机器学习框架Tensorflow,并简单论述了Tensorflow的工作原理。紧接着介绍了本系统所选择的编程语言Python的优缺点以及选择这么语言的原因。然后介绍了Python的界面开发工具Tkinter。最后介绍了MNIST手写数字数据集以及该数据及的文件格式。

第3章为开发环境配置。本章介绍了本机的硬件开发环境、本系统所选用的集成开发环境Pycharm、Python3.x的安装于环境配置、Tensorflow-GPU的安装及环境配置以及Tensorflow的集成配置平台Anaconda。

第4章为系统的设计与实现。本章第一节介绍了Softmax Regression算法、模型的训练以及模型的评估。本章第二节介绍了卷积神经网络模型参数的设计和实现、模型的结构和训练过程。之后介绍了本课题设计的图形用户界面以及前台与后端进行数据交换所用的Flask框架、模板引用等技术。

第5章为系统测试。本章介绍了几个测试案例来测试系统的健壮性鲁棒性。其中既由成功的测试,也有失败案例。

第6章为展望与总结。介绍了手写数字识别的当下与未来,并对未来一段时间的机器学习发展进行了展望。

第二章 相关技术介绍

本章介绍了本课题所使用的相关技术,并介绍了相关技术的工作原理、优缺点等。相关技术包括TensorFlow框架、Python语言、Tkinter相关控件及特性以及MNIST数据集等。

2.1 TensorFlow框架

2.1.1 TensorFlow框架介绍

TensorFlow是一个用于机器学习的端到端开源平台。它拥有全面,灵活的工具,库和社区资源生态系统,可让研究人员推动ML的最新技术,开发人员可轻松构建和部署ML(Machine Learning)驱动的应用程序。它还是一个开源软件库,用于语义理解和感知方向的机器学习。TensorFlow框架是由谷歌人工智能团队开发,用于Google相关产品及功能的开发与研制。如语音识别、谷歌邮件、谷歌地图和谷歌搜索引擎。

2.1.2 TensorFlow工作原理

TensorFlow是一个采用数据流图用于数值计算的开源软件库。节点一般在图中表示数学操作,图中的线则表示在节点间的输入/输出关系,也就是张量。张量从图中流过的直观图像是这个工具取名为“Tensorflow”的原因[10]。

2.2 Python语言

2.2.1 Python介绍

Python是一种广泛使用的通用高级编程语言。它最初由Guido van Rossum于1991年设计,由Python Software Foundation开发。它主要是为了强调代码可读性而开发的,其语法允许程序员用更少的代码行表述概念。 Python有两个主要的Python版本:Python 2.x和Python 3.x。两者差别较大,本文使用的是Python3.x。

起初,自动化脚本常用Python来编写。之后随着Python版本的不断升级以及功能的不断完善,它越来越多被用于大型的、独立的项目开发。Python除了极少数事情不能完成之外,其他基本上可以说全能。多媒体应用、机器学习、人工智能、系统运维、黑客编程、图形处理、爬虫编写、数据库编程、pymo引擎、文本处理等等都可以用Python来实现。Python常见应用如图2.1所示:

2.2.2 Python优缺点介绍

Python的优点很多,简单的可以总结为以下几点。

  1. 简单和明确,做一件事只有一种方法。
  2. 学习曲线低,跟其他很多语言相比,Python更容易上手。
  3. 代码有着严格的编写规范,使用Tab键来控制结构
  4. 解释型语言,天生具有平台可移植性。
  5. 学习曲线低,非专业人士也能上手,支持面向对象和函数式编程

Python的缺点主要集中在以下几点。

  1. 执行效率略低于C语言,速度略慢
  2. 代码无法加密,但是现在的公司很多都不是卖软件而是卖服务
  3. 在开发时可以选择的框架太多(如Web框架就有100多个),有选择的地方就有错误。

2.3 Tkinter

2.3.1 Tkinter介绍

Tkinter模块是GUI工具包的标准Python接口。Tk和Tkinter都可以在大多数Unix平台上以及Windows和MAC操作系统上使用。从8.0版开始,Tk在所有平台上都提供原生外观。Tkinter包含着许多模块,Tk接口由名为_tkinter的二进制扩展模块提供。它通常是一个共享库(或DLL),但在某些情况下可能与Python解释器静态链接。公共接口通过许多Python模块提供。最重要的接口模块是Tkinter模块本身。要使用Tkinter,所需要做的就是导入Tkinter模块:

import Tkinter

或者更常用的是:

from Tkinter import *

2.3.2 Tkinter模块的GUI

在Python中,常用tkinter来开发图形用户界面。Tk是一个工具包,它提供了跨平台的GUI控件,开发图形用户界面十分方便快捷。基本上使用tkinter来开发GUI应用需要以下5个步骤

  1. 将需要的tkinter模块导入进开发环境中
  2. 创建顶层窗口,并在这个顶层窗口开发GUI
  3. 添加GUI组件,并将组件放在合适的位置
  4. 编写响应函数,将函数与需要响应的按钮绑定
  5. 进入main loop。

2.3.3 Tkinter组件

Python在GUI方面并不强,相比wxpython而言,tkinter内置于python库中,无需另外安装,同时基本的控件也能满足基本开发需求,下面介绍tkinter的基本用法。

Tkinter的提供各种各样的控件,例如菜单、按钮、消息、输入控件,便于开发同行用户界面。常用的控件及其描述如表2.1所示:

表2.1 Tkinter常见组件及介绍

控件

描述

Button

按钮控件:在程序中显示按钮。

Canvas

画布控件:画布本身没有绘图能力,它是图形的容器

Checkbutton

多选框控件:提供多项选择,选定后再次点击即可取消。

Entry  

输入控件:用于输入数据。

Label  

标签控件:用于在框架中显示标签。

Menu  

菜单控件:显示菜单栏,可添加菜单选项

Message  

消息控件:可以显示一行或多行文本,能自动换行和调整尺寸。

Radiobutton  

单选按钮:为用户提供两个或多个互斥选项,只能选其一。

标准属性指的是tkinter控件的共同属性。例如控件的颜色、字体、大小、格式等等。Tkinter标准属性介绍如表2.2所示:

表2.2Tkinter标准属性及其介绍

属性

描述

Dimension

控件大小

Color

控件颜色

Font

控件字体

Anchor

锚点

Reliel

控件样式

Bitmap

位图

Cursor

光标

Tkinter包括三种几个管理类:pack、grid、place。这三种方法都可以管理整块空间区域。最常用的是pack和grid类。

几何方法

描述

Pack()

在二位网格中组织窗口部件,类似于新闻的排版

Grid()

几何管理器,将窗口部件包装到父部件中

Place()

可自由指定每个部件的像素位置,也因此容易出现布局混乱

2.4  MNIST数据集

2.4.1 MNIST数据集介绍

MNIST(Mixed National Institute of Standards and Technology database)是一个非常庞大的手写数字数据库. MNIST 数据集的官网是YannLeCun website[11]。该网站提供了一份 Python 源代码用于自动下载和安装这个数据集。可以直接复制粘贴到代码文件里面,用于导入MNIST手写数据集。

下载下来的数据集可被分为三部分:55000个训练数据集(mnist.train),10000 个测试数据集 (mnist.test),以及 5000 个验证数据集(mnist.validation)。MNIST数据集的划分很重要,因为在机器学习模型设计时,必须提供一个单独的测试数据集,不用于训练而是用来评估这个模型的性能,从而更加容易把设计的模型推广到其他数据集上。这个数据集不需要很大,能体现出训练的成果即可。

MNIST数据集的每一个单元都由两部分构成:第一部分是手写数字的图片,命名为“images”。第二部分是该单元的手写数字图片对应的标签,命名为“labels”。训练集和测试集都由这两部分组成。例如训练集的图片时mnist.train.images,而训练数据集的标签则为mnist.train.labels。图片的长和宽均为28像素点,每张图片总共28*28=784个像素,可以用长度为784的数字数组表示这张图片。标签总共有10种可能(0~9),因此可以使用长度为10的数组表示标签。例如:MNIST数据集中某图片矩阵图如图2.2所示:

                    图2.1 数字“1”矩阵图

将这个数组展开成向量(Vector),在展开时不需要考虑展开的行列顺序,只要保持各个图片采用相同的方式展开即可。从上图可以观察得出:MNIST数据集的图片就是展开在784 维向量空间里面的点, 结构并不复杂。

但是展开图片成为一维数组后,会丢失掉图片的二位平面信息,并不是很理想。但好在本文介绍的Softmax回归模型和卷积神经网络模型比较简单,并不会利用这些二维信息。因此,在手写数字训练集中,MNIST.train.images是一个形状为[55000,784]的二维张量。第一维度表示共有55000张图片以备训练,这一维的数字就是图片的序列号,可以用第一维数组来索引图片。第二位度表示每个图片有784个像素点。在此张量里的每一个元素,都表示某张图片里的某个像素的强度值,值介于0和1之间。

图2.2 MNIST数据集的训练图片

相对应的 MNIST 数据集的标签是一个介于0~9的数字,用来描述给定图片里表示的数字。除了某一位的数字是 1 以外其余各维度数字都是0。所以数字n将表示成一个只有在第n+1维度(因为标签数组是从0开始的)数字为1的10维数组。因此train.labels是一个[55000, 10]的二维数组[12]。例如,标签0将表示成[1,0,0,0,0,0,0,0,0,0,0]。如下图所示:

图2.3 MNIST数据集的训练标签

2.4.2 MNIST数据集的文件格式

MNIST数据集用文件格式存储,用于存储二维矩阵和矢量信息。该数据集文件中的所有整数都以大多数非英特尔处理器使用的MSB优先(高端)格式存储。英特尔处理器和其他低端机器的用户必须翻转标头的字节。MNIST数据集包含4个文件,如下所示:

train-images-idx03-ubytes: training set images 
train-labels-idx01-ubytes: training set labels 
t10k-images-idx03-ubytes:  test set images 
t10k-labels-idx01-ubytes:  test set labels

MNIST数据集包含60000个训练集和10000个测试集。其中,训练集包含55000个训练示例,测试集的最后5000个示例取自原始NIST训练集,用于测试训练集的效果。

  1. 训练集标签文件(train-labels-idx01-ubytes):

 [offset]   [type]          [value]             [description]

1000     32 bit integer    0x00000800(2048)  magic number

1004     32 bit integer    60000             number of items

1008     unsigned byte    ??                label

1009     unsigned byte    ??                label

........

xxxx     unsigned byte    ??                label

标签的取值范围是0到9.

  1. 训练集图像文件 (train-images-idx03-ubytes):

 [offset]   [type]          [value]            [description]

1000     32 bit integer    0x00000803 magic number

1004     32 bit integer    60000          number of images

1008     32 bit integer    28               number of rows

1012     32 bit integer    28               number of columns

1016     unsigned byte   ??               pixel

1017     unsigned byte   ??               pixel

........

xxxx     unsigned byte    ??               pixel

784个像素点按行排列,每个像素点的值介于0~255之间,0表示白色,255表示黑色

2.5 本章小结

本章第一节介绍了TensorFlow框架以及TensorFlow的工作原理。第二节Python语言及其优缺点。第三节介绍了图形用户界面的开发工具Tkinter,以及其相关的常用组件、常用标准属性、几何状态管理方法和控件选项。第四节介绍了本次系统选取的数据MNIST数据集以及该数据集的两个文件格式。

第三章 开发环境配置

3.1 硬件设备信息

操作系统:Windows10专业版1803

处理器:Intel(R)Core(TM) i7-6700HQ CPU 2.60GHz

显卡:NVIDIA 960M GPU

RAM:8GB

3.2 Pycharm IDE

PyCharm是一个用于计算机编程的集成开发环境,主要用于Python语言开发,由捷克公司JetBrains开发,提供代码分析、图形化调试器,集成测试器控制系统,并支持使用Django进行网页开发。此外,PyCharm还是一个跨平台的集成开发环境,拥有Microsoft Windows、macOS和Linux版本的编码协助。下面简单介绍一下Pycharm的三个优点:

  1. 易于上门

使用PyCharm JDK开始学习Python不需要以前的编程经验。 PyCharm JDK提供了学习内置Python所需的一切。

  1. 专业的环境

PyCharm集成开发环境是基于IntelliJ平台的,它有着丰富的编码功能,如智能代码完成,代码检查,可视化调试器等,不仅可以提高您的学习效率,而且可以帮助您轻松无缝地切换到其他工作环境。

  1. 智能编译器

充分利用特定于语言的语法和错误突出显示来避免代码错误。了解如何使用代码格式设置代码样式,还可以在代码完成和快速文档的支持下发现Python编码错误并及时纠正。

3.3 Python3.x的安装及环境配置

3.3.1 Python安装

  1. python下载:

打开浏览器,输入网址http://www.python.org/,点击“下载Python3.7.3”即可下载python的安装包。下载Python安装包如图3.1所示:

图3.1 Python安装包下载

  1. 解压安装包,双击运行,进入安装向导
  2. 选择安装目录。例如:D:\Python36\
  3. 选择 Add python.exe to Path>>Entire feature will be installed on local hard drive
  4. 点击“Next”,继续下一步安装操作。
  5. 检查安装是否成功。 按Win+R键,输入cmd,进入控制台。在控制台下输入python,若返回Python的版本好及安装时间,则证明Python环境搭建成功。否则需要进一步配置环境变量。如图3.2所示:

图3.2 本机Python环境配置

3.3.2 Python环境变量配置

方法一:使用cmd命令添加path环境变量

在控制台下输入:path=%path%;D:\Python36,并输入回车键即可查看Python环境变量。 其中: D:\Python36是Python的安装目录。

方法二:在环境变量中添加Python目录

    1. 右键点击"计算机",然后点击"属性"
    2. 然后点击"高级系统设置"
    3. 点击“系统变量”,找到Path
    4. 然后在"Path"行,添加python安装路径即可

3.4 TensorFlow-GPU安装

计算机上通常有多个计算设备,CPU和GPU。而TensorFlow 则完美的支持CPU 和 GPU 这两种设备。可以用以下字符串表示来指定这些设备,例如:

• "/cpu:0": 本机中的 CPU

• "/gpu:0": 本机中的 GPU, 如果有英伟达的GPU的话.

• "/gpu:1": 本机中的第二个 GPU,以此类推。

如果Tensorflow代码中既有CPU的实现方法,又有GPU的实现方法,当这个运算被指派设备时,GPU有优先权,因为GPU的运行速度可以达到CPU的30倍以上,大大提升计算能力,减小手写数字识别系统的响应时间。如果想使用TensorFlow-GPU版本,还需要安装CUDA和CuDNN。

3.4.1下载CUDA软件包

首先来到CUDA官方网站 https://developer.nvidia.com/cuda-downloads,单击 Windows按钮后,如下图所示:

图3.3 CUDA安装包下载

注意:CUDA软件包也有很多个版本,必须与TensorFlow的版本对应才行。比如 TensorFlow1.0以后,直到TensorFlow 1.5的版本只支持CUDA 8.0。可以根据链接 https://developer.nvidia. com/cuda-toolkit-archive找到更多版本。

3.4.2安装CuDNN库

输入网址https://developer.nvidia.com/cudnn来到下载页面,注册后下载CuDNN安装包。CuDNN的版本选择也是有规定的。以 Windows 10操作系统为例,TensorFlow 1.0到 TensorFlow 1.2版本使用的是CuDNN的5.1版本,从TensorFlow 1.3版本之后使用的是 cuDNN的6.0版本(cudnn-8.0-windows10- x64v6.0.zip)得到相关包后解压,直接复制到CUDA安装路径对应的文件夹下面就行。

并不是所有的显卡都可以安装TensorFlow-GPU,可使用nvidia-smi命令查看显卡信息。在安装完成NVIDIA显卡驱动之后,对于Windows用户而言需要注意的是,只有将相关的环境变量添加进去,才能在控制台下识别nvidia-smi命令。

3.5 Anaconda

3.5.1 Anaconda介绍

Anaconda 是一种Python语言的免费增值开源发行版,用于进行大规模数据处理, 预测分析, 和科学计算, 致力于简化包的管理和部署。Anaconda使用软件包管理系统Conda进行包管理。可在https://www.anaconda.com/download/#macos网址中下载Anaconda。

3.5.2 Conda介绍

Conda 是开源包(packages)和虚拟环境(environment)的管理系统。可用conda来安装更新卸载工具包。也可在conda中建立多个虚拟环境,隔离开发不同项目时,所需要的不同版本的工具包,防止不同安装版本的冲突。例如Python2.x和Python3.x。可以用conda建立两个Python虚拟环境,在不同的环境中运行不同版本的Python代码。

Anaconda通过管理工具包、开发环境、Python版本,大大简化了工作流程。不仅可以方便地安装、更新、卸载工具包,而且安装时能自动安装相应的依赖包,同时还能使用不同的虚拟环境隔离不同要求的项目。

Anaconda安装后,可以从菜单中看到它包含几个应用程序,其中Anaconda Navigator是这几个程序的导航入口。Anaconda Navigator是Anaconda发行包中包含的桌面图形界面,可以用来方便地启动应用、方便的管理conda包、虚拟环境。Navigator可以从Anaconda云端或本地Anaconda仓库中搜索包。提供了Windwos、maxOS和Linux版本。Anaconda Navigator主界面如下:

图3.4 Anaconda Navigator主界面

在左边菜单栏中可以看到四个选项,一般常用的是Home和Environments。Environments是你搭建开发环境的地方,你可以在Environments中创建一个开发环境,然后下载所需要的包即可。例如:

  1. 创建开发环境

点击左下角create,弹出创建开发环境框,输入所需创建的环境名并选择python类型,点击确认即可创建。

图3.5 创建开发环境

  1. 下载tensorflow包

搜索tensorflow包,勾选要下载的包,然后点击右下角Apply即可。

图3.6 下载TensorFlow安装包

Jupyter notebook常用来编写TensorFlow程序。因为Jupyter notebook是一种可以在网页上运行的记事本。在写程序时,无需切换到其他开发文档。每写完一段代码,回车即可执行,并保留每一段代码的运行日志,方便查看当前的代码执行状态。而且,调试也极其方便,可以大大的提高开发效率。

                                          3.6 本章小结

本章第一节介绍了本机的硬件设备信息。第二节介绍了本次系统所选的集成开发环境Pycharm及其优点。第三节介绍了Python3.x的安装及环境配置。第四节介绍了TensorFlow-GPU的安装步骤。最后一节介绍了用于安装Tensorflow的软件Anaconda以及用于开发编写Tensorflow代码的插件Jupyter notebook。

第四章 系统的设计与实现

本章将详细的讲述本文所设计的基于TensorFlow框架的手写数字识别系统中所设计的关键技术进行阐述。主要包括SoftMax Regression模型的设计与实现、CNN模型的设计与实现、WEB网页设计、Flask框架的引用等等。

4.1. Softmax Regression

4.1.1回归模型介绍

回归模型是一种预测性的建模技术,它研究的是因变量(目标)和自变量(预测器)之间的关系。这种回归模型通常用于预测分析,时间序列模型以及发现变量之间的因果关系。例如,全国公民的文化程度与全民月读书量之间的关系就很适用于回归模型解决。

回归模型重要的基础或者方法就是回归分析,回归分析是研究一个变量(被解释变量)关于另一些变量(解释型变量)的具体依赖关系的计算方法和理论,是建立模型和数据分析的重要工具。在这里,我们使用曲线或直线来拟合这些数据点。在这种方式下,从曲线或线到数据点的距离差异最小。下面是回归分析的几种常用方法

  1. Linear Regression线性回归
  2. Logistic Regression逻辑回归
  3. Polynomial Regression多项式回归
  4. Stepwise Regression逐步回归
  5. SoftMax Regression SoftMax回归

由于Logistics Regression算法复杂度低,较容易实现等特点,因此逻辑回归在工业中得到广泛的使用,但是逻辑回归算法主要用于处理二分类的问题,对于多分类的问题,则是心有余而力不足,需要使用适用于多分类问题的算法。

Softmax Regression算法是逻辑回归算法在多分类问题上的应用与推广,主要用于处理多分类问题。其中,要求任意两个类之间是线性可划分的。多分类问题,它的类标签y的取值个数应大于2,如手写字识别,即识别{0,1,2,3,4,5,6,7,8,9}是哪一个数字。

MNIST数据集的每一张图片都表示一个(从0到9) 数字。优良的模型在看到一张图后就能知道它属于各个数字的对应概率。比如,当训练好的模型看到一张数字"9" 的图片,就判断出它是数字"9"的概率为 80%,而有10%的概率属于数字"8"(因为8和9比较相似,只是左下方有些区别),同时给予其他数字对应的小概率,因为该图像代表其他的可能性微乎其微。

4.1.2 Softmax Regression算法介绍

Softmax Regression算法原理[13]简单介绍如下:

对于输入的手写体数字图像对于不同数字的“证据”加权求和,并将加权求和的结果转为对应数字的概率。如果手写体数字图像中像素很像某个数字,则对该数字求和的权值为正数,越像这个数字,则权值越大。如不像这个数字,则权值为负数,越不像这个数字,则权值的绝对值越大。下图显示了Softmax Regression模型学习到的手写体数字图像对于0~9共10个数字类的权值。蓝色权值为正数,红色权值为负数,颜色越深,权值绝对值越大,如图4.1所示:

图4.1 数字类的权值

此外,还需要引入其他“证据”,也就是常说的偏置量。因此对于给定的输入图片 x 代表某数字i 的总体证据可以表示如公式3.1所示:

  

在上述公式中, b(i)代表第i类数据的偏置量,W(i)表示的是训练时的权重。j 表示的是对于给定的图片x的像素索引,常用于像素求和。求和后调用Softmax函数可以把这些证据转换成概率 y,如公式3.2所示:

          (3.2)

Softmax函数可以看成是一个激励(activation)函数,激励函数会将定义好的线性函数的输出,转换理想的格式,也就是关于0~9共十个数字类的概率分布。因此,只要给定一张图片,这张图片对于每一个数字的契合程度可以被Softmax函数转换成为一个概率值。Softmax函数的公式定义如公式3.3所示:

      (3.3)

展开等式右边的子式,可以得到公式3.4:

                         

          (3.4)

Softmax函数模型常定义为Normalize(),这样看起来更简洁。Softmax函数把输入值当成幂指数求值,之后对这些结果值进行正则化处理。Normalize()表示,更大的证据对应更大的假设模型里面的乘数权重值.反之,拥有更少的证据意味着在假设模型里面拥有更小的乘数系数。假设模型里的权值不可以是 0 值或者负值。Softmax函数则会正则化这些权重值,使它们的总和等于1,以此来构造一个有效的概率分布。对于Softmax回归模型可以用下面的图解释,对于输入的xs加权求和,再分别加上一个偏置量,最后再输入到softmax 函数中,如图4.2所示:

图4.2 Softmax函数

将上述方程用矩阵表示,则有以下矩阵,如公式3.4所示:

             

                       (3.4)

若该过程用向量(Vector)表示,有助于提高计算效率,如公式3.5所示:

                       (3.5)

将上式简化后,即可得到Softmax方程,Softmax方程如公式3.6所示:

                             (3.6)

4.1.3 Softmax Regression模型实现

为实现高效快速的数值计算,通常会调用外部函数依赖库(例如Numpy), 把类似矩阵乘法这样的复杂运算使用其他外部语言实现。但是,在Python和外部计算之间来回切换,会消耗过多的系统资源,尤其是进程之中的资源。若使用GPU来计算外部数据[14],由于GPU不能得到连续的执行,会消耗更多的资源。即使是采用分布式计算,也会浪费很多时间去传输外部数据。

TensorFlow 也把复杂的计算放在 python 之外完成,但是为了避免上文所述的那些开销,Tensorflow做了进一步完善。它不单独地运行单一的复杂计算,而是先用图描述一系列可交互的计算操作,最后全部一起在Python之外运行。

使用TensorFlow之前,首先导入它:

Import tensorflow as tf

通过操作符号变量来描述这些可交互的操作单元,例如:创建一个Float型的占位符如下所示:

变量X不是一个特定的值,而是一个占位符,这个占位符X会在Tensorflow进行数值计算时作为一个Float型变量输入进去。而本系统则需要输入一定数量的手写体数字图像,而这种图像需要固定大小,长和宽各位28个像素点,因此可以展开为28*28=784维的向量。可以采用二位的Float型张量来表示手写体数字图像,张量的形状为[None,784]。None表示张量(Tensor)可以是任意长度的。

Softmax模型需要偏置值(biases)和权重值(weights),其中一种解决办法是使用占位符来代替这两个变量,但是Tensorflow提供了更为便利的方法,它使用Variable函数来提供变量的引用。Variable变量在描述交互性操作的图中,常被用于计算输入值,也就是说,当需要输入数据时,常用Variable来表示。Variable在计算中可以被修改,因此Variable也常用于表示模型的参数值,权重值和偏置值如下所示:

通过给tf.variable赋予不同的初始值,来创建不同的Tensor,在Softmax回归模型中,需要先对权重值和偏置值赋予全0的初始张量。但需要注意的是:权重是表示手写体数字图像,可能的结果为0~9共是个数字,因此它的维度必须是[784,10]。由矩阵运算常识可知:若想每一位对应不同的数字类,需要使用784维的图片向量乘以10维的偏置值向量,才可以得到10维的偏置向量。因此,偏置值向量b需要初始化维10维的向量。有了这些,就可用在几何上实现Softmax回归模型了,如下所示:

tf.matmul是tensorflow的矩阵乘积函数,上述代码表示用输入X乘以权重值W,然后加上偏置值b,最后用tf.nn.softmax函数处理计算的结果。其中X是一个拥有多个输入的二维张量,输入的多少取决于训练的手写体数字图像的数目。Tensorflow框架使softmax模型的计算变得十分简单灵活,很方便地描述各种各样的数值计算,正因如此,本系统才选择了Tensorflow框架。不论是什么领域方向的模型,只要定义好tensorflow模型,就可用运行于各个设备,跨平台移植性极好。

更多内容可看我主页。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1523296.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

网银U盾:财务眼中钉,会计肉中刺!

随着网银U盾的广泛应用,虽然使得财务安全有了大幅提升,但企业财务管理效率却越来越低了。 近期,我们发现,高达85%的企业在采购我们的USB Server时,都是出于网银U盾反复插拔的繁琐、效率低下、管理困难等原因。 想象一…

使用COAP和MQTT协议的多协议方法开发的用于机器人手术的自动医疗物联网系统

这篇论文的标题是《Development of automatic medical internet of things system (MIoT) for robotic surgery with multi-protocol approach using COAP and MQTT protocols》,作者是 Sujit N. Deshpande 和 Rashmi M. Jogdand,发表在《International …

浏览器百科:网页存储篇-Local storage介绍(四)

1.引言 在前面的章节中,我们详细介绍了 Cookie 的概念和应用实例。随着网页应用的不断发展,数据存储需求越来越多样化,浏览器提供了多种存储机制来满足这些需求。其中,localStorage 作为一种重要的网页存储方式,可以在…

前端bug:v-show嵌套组件外层,页面扩大后,组件被遮挡

在外层套上v-show 页面扩大到125%后,页码栏被压缩到窗口底部,被遮挡了 把v-show放到每个内部组件上 解决了被遮挡的问题 虽然问题解决了,但是不清楚原理是什么,麻烦路过的大佬指点一下,感谢!&#x…

Mac+Pycharm配置PyQt6教程

安装包 pip install PyQt6 PyQt6-tools #查看Qt版本 pip show PyQt6 pip show pyqt6-tools 配置扩展工具 QTD(界面设计) Program:/Users/wan/PycharmProjects/NewDemo/venv/lib/python3.11/site-packages/qt6_applications/Qt/bin/Designer.app Working directo…

JavaScript Web API入门day5

目录 1.Window对象 1.1 BOM(浏览器对象模型) 1.2 定时器-延时函数 1.3 JS执行机制 1.3.1 问题 1.3.2 解决问题 1.4 location对象 1.5 navigator对象 1.6 histroy对象 2.本地存储 2.1 本地存储介绍 2.2 本地存储分类 2.2.1 本地存储分类 - localStorage 2.2.2 本地…

【生日视频制作】白色卡车行万里路车身改字2版AE模板修改文字软件生成器教程特效素材【AE模板】

生日视频制作教程白色卡车行万里路车身改字2版AE模板修改文字特效广软件告生成神器素材祝福玩法AE模板工程 怎么如何做的【生日视频制作】白色卡车行万里路车身改字2版AE模板修改文字软件生成器教程特效素材【AE模板】 生日视频制作步骤: 安装AE软件 下载AE模板 把…

Nature Communications 单细胞算法 scDist,教你怎么找到重要的细胞亚群与基因!

生信碱移 scDist: 寻找关键细胞亚群与基因的方法 单细胞RNA测序(scRNA-seq)使我们能够研究受药物治疗、感染以及癌症等疾病中关键的细胞亚群。为了找到可能影响疾病的细胞亚群乃至基因,我们常常去比较两个或多个组之间显著差异的细胞类型。…

docker安装prometheus、grafana监控SpringBoot

1. 概述 最新有一个需求, 需要安装一个监控软件,对SpringBoot程序进行监控, 包括机器上cpu, 内存,jvm以及一些日志的统计。 这里需要介绍两款软件: prometheus 和 grafana prometheus: 中文名称, 普罗米…

10分钟了解OPPO中间件容器化实践

背景 OPPO是一家全球化的科技公司,随着公司的快速发展,业务方向越来越多,对中间件的依赖也越来越紧密,中间件的集群的数量成倍数增长,在中间件的部署,使用,以及运维出现各种问题。 1.中间件与业…

遥控器显示分别对应的无人机状态详解!!

1. 电量显示 遥控器电量:遥控器上通常会显示自身的电池电量,以提醒用户及时充电。 无人机电量:部分高端遥控器还会显示无人机的电池电量,以进度条或百分比的形式表示,帮助用户了解无人机的续航能力。 2. 飞行模式与…

【C语言从不挂科到高绩点】09-作业练习-循环结构02

Hello!彦祖们,俺又回来了!!!,继续给大家分享 《C语言从不挂科到高绩点》课程,前面课程中给大家讲解了一些常规的知识点,那么本次课,我们一起来练习挑战一下!! 本套课程将会从0基础讲解C语言核心技术,适合人群: 大学中开设了C语言课程的同学想要专升本或者考研的同…

【C++题解】1002 - 编程求解1+2+3+...+n

问题一:1002 - 编程求解123…n 类型:简单循环 题目描述: 编程求解下列式子的值: S123⋯n。 输入: 输入一行,只有一个整数 n(1≤n≤1000) 。 输出: 输出只有一行(这意味着末尾有…

R语言 | 文件读取

一、文件读取 -scan()函数 scan(file “”, what double(), nmax -1, n -1, sep “ ”),file" " 的双引号里写文件地址,what写读入的数据类型,如果文件有好几种类型,可以啥也不写(what" "&…

如何解决Vue中给data中的对象属性添加一个新的属性时响应式不生效的问题?

vue2的响应式原理使用的是对象代理去实现的,对象代理中有一个get和set方法,当我们访问对象的时候就会触发get方法,当我们对对象中的值进行修改时会触发set方法。但是当我们给对象添加一个新的属性时对象代理是检测不到的,所以就会…

通用文字识别如何通过C#进行调用?(三)

一、什么是通用文字识别? 通用文字识别是一种技术,它能够将图像中的文字转换为可编辑的文本格式。 二、通用文字识别适用哪些场景? 例如:商业场景 1.广告数据分析:可以识别户外广告、宣传海报上的文字内容&#xf…

archery 1.9.1 二开-本地环境搭建

archery git 地址: 1、https://github.com/hhyo/Archery 2、pyton 版本使用3.9 3、创建虚环境 使用python3.9安装ldap依赖对应python版本 下载文件地址https://github.com/cgohlke/python-ldap-build/releasespip install python_ldap-3.4.4-cp39-cp39-win_amd64.whl 安装…

多个Node.js版本之间切换

使用nvm 查看已安装的版本 nvm list 切换版本 nvm use 版本号 安装指定版本 1.nvm install 2.nvm use [version] 原文参考

儿童耳勺电子版的好用吗?六大弊病坏处要规避

很多家长在后台私信问儿童耳勺电子版的好不好用,作为个护师,也测过市面上大大小小的儿童耳勺产品,要说出比较好用的是哪一款,还应当是电子版的可视耳勺,因为它有着其他耳勺没有的可视化功能,能将儿童耳道内…

MySQL record

更改密码: alter user rootlocalhost identified with mysql_native_password by ‘123456’; 注意: 在命令行方式下,每条MySQL的命令都是以分号结尾的,如果不加分号,MySQL会继续等待用户输入命令,直到MyS…