手把手用keras分类mnist数据集

实战流程获得数据,并将数据处理成合适的格式按照自己的设计搭建神经网络设定合适的参数训练神经网络在测试集上评价训练效果一、认识mnist数据集fromkeras.utilsimportto_categoricalfromkerasimportmodels,layers,regularizersfromkeras.optimizersimportRMSprop…

大家好,又见面了,我是你们的朋友全栈君。

博主将整个实战做成了视频讲解,视频链接:https://www.bilibili.com/video/BV16g4y1z7Qu/

实战流程

  • 获得数据,并将数据处理成合适的格式
  • 按照自己的设计搭建神经网络
  • 设定合适的参数训练神经网络
  • 在测试集上评价训练效果

一、认识mnist数据集

from keras.utils import to_categorical
from keras import models, layers, regularizers
from keras.optimizers import RMSprop
from keras.datasets import mnist
import matplotlib.pyplot as plt

# 加载数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
print(train_images.shape, test_images.shape)
print(train_images[0])
print(train_labels[0])
plt.imshow(train_images[0])
plt.show()

将图片由二维铺开成一维

train_images = train_images.reshape((60000, 28*28)).astype('float')
test_images = test_images.reshape((10000, 28*28)).astype('float')
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

二、搭建一个神经网络

在这里插入图片描述

network = models.Sequential()
network.add(layers.Dense(units=15, activation='relu', input_shape=(28*28, ),))
network.add(layers.Dense(units=10, activation='softmax'))

三、神经网络训练

1、编译:确定优化器和损失函数等

# 编译步骤
network.compile(optimizer=RMSprop(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])

2、训练网络:确定训练的数据、训练的轮数和每次训练的样本数等

# 训练网络,用fit函数, epochs表示训练多少个回合, batch_size表示每次训练给多大的数据
network.fit(train_images, train_labels, epochs=20, batch_size=128, verbose=2)

四、用训练好的模型进行预测,并在测试集上做出评价

# 来在测试集上测试一下模型的性能吧
y_pre = network.predict(test_images[:5])
print(y_pre, test_labels[:5])
test_loss, test_accuracy = network.evaluate(test_images, test_labels)
print("test_loss:", test_loss, " test_accuracy:", test_accuracy)

五、完整代码

1、使用全连接神经网络

from keras.utils import to_categorical
from keras import models, layers, regularizers
from keras.optimizers import RMSprop
from keras.datasets import mnist
import matplotlib.pyplot as plt

# 加载数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images.reshape((60000, 28*28)).astype('float')
test_images = test_images.reshape((10000, 28*28)).astype('float')
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

network = models.Sequential()
network.add(layers.Dense(units=128, activation='relu', input_shape=(28*28, ),
                         kernel_regularizer=regularizers.l1(0.0001)))
network.add(layers.Dropout(0.01))
network.add(layers.Dense(units=32, activation='relu',
                         kernel_regularizer=regularizers.l1(0.0001)))
network.add(layers.Dropout(0.01))
network.add(layers.Dense(units=10, activation='softmax'))

# 编译步骤
network.compile(optimizer=RMSprop(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])

# 训练网络,用fit函数, epochs表示训练多少个回合, batch_size表示每次训练给多大的数据
network.fit(train_images, train_labels, epochs=20, batch_size=128, verbose=2)

# 来在测试集上测试一下模型的性能吧
test_loss, test_accuracy = network.evaluate(test_images, test_labels)
print("test_loss:", test_loss, " test_accuracy:", test_accuracy)

2、使用卷积神经网络

from keras.utils import to_categorical
from keras import models, layers
from keras.optimizers import RMSprop
from keras.datasets import mnist
# 加载数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# 搭建LeNet网络
def LeNet():
    network = models.Sequential()
    network.add(layers.Conv2D(filters=6, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
    network.add(layers.AveragePooling2D((2, 2)))
    network.add(layers.Conv2D(filters=16, kernel_size=(3, 3), activation='relu'))
    network.add(layers.AveragePooling2D((2, 2)))
    network.add(layers.Conv2D(filters=120, kernel_size=(3, 3), activation='relu'))
    network.add(layers.Flatten())
    network.add(layers.Dense(84, activation='relu'))
    network.add(layers.Dense(10, activation='softmax'))
    return network
network = LeNet()
network.compile(optimizer=RMSprop(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])

train_images = train_images.reshape((60000, 28, 28, 1)).astype('float') / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float') / 255
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

# 训练网络,用fit函数, epochs表示训练多少个回合, batch_size表示每次训练给多大的数据
network.fit(train_images, train_labels, epochs=10, batch_size=128, verbose=2)
test_loss, test_accuracy = network.evaluate(test_images, test_labels)
print("test_loss:", test_loss, " test_accuracy:", test_accuracy)

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/126364.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • C语言输出有颜色的字体

    C语言输出有颜色的字体先看下面的一段代码:#include<stdio.h>intmain(intargc,char**argv){printf(“\033[44;37;5mhelloworld\033[0m\n”);return0;}编译后运行上述代码,结果如下:可见,此时输出的字体和背景已经有了颜色。由上可知,在输出时候加上“\033[…

    2022年7月24日
    24
  • asp.net(C#)中Repeater嵌套绑定Repeater[通俗易懂]

    asp.net(C#)中Repeater嵌套绑定Repeater[通俗易懂]Repeater嵌套Repeater的结构:一般写过的都能看懂吧privatevoidRpTypeBind(){//GetQuestionTypeAndCount()返回一个datatablethis.rptypelist.DataSource=LiftQuestionCtr.GetQuestionTypeAndCount(

    2022年7月14日
    16
  • 素数判断算法(高效率)「建议收藏」

    素数判断算法(高效率)「建议收藏」chuanbindeng的素数判断算法关于素数的算法是信息学竞赛和程序设计竞赛中常考的数论知识,在这里我跟大家讲一下寻找一定范围内素数的几个算法。看了以后相信对大家一定有帮助。   正如大家都知道的那样,一个数n如果是合数,那么它的所有的因子不超过sqrt(n)–n的开方,那么我们可以用这个性质用最直观的方法来求出小于等于n的所有的素数。   num=0;

    2022年6月18日
    27
  • webstorm 2021 激活码_最新在线免费激活

    (webstorm 2021 激活码)JetBrains旗下有多款编译器工具(如:IntelliJ、WebStorm、PyCharm等)在各编程领域几乎都占据了垄断地位。建立在开源IntelliJ平台之上,过去15年以来,JetBrains一直在不断发展和完善这个平台。这个平台可以针对您的开发工作流进行微调并且能够提供…

    2022年3月28日
    110
  • NumPy使用图解教程「建议收藏」

    NumPy使用图解教程「建议收藏」NumPy使用图解教程

    2022年8月2日
    8
  • HTTP默认端口_http协议使用的端口号

    HTTP默认端口_http协议使用的端口号HTTP默认端口80是http协议的默认端口,是在输入网站的时候其实浏览器(非IE)已经帮你输入协议了,所以你输入http://baidu.com,其实是访问http://baidu.com:80。而8080,一般用与webcahe,完全不一样的两个,比如linux服务器里apache默认跑80端口,而apache-tomcat默认跑8080端口,其实端口没有实际意义只是一个接口,主要是看服务的监听端口。443是https的默认端口端口号标识了一个主机上进行通信的不同的应用程序。 H.

    2022年9月16日
    0

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号