Tensorflow 加载本地CIFAR10数据集

Tensorflow 加载本地CIFAR10数据集本文介绍怎样把保存在本地的CIFAR10数据集加载到程序中。数据集网址:https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz代码:from__future__importabsolute_importfrom__future__importdivisionfrom__future__importprint…

大家好,又见面了,我是你们的朋友全栈君。

本文介绍怎样把保存在本地的CIFAR10数据集加载到程序中。

数据集网址:https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

代码:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.keras import backend as K
import numpy as np
import os

import sys
from six.moves import cPickle

def load_batch(fpath, label_key='labels'):
    """Internal utility for parsing CIFAR data.
    # Arguments
        fpath: path the file to parse.
        label_key: key for label data in the retrieve
            dictionary.
    # Returns
        A tuple `(data, labels)`.
    """
    with open(fpath, 'rb') as f:
        if sys.version_info < (3,):
            d = cPickle.load(f)
        else:
            d = cPickle.load(f, encoding='bytes')
            # decode utf8
            d_decoded = {}
            for k, v in d.items():
                d_decoded[k.decode('utf8')] = v
            d = d_decoded
    data = d['data']
    labels = d[label_key]

    data = data.reshape(data.shape[0], 3, 32, 32)
    return data, labels


def load_data(ROOT):
    """Loads CIFAR10 dataset.
    # Returns
        Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
    """
    #dirname = 'cifar-10-batches-py'
    #origin = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
    #path = get_file(dirname, origin=origin, untar=True)
    path = ROOT

    num_train_samples = 50000

    x_train = np.empty((num_train_samples, 3, 32, 32), dtype='uint8')
    y_train = np.empty((num_train_samples,), dtype='uint8')

    for i in range(1, 6):
        fpath = os.path.join(path, 'data_batch_' + str(i))
        (x_train[(i - 1) * 10000: i * 10000, :, :, :],
         y_train[(i - 1) * 10000: i * 10000]) = load_batch(fpath)

    fpath = os.path.join(path, 'test_batch')
    x_test, y_test = load_batch(fpath)

    y_train = np.reshape(y_train, (len(y_train), 1))
    y_test = np.reshape(y_test, (len(y_test), 1))

    if K.image_data_format() == 'channels_last':
        x_train = x_train.transpose(0, 2, 3, 1)
        x_test = x_test.transpose(0, 2, 3, 1)

    return (x_train, y_train), (x_test, y_test)

调用时:先将上面代码保存为load_local_cifar10.py

from load_local_cifar10 import load_data


cifar10_dir = './datasets/cifar-10-batches-py'
(x_train, y_train), (x_test, y_test) = load_data(cifar10_dir)

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/152071.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • Windows 下搭建LDAP服务器

    Windows 下搭建LDAP服务器TheLightweightDirectoryAccessProtocol,orLDAP,isanapplicationprotocolforqueryingandmodifyingdirectoryservicesrunningoverTCP/IP.(viawikipedia)。LDAP全称是一个轻量级的目录访问协议,它是建立在TCP/IP

    2022年5月14日
    32
  • pytorch lstm时间序列预测问题踩坑「建议收藏」

    这里写目录标题1.做时间序列问题2.问题1.数据集自己做,为多个输入对应多个或一个输出2.损失函数注意:不能用交叉熵nn.CrossEntropyLoss()3.准确率1.做时间序列问题2.问题1.数据集自己做,为多个输入对应多个或一个输出2.损失函数注意:不能用交叉熵nn.CrossEntropyLoss()nn.CrossEntropyLoss()要求target目标值即真实值是标签,是torch.int64类型数据,即整数,不允许小数,如果输入小数会强行取整,应该用nn.MSELo

    2022年4月16日
    42
  • 电子元器件采购知识_电子元件购买

    电子元器件采购知识_电子元件购买​本部分内容为“电子元件知识汇总1-封装、电子元件知识汇总2-封装”的扩展,主要侧重于电子元件的品牌以及采购,若需采购厂商参见“电子元件知识汇总3-厂商”,仅供参考。​

    2022年8月24日
    3
  • Html动态点击按钮实现“+”和“-”功能

    Html动态点击按钮实现“+”和“-”功能  Html动态点击按钮实现“+”和“-”功能&lt;!DOCTYPE html&gt;&lt;html lang="en"&gt; &lt;head&gt; &lt;meta http-equiv="Content-Type" content="text/html;"&gt; &lt;title&gt;html动态实现加减&lt;

    2022年6月13日
    61
  • VS2012 产品密钥「建议收藏」

    VS2012 产品密钥「建议收藏」vs2012产品激活码,序列号,旗舰版(utimate)YKCW6-BPFPF-BT8C9-7DCTH-QXGWC

    2022年10月15日
    0
  • idea2021激活吗[最新免费获取]

    (idea2021激活吗)JetBrains旗下有多款编译器工具(如:IntelliJ、WebStorm、PyCharm等)在各编程领域几乎都占据了垄断地位。建立在开源IntelliJ平台之上,过去15年以来,JetBrains一直在不断发展和完善这个平台。这个平台可以针对您的开发工作流进行微调并且能够提供…

    2022年3月30日
    43

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号