Tensorflow 加载本地CIFAR10数据集

Tensorflow 加载本地CIFAR10数据集本文介绍怎样把保存在本地的CIFAR10数据集加载到程序中。数据集网址:https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz代码:from__future__importabsolute_importfrom__future__importdivisionfrom__future__importprint…

大家好,又见面了,我是你们的朋友全栈君。

本文介绍怎样把保存在本地的CIFAR10数据集加载到程序中。

数据集网址:https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

代码:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.keras import backend as K
import numpy as np
import os

import sys
from six.moves import cPickle

def load_batch(fpath, label_key='labels'):
    """Internal utility for parsing CIFAR data.
    # Arguments
        fpath: path the file to parse.
        label_key: key for label data in the retrieve
            dictionary.
    # Returns
        A tuple `(data, labels)`.
    """
    with open(fpath, 'rb') as f:
        if sys.version_info < (3,):
            d = cPickle.load(f)
        else:
            d = cPickle.load(f, encoding='bytes')
            # decode utf8
            d_decoded = {}
            for k, v in d.items():
                d_decoded[k.decode('utf8')] = v
            d = d_decoded
    data = d['data']
    labels = d[label_key]

    data = data.reshape(data.shape[0], 3, 32, 32)
    return data, labels


def load_data(ROOT):
    """Loads CIFAR10 dataset.
    # Returns
        Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
    """
    #dirname = 'cifar-10-batches-py'
    #origin = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
    #path = get_file(dirname, origin=origin, untar=True)
    path = ROOT

    num_train_samples = 50000

    x_train = np.empty((num_train_samples, 3, 32, 32), dtype='uint8')
    y_train = np.empty((num_train_samples,), dtype='uint8')

    for i in range(1, 6):
        fpath = os.path.join(path, 'data_batch_' + str(i))
        (x_train[(i - 1) * 10000: i * 10000, :, :, :],
         y_train[(i - 1) * 10000: i * 10000]) = load_batch(fpath)

    fpath = os.path.join(path, 'test_batch')
    x_test, y_test = load_batch(fpath)

    y_train = np.reshape(y_train, (len(y_train), 1))
    y_test = np.reshape(y_test, (len(y_test), 1))

    if K.image_data_format() == 'channels_last':
        x_train = x_train.transpose(0, 2, 3, 1)
        x_test = x_test.transpose(0, 2, 3, 1)

    return (x_train, y_train), (x_test, y_test)

调用时:先将上面代码保存为load_local_cifar10.py

from load_local_cifar10 import load_data


cifar10_dir = './datasets/cifar-10-batches-py'
(x_train, y_train), (x_test, y_test) = load_data(cifar10_dir)

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/152071.html原文链接:https://javaforall.net

(0)
上一篇 2022年6月22日 下午12:00
下一篇 2022年6月22日 下午12:00


相关推荐

  • 【分布式事务】GitHub上分布式事务框架压测性能对比

    【分布式事务】GitHub上分布式事务框架压测性能对比一、前言&amp;nbsp;&amp;nbsp;&amp;nbsp;&amp;nbsp;&amp;nbsp;&amp;nbsp;随着项目逐步以微服务开发为趋势,逐渐呈现一个服务对应一个数据库。从中产生了分布式事务的问题:一个操作先后调用不同的服务,要保证服务间的事务一致性,这就是分布式事务解决的问题。&amp;nbsp;&amp;nbsp;&amp;nbsp;&amp;nbsp;&amp;nbsp;&am

    2025年8月20日
    4
  • mysql建表与oracle_mysql和oracle建表语句以及数据类型的区别

    mysql建表与oracle_mysql和oracle建表语句以及数据类型的区别1 mysql 和 oracle 建表语句的区别 mysqlDROPTAB order CREATETABLE order id int 11 NOTNULLAUTO INCREMENT number varchar 255 NOTNULLCOMME 工单编号 applicant varchar 255 NOTNULLCOMME

    2026年3月16日
    1
  • PostgreSQL 14和SCRAM认证的改变–应该迁移到SCRAM?

    PostgreSQL 14和SCRAM认证的改变–应该迁移到SCRAM?PostgreSQL14 和 SCRAM 认证的改变应该迁移到 SCRAM 最近 一些 PG 使用者反馈他们切换到 PG14 后 遇到了一些连接错误 FATAL passwordauth

    2026年3月19日
    2
  • java运算符优先级

    java运算符优先级转载博客 https blog csdn net pc gad article details java 中运算符的优先级优先级记忆方法 单目乘除为关系 逻辑三目后赋值 前辈总结的 所谓优先级 就是在表达式中的运算顺序 Java 中常用的运算符的优先级如下表所示 级别为 1 的优先级最高 级别 11 的优先级最低 譬如 x 7 3 2 得到的结

    2026年3月20日
    2
  • lcd开机流程图_LCD1602程序代码及显示流程图.doc[通俗易懂]

    lcd开机流程图_LCD1602程序代码及显示流程图.doc[通俗易懂]LCD1602程序代码及显示流程图LCD1602程序代码及显示流程图lcd1602显示程序代码前些天弄了最小系统板后就想着学习1602的显示程序,可惜坛子里的或网上的,都没有简单的1602显示程序,无柰在网上下载了一段经过反复修改测试,终于有了下面一段代码://———————————-…

    2022年7月16日
    24
  • 深度学习超分辨率重建(总结)[通俗易懂]

    本文为概述,详情翻看前面文章。1.SRCNN:—2,3改进开山之作,三个卷积层,输入图像是低分辨率图像经过双三次(bicubic)插值和高分辨率一个尺寸后输入CNN。图像块的提取和特征表示,特征非线性映射和最终的重建。使用均方误差(MSE)作为损失函数。2.FSRCNN特征提取:低分辨率图像,选取的核9×9设置为5×5。收缩:1×1的卷积核进行降维。非线性映射:用两个串联的3×3的卷积核可以…

    2022年4月1日
    43

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号