TensorFlow CNN 测试CIFAR-10数据集

TensorFlow CNN 测试CIFAR-10数据集




本系列文章由
@yhl_leo
出品,转载请注明出处。


文章链接:
http://blog.csdn.net/yhl_leo/article/details/50738311



1 CIFAR-10 数据

CIFAR-10数据集是机器学习中的一个通用的用于图像识别的基础数据集,官网链接为:The CIFAR-10 dataset

cifar10

下载使用的版本是:

version

将其解压后(代码中包含自动解压代码),内容为:

cifar10 data

cifar10 data2

2 测试代码

测试代码公布在GitHub:yhlleo

主要代码及作用:

文件 作用
cifar10_input.py 读取本地或者在线下载CIFAR-10的二进制文件格式数据集
cifar10.py 建立CIFAR-10的模型
cifar10_train.py 在CPU或GPU上训练CIFAR-10的模型
cifar10_multi_gpu_train.py 在多个GPU上训练CIFAR-10的模型
cifar10_eval.py 评估CIFAR-10模型的预测性能

该部分的代码,介绍了如何使用TensorFlow在CPU和GPU上训练和评估卷积神经网络(convolutional neural network, CNN)。

3 相关网页及教程

更加详细地介绍说明,请浏览网页:Convolutional Neural Networks

中文网站极客学院也有该部分的汉译版:卷积神经网络

代码源自tensorflow官网:tensorflow/models/image/cifar10

4 代码修改说明

GitHub公布代码相对源码(本人的Tensorflow版本还是0.5),主要进行了以下修正:

  • cifar10.py
# indices = tf.reshape(tf.range(FLAGS.batch_size), [FLAGS.batch_size, 1])
indices = tf.reshape(range(FLAGS.batch_size), [FLAGS.batch_size, 1])

# or
indices = tf.reshape(tf.range(0, FLAGS.batch_size, 1), [FLAGS.batch_size, 1])

此处,源码编译时会出现以下错误:

  ...
  File ".../cifar10.py", line 271, in loss
    indices = tf.reshape(tf.range(FLAGS.batch_size), [FLAGS.batch_size, 1])
TypeError: range() takes at least 2 arguments (1 given)
  • cifar10_input_test.py
#self.assertEqual("%s:%d" % (filename, i), tf.compat.as_text(key))

import compat as cp
...

self.assertEqual("%s:%d" % (filename, i), cp.as_text(key))

不然的话,我测试的时候就会出现这的错误:

AttributeError: 'module' object has no attribute 'compat'
  • cifar10_train.pycifar10_multi_gpu_train.py

源代码里的最大迭代次数max_steps1000000,需要训练几个小时,不忍心折腾我的破笔记本,就改为了20000

其他改动,例如导入模块或者文件路径等,都很容易理解,就不列举了~

运行结果,与官网上公布的一致,也不再列举。附上一张运行结果截图:

cifartrain

转载于:https://www.cnblogs.com/hehehaha/p/6332160.html

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/109139.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • C# .net中获取台式电脑中串口设备的名称

    C# .net中获取台式电脑中串口设备的名称

    2022年2月23日
    38
  • 数据结构:循环队列(C语言实现)[通俗易懂]

    数据结构:循环队列(C语言实现)[通俗易懂]生活中有很多队列的影子,比如打饭排队,买火车票排队问题等,可以说与时间相关的问题,一般都会涉及到队列问题;从生活中,可以抽象出队列的概念,队列就是一个能够实现“先进先出”的存储结构。队列分为链式队列和静态队列;静态队列一般用数组来实现,但此时的队列必须是循环队列,否则会造成巨大的内存浪费;链式队列是用链表来实现队列的。这里讲的是循环队列,首先我们必须明白下面几个问题一、循环队列的基础知识1

    2022年6月2日
    39
  • Linux 解压 zip 分卷

    Linux 解压 zip 分卷对于一个大的文件,使用分卷压缩得到如下文件:传到Linux目录下,希望解压出来,需要使用zip-F命令修复分卷,从而合成正确的一个压缩文件zip-FUCF-101.zip–outucf101.zip得到ucf101.zip,然后解压ucf101.zip即可unzipucf101.zip…

    2022年5月23日
    195
  • Android SDK下载太慢

    Android SDK下载太慢AndroidSDK下载太慢,可以通过设置合适的代理服务器来解决。

    2022年7月19日
    17
  • 图论简介[通俗易懂]

    图论简介[通俗易懂]这里介绍图论(GraphTheory),图论是计算机科学中非常重要的一部分内容,甚至可以单独划分成为一个领域。很多人第一次接触到图论这个词,就觉得图论是研究和图画相关的内容。不过当大家真的去学习图

    2022年8月1日
    9
  • IP地址和域名的关系

    IP地址和域名的关系1、ip地址和域名是一对多的关系,一个ip地址可以有多个域名,但是相反,一个域名只能有一个ip地址;2、ip地址是数字型的,为了方便记忆,才有了域名,通过域名地址就能找到ip地址;3、ip,全称为互联网协议地址,是指ip地址,意思是分配给用户上网使用的网络协议的设备的数字标签;4、常用的ip地址分为IPv4和IPv6两大类;什么是IP地址1、IP地址是IP协议提供的一种统一的地址格式,他为互联网上的每一台主机和每一个网络都分配一个唯一的逻辑地址,以此来屏蔽物理地址的差异;

    2022年4月5日
    87

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号