基于keras的手写数字识别_数字识别

基于keras的手写数字识别_数字识别一、概述手写数字识别通常作为第一个深度学习在计算机视觉方面应用的示例,Mnist数据集在这当中也被广泛采用,可用于进行训练及模型性能测试;模型的输入为:32*32的手写字体图片,这些手写字体包含0~9数字,也就是相当于10个类别的图片模型的输出:分类结果,0~9之间的一个数下面通过多层感知器模型以及卷积神经网络的方式进行实现二、基于多层感知器的手写数字识别多层感知器的模型如下…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

一、概述

  • 手写数字识别通常作为第一个深度学习在计算机视觉方面应用的示例,Mnist数据集在这当中也被广泛采用,可用于进行训练及模型性能测试;
  • 模型的输入: 32*32的手写字体图片,这些手写字体包含0~9数字,也就是相当于10个类别的图片
  • 模型的输出: 分类结果,0~9之间的一个数
  • 下面通过多层感知器模型以及卷积神经网络的方式进行实现

二、基于多层感知器的手写数字识别

  • 多层感知器的模型如下,其具有一层影藏层:
784个神经元 784个神经元 10个神经元
输入层 影藏层 输出层
  • Mnist数据集此前可通过mnist.load_data()进行下载,但网址打不开,因此通过其他方式将数据集下载到本地,并在本地进行读取,数据集下载链接为:链接: https://pan.baidu.com/s/1ZlktkjqEGEJ0aZGQBQuqXg 提取码: br96
  • 改编后的数据读取方式如下:
import numpy as np def loadData(path="mnist.npz"): f = np.load(path) x_train, y_train = f['x_train'], f['y_train'] x_test, y_test = f['x_test'], f['y_test'] f.close() return (x_train, y_train), (x_test, y_test) # 从Keras导入Mnist数据集 (x_train, y_train), (x_validation, y_validation) = loadData() 
  • 完整的实现代码如下:
import matplotlib.pyplot as plt import numpy as np from keras.models import Sequential from keras.layers import Dense from keras.utils import np_utils def loadData(path="mnist.npz"): f = np.load(path) x_train, y_train = f['x_train'], f['y_train'] x_test, y_test = f['x_test'], f['y_test'] f.close() return (x_train, y_train), (x_test, y_test) # 从Keras导入Mnist数据集 (x_train, y_train), (x_validation, y_validation) = loadData() # 显示4张手写数字图片 plt.subplot(221) plt.imshow(x_train[0], cmap=plt.get_cmap('gray')) plt.subplot(222) plt.imshow(x_train[1], cmap=plt.get_cmap('gray')) plt.subplot(223) plt.imshow(x_train[2], cmap=plt.get_cmap('gray')) plt.subplot(224) plt.imshow(x_train[3], cmap=plt.get_cmap('gray')) plt.show() # 设定随机种子 seed = 7 np.random.seed(seed) num_pixels = x_train.shape[1] * x_train.shape[2] print(num_pixels) x_train = x_train.reshape(x_train.shape[0], num_pixels).astype('float32') x_validation = x_validation.reshape(x_validation.shape[0], num_pixels).astype('float32') # 格式化数据到0~1 x_train = x_train/255 x_validation = x_validation/255 # 进行one-hot编码 y_train = np_utils.to_categorical(y_train) y_validation = np_utils.to_categorical(y_validation) num_classes = y_validation.shape[1] print(num_classes) # 定义基准MLP模型 def create_model(): model = Sequential() model.add(Dense(units=num_pixels, input_dim= num_pixels,kernel_initializer='normal', activation='relu')) model.add(Dense(units=num_classes, kernel_initializer='normal', activation='softmax')) # 编译模型 model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) return model model = create_model() model.fit(x_train, y_train, epochs=10, batch_size=200) score = model.evaluate(x_validation, y_validation) print('MLP: %.2f%%' % (score[1]*100)) 
  • 程序运行结果如下
784 10 Epoch 1/10 200/60000 [..............................] - ETA: 4:32 - loss: 2.3038 - acc: 0.1100 600/60000 [..............................] - ETA: 1:37 - loss: 2.0529 - acc: 0.3283 1000/60000 [..............................] - ETA: 1:02 - loss: 1.8041 - acc: 0.4710 ... 9472/10000 [===========================>..] - ETA: 0s 10000/10000 [==============================] - 1s 112us/step MLP: 98.07% 

三、基于卷积神经网络的手写数字识别

  • 构建的卷积神经网络结构如下:
1 x 28 x 28个输入 32maps, 5 x 5 2 x 2 20% 128个 10个
输入层 卷积层 池化层 Dropout层 Flatten层 全连接层 输出层

Flatten层: Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡,举例如下

input size —->> output size
32 x 32 x 3 Flatten–> 3072
  • 完整的实现代码如下:
