tensorflow2.0手写数字识别_tensorflow手写数字识别

tensorflow2.0手写数字识别_tensorflow手写数字识别本节笔记作为Tensorflow的HelloWorld,用MNIST手写数字识别来探索Tensorflow。笔记的内容来自Tensorflow中文社区和黄文坚的《Tensorflow实战》,只作为自己复习总结。

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

本节笔记作为 Tensorflow 的 Hello World,用 MNIST 手写数字识别来探索 Tensorflow。笔记的内容来自 Tensorflow 中文社区和黄文坚的《Tensorflow 实战》,只作为自己复习总结。

环境:

  • Windows 10
  • Anaconda 4.3.0
  • Spyder

本节笔记主要采用 Softmax Regression 算法,构建一个没有隐层的神经网络来实现 MNIST 手写数字识别。

1. MNIST 数据集加载

MNIST 数据集可以从MNIST官网下载。也可以通过 Tensorflow 提供的 input_data.py进行载入。

由于上述方法下载数据集比较慢,我已经把下载好的数据集上传到CSDN资源中,可以直接下载。

将下载好的数据集放到目录C:/Users/Administrator/.spyder-py3/MNIST_data/下。目录可以根据自己的喜好变换,只是代码中随之改变即可。

通过运行Tensorflow 提供的代码加载数据集:

from tensorflow.examples.tutorials.mnist import input_data

# 获取数据
mnist = input_data.read_data_sets("C:/Users/Administrator/.spyder-py3/MNIST_data/", one_hot=True)

MNIST数据集包含55000样本的训练集,5000样本的验证集,10000样本的测试集。 input_data.py 已经将下载好的数据集解压、重构图片和标签数据来组成新的数据集对象。

图像是28像素x28像素大小的灰度图片。空白部分全部为0,有笔迹的地方根据颜色深浅有0~1的取值,因此,每个样本有28×28=784维的特征,相当于展开为1维。

这里写图片描述

所以,训练集的特征是一个 55000×784 的 Tensor,第一纬度是图片编号,第二维度是图像像素点编号。而训练集的 Label(图片代表的是0~9中哪个数)是一个 55000×10 的 Tensor,10是10个种类的意思,进行 one-hot 编码 即只有一个值为1,其余为0,如数字0,对于 label 为[1,0,0,0,0,0,0,0,0,0]。

这里写图片描述

这里写图片描述

2. Softmax Regression 算法

数字都是0~9之间的,一共有10个类别,当对图片进行预测时,Softmax Regression 会对每一种类别估算一个概率,并将概率最大的那个数字作为结果输出。

Softmax Regression 将可以判定为某类的特征相加,然后将这些特征转化为判定是这一个类的概率。我们对图片的所以像素求一个加权和。如某个像素的灰度值大代表很有可能是数字n,这个像素权重就很大,反之,这个权重很有可能为负值。

特征公式:

这里写图片描述

b i b_i bi 为偏置值,就是这个数据本身的一些倾向。

然后用 softmax 函数把这些特征转换成概率 y y y :

这里写图片描述

对所有特征计算 softmax,并进行标准化(所有类别输出的概率值和为1):

这里写图片描述

判定为第 i 类的概率为:

这里写图片描述

Softmax Regression 流程如下:

这里写图片描述

转换为矩阵乘法:

这里写图片描述

这里写图片描述

写成公式如下:

这里写图片描述

3.实现模型

import tensorflow as tf
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)

首先载入 Tensorflow 库,并创建一个新的 InteractiveSession ,之后的运算默认在这个 session 中。

  • placeholder:输入数据的地方,None 代表不限条数的输入,每条是784维的向量
  • Variable:存储模型参数,持久化的

4.训练模型

我们定义一个 loss 函数来描述模型对问题的分类精度。 Loss 越小,模型越精确。这里采用交叉熵:

这里写图片描述
其中,y 是我们预测的概率分布, y’ 是实际的分布。

y_ = tf.placeholder(tf.float32, [None,10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))

定义一个 placeholder 用于输入正确值,并计算交叉熵。

接着采用随机梯度下降法,步长为0.5进行训练。

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

训练模型,让模型循环训练1000次,每次随机从训练集去100条样本,以提高收敛速度。

for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  train_step.run({x: batch_xs, y_: batch_ys})

5.评估模型

我们通过判断实际值和预测值是否相同来评估模型,并计算准确率,准确率越高,分类越精确。

这里写图片描述

6.总结