import numpy as np from keras.models import Sequential from keras.layers import Dense from keras.layers import Dropout from keras.layers import Flatten from keras.layers.convolutional import Conv2D from keras.layers.convolutional import MaxPooling2D from keras.utils import np_utils from keras import backend backend.set_image_data_format('channels_first') def loadData(path="mnist.npz"): f = np.load(path) x_train, y_train = f['x_train'], f['y_train'] x_test, y_test = f['x_test'], f['y_test'] f.close() return (x_train, y_train), (x_test, y_test) # 从Keras导入Mnist数据集 (x_train, y_train), (x_validation, y_validation) = loadData() # 设定随机种子 seed = 7 np.random.seed(seed) x_train = x_train.reshape(x_train.shape[0], 1, 28, 28).astype('float32') x_validation = x_validation.reshape(x_validation.shape[0], 1, 28, 28).astype('float32') # 格式化数据到0~1 x_train = x_train/255 x_validation = x_validation/255 # 进行one-hot编码 y_train = np_utils.to_categorical(y_train) y_validation = np_utils.to_categorical(y_validation) # 定义模型 def create_model(): model = Sequential() model.add(Conv2D(32, (5, 5), input_shape=(1, 28, 28), activation='relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.2)) model.add(Flatten()) model.add(Dense(units=128, activation='relu')) model.add(Dense(units=10, activation='softmax')) # 编译模型 model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) return model model = create_model() model.fit(x_train, y_train, epochs=10, batch_size=200, verbose=2) score = model.evaluate(x_validation, y_validation, verbose=0) print('CNN_Small: %.2f%%' % (score[1]*100)) 
  • 运行结果如下(明显感觉到运行时间较长):
Epoch 1/10 - 165s - loss: 0.2226 - acc: 0.9367 Epoch 2/10 - 163s - loss: 0.0713 - acc: 0.9785 Epoch 3/10 - 165s - loss: 0.0512 - acc: 0.9841 Epoch 4/10 - 165s - loss: 0.0391 - acc: 0.9880 Epoch 5/10 - 166s - loss: 0.0325 - acc: 0.9900 Epoch 6/10 - 162s - loss: 0.0268 - acc: 0.9917 Epoch 7/10 - 164s - loss: 0.0221 - acc: 0.9928 Epoch 8/10 - 161s - loss: 0.0190 - acc: 0.9943 Epoch 9/10 - 162s - loss: 0.0156 - acc: 0.9950 Epoch 10/10 - 162s - loss: 0.0143 - acc: 0.9959 CNN_Small: 98.87% 
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/194000.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 软件测试的基本理论知识(软件测试面试基础知识)

    01软件研发流程1.软件产品软件产品是指向用户提供的计算机软件、信息系统或设备中嵌入的软件或在提供计算机信息系统集成、应用服务等技术服务时提供的计算机软件。2.软件工程软件工程,英文名SoftwareEngineering,是一门研究用工程化方法构建和维护有效的、实用的和高质量的软件的学科。“软件工程是开发、运行、维护和修复软件的系统方法。”这个定义相当概括,它主要强调软件工程是系统方法而不是某种…

    2022年4月18日
    47
  • 大数据舆情监测与分析_大数据分析系统架构

    大数据舆情监测与分析_大数据分析系统架构前言互联网的飞速发展促进了很多新媒体的发展,不论是知名的大V,明星还是围观群众都可以通过手机在微博,朋友圈或者点评网站上发表状态,分享自己的所见所想,使得“人人都有了麦克风”。不论是热点新闻还是娱乐八卦,传播速度远超我们的想象。可以在短短数分钟内,有数万计转发,数百万的阅读。如此海量的信息可以得到爆炸式的传播,如何能够实时的把握民情并作出对应的处理对很多企业来说都是至关重要的。大数据时代,除了…

    2022年9月20日
    3
  • Matlab矩阵复制扩充

    考虑这个问题:定义一个简单的行向量a    如何复制10行呢?即:    同理,对于一个列向量,如何复制10列呢?  关键函数1:repmat(A,m,n):将向量/矩阵在垂直方向复制m次,在水平方向复制n次。      再举一个例子,对于a=[12;34]:         垂直方向复制3次,水平方向复制2次,

    2022年4月8日
    61
  • rj45 千兆接口定义_网线的RJ45接口的针脚定义「建议收藏」

    我们生活中常用的网线接头类型分为两类:用于连接到网络中的终端设备的DTE类型,如连接到PC机的网卡的网线属于DTE型。还有用于网络设备间连接的DCE类型,如路由器连接到交换机的线或交换机连接到交换机的线均属于DCE型。DTE我们称做“数据终端设备”,这里的终端是一个广义的概念,PC也可以是终端(一般广域网常用DTE设备有路由器、终端主机)。DCE我们称做“数据通信设备”,如MODEM,连接DTE设…

    2022年4月10日
    562
  • windows连接Ubuntu16.10中winscp连接被拒绝「建议收藏」

    windows连接Ubuntu16.10中winscp连接被拒绝「建议收藏」这些天在玩Linux上的一些东西,物理机装了Linux,虚拟机也装了。但是很尴尬,完全从Windows上迁移到Linux上还是需要时间的,比如说今天,虚拟机上就碰到了问题。博主想在Windows上装一个winscp。winscp是一款文件传输工具,可以用来做不同系统之间的文件传输。 因为某些需要,博主的虚拟机网卡设置的是host-only模式,这种模式有一个缺点,也应该不算是缺点,在这种模式

    2025年12月14日
    4
  • Js判断数组中是否存在某个元素「建议收藏」

    Js判断数组中是否存在某个元素「建议收藏」方法一:indexOf(item,start);Item:要查找的值;start:可选的整数参数,缺省则从起始位子开始查找。indexOf();返回元素在数组中的位置,如果没有则返回-1;例子:vararr=[‘aaa’,’bbb’,’ccc’,’ddd’,’eee’];  vara=arr.indexOf(‘ddd’);  console.log(a);  //3  varb=arr.indexOf(‘d’);  console.log(b);  //-1  我通常的用法:if(

    2022年10月19日
    3

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号