实现的整个流程:

  1. 定义算法公式,也就是神经网络前向传播时的计算。
  2. 定义 loss ,选定优化器,并指定优化器优化 loss。
  3. 迭代地对数据进行训练。
  4. 在测试集或验证集上对准确率进行评测。

7.全部代码

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

# 获取数据
mnist = input_data.read_data_sets("C:/Users/Administrator/.spyder-py3/MNIST_data/", one_hot=True)

print('训练集信息:')
print(mnist.train.images.shape,mnist.train.labels.shape)
print('测试集信息:')
print(mnist.test.images.shape,mnist.test.labels.shape)
print('验证集信息:')
print(mnist.validation.images.shape,mnist.validation.labels.shape)

# 构建图
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x,W) + b)

y_ = tf.placeholder(tf.float32, [None,10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# 进行训练
tf.global_variables_initializer().run()

for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  train_step.run({x: batch_xs, y_: batch_ys})

# 模型评估
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

print('MNIST手写图片准确率:')
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/193869.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • linux下的删除文件命令,Linux下删除文件命令「建议收藏」

    linux下的删除文件命令,Linux下删除文件命令「建议收藏」在linux中删除文件与文件夹我们可以直接使用rm就可以删除了,彻底删除文件或文件夹我们可以使用shred命令来完成,接下来是小编为大家收集的Linux下删除文件命令,希望能帮到大家。Linux下删除文件命令linux删除目录很简单,很多人还是习惯用rmdir,不过一旦目录非空,就陷入深深的苦恼之中,现在使用rm-rf命令即可。直接rm就可以了,不过要加两个参数-rf即:rm-rf目录名字…

    2022年7月26日
    6
  • 模板方法模式例子「建议收藏」

    模板方法模式例子「建议收藏」原文地址:http://www.cnblogs.com/jenkinschan/p/5768760.html一、概述 模板方法模式在一个方法中定义一个算法的骨架,而将一些步骤延迟到子类中。模板方法使得子类可以在不改变算法结构的情况下,重新定义算法中的某些步骤。二、结构类图三、解决问题模板方法就是提供一个算法框架,框架里面的步骤有些是父类已经定好的,有些需要子类自己实现。相当于要去办一件事情,行动的流

    2025年6月9日
    3
  • Java BigDecimal详解

    Java BigDecimal详解1.引言       借用《EffactiveJava》这本书中的话,float和double类型的主要设计目标是为了科学计算和工程计算。他们执行二进制浮点运算,这是为了在广域数值范围上提供较为精确的快速近似计算而精心设计的。然而,它们没有提供完全精确的结果,所以不应该被用于要求精确结果的场合。但是,商业计算往往要求结果精确,这时候BigDecimal就派上大用场啦。 2.BigD

    2022年6月7日
    36
  • java出现中文乱码_JAVA中文显示乱码问题「建议收藏」

    java出现中文乱码_JAVA中文显示乱码问题「建议收藏」在基于JAVA的编程中,经常会碰到汉字显示乱码的问题,经一番查询现总结如下。在JSP中建议网页编码方式用GBK,这样会方便一些。这个问题是因为JAVA编码方式转换出现了问题,Java中默认的编码方式是UNICODE,而中国人通常使用的文件和DB都是基于GB2312或者BIG5等编码,故会出现此问题。我知道一定有很多朋友也会碰到这个问题,所以特就总结了一下,来拿出来让大家一起分享了。自己也做个备忘。…

    2022年7月8日
    18
  • 跳频介绍_跳频功能

    跳频介绍_跳频功能跳频是最常用的扩频方式之一,其工作原理是指收发双方传输信号的载波频率按照预定规律进行离散变化的通信方式,也就是说,通信中使用的载波频率受伪随机变化码的控制而随机跳变。从通信技术的实现方式来说,“跳频”是一种用码序列进行多频频移键控的通信方式,也是一种码控载频跳变的通信系统。从时域上来看,跳频信号是一个多频率的频移键控信号;从频域上来看,跳频信号的频谱是一个在很宽频带上以不等间隔随机跳变的。其中:跳

    2025年8月12日
    2
  • int是什么_int a[4][4]

    int是什么_int a[4][4]Int16意思是16位整数(16bitinteger),相当于short占2个字节-32768~32767Int32意思是32位整数(32bitinteger),相当于int占4个字节-2147483648~2147483647Int64意思是64位整数(64bitinterger),相当于longlong占8个字节…

    2022年8月15日
    6

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号