十分钟搞懂Pytorch如何读取MNIST数据集

前言本文用于记录使用pytorch读取minist数据集的过程,以及一些思考和疑惑吧…正文在阅读教程书籍《深度学习入门之Pytorch》时,文中是如此加载MNIST手写数字训练集的:train_dataset=datasets.MNIST(root=’./MNIST’,train=True,transform=data_tf,download=True)解释一下参数datasets.MNIST是Pytorch的内置函数torchvision.datasets.MNIST,通过这个可以导入数

大家好,又见面了,我是你们的朋友全栈君。

前言

本文用于记录使用pytorch读取minist数据集的过程,以及一些思考和疑惑吧…

正文

在阅读教程书籍《深度学习入门之Pytorch》时,文中是如此加载MNIST手写数字训练集的:

train_dataset = datasets.MNIST(root='./MNIST',train=True,transform=data_tf,download=True)

解释一下参数

datasets.MNIST是Pytorch的内置函数torchvision.datasets.MNIST,通过这个可以导入数据集。

train=True 代表我们读入的数据作为训练集(如果为true则从training.pt创建数据集,否则从test.pt创建数据集)

transform则是读入我们自己定义的数据预处理操作

download=True则是当我们的根目录(root)下没有数据集时,便自动下载。

如果这时候我们通过联网自动下载方式download我们的数据后,它的文件路径是以下形式:

在这里插入图片描述

其中我们所需要的文件主要在raw文件夹下

train-images-idx3-ubyte.gz: training set images (9912422 bytes) 
train-labels-idx1-ubyte.gz: training set labels (28881 bytes) 

t10k-images-idx3-ubyte.gz: test set images (1648877 bytes) 
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes) 

接下来,书中是如此加载数据集的

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=5,
                                           shuffle=True)

由于DataLoader为Pytorch内部封装好的函数,所以对于它的调用方法需要自行去查阅。

我在最开始疑惑的点:传入的根目录在下载好数据集后,为MNIST下两个文件夹,而processed和raw文件夹下还有诸多文件,所以到底是如何读入数据的呢?所以我决定将数据集下载后,通过读取本地的MINIST数据集并进行装载。

首先,自定义数据类来继承和重写Dataset抽象类

class DealDataset(Dataset):
    """ 读取数据、初始化数据 """
    def __init__(self, folder, data_name, label_name,transform=None):
        (train_set, train_labels) = self.load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式
        self.train_set = train_set
        self.train_labels = train_labels
        self.transform = transform

    def __getitem__(self, index):

        img, target = self.train_set[index], int(self.train_labels[index])
        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self):
        return len(self.train_set)
    ''' load_data也是我们自定义的函数,用途:读取数据集中的数据 ( 图片数据+标签label '''
    def load_data(self,data_folder, data_name, label_name):
    with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据
        y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

    with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath:
        x_train = np.frombuffer(
            imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
    return (x_train, y_train)

接下来,调用我们自定义的数据类来加载数据集

trainDataset = DealDataset('./MNIST/MNIST/raw', "train-images-idx3-ubyte.gz","train-labels-idx1-ubyte.gz",transform=transforms.ToTensor())

# 训练数据和测试数据的装载
train_loader = torch.utils.data.DataLoader(
    dataset=trainDataset,
    batch_size=10, # 一个批次可以认为是一个包,每个包中含有10张图片
    shuffle=False,
)

通过这种方式便可以大概了解了读取数据集的过程。

接下来,我们来验证以下我们数据是否正确加载

# 实现单张图片可视化
images, labels = next(iter(train_loader))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
plt.imshow(img)
plt.show()

p.s.:其实这里是用cv2.imshow来展示图片,但是我的代码是在jupyter notebook上写的,所以只能通过plt来代替加载。
在这里插入图片描述

数据加载成功~

深入探索

可以看到,在load_data函数中

y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

这个offset=8又是为啥呢?
我们进入MNIST数据集的官方页面进行查看
在这里插入图片描述

通过文档介绍,可以看到
offset的0000-0003是 magic number,所以跳过不读,
offset的0004-0007是items数目
接下来这些代表的就是标签

同理对于

x_train = np.frombuffer(
            imgpath.read(), np.uint8, offset=16).reshape(len(y_train)

在这里插入图片描述

根据刚才的分析方法,也可以明白为什么offset=16了

完整代码

1.直接使用pytorch自带的mnist数据集加载

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2
import matplotlib.pyplot as plt

data_tf = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5])]
)

train_dataset = datasets.MNIST(root='./coding/learning/lrdata/MNIST',train=True,transform=data_tf,download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=5,
                                           shuffle=True)
# 实现单张图片可视化
images, labels = next(iter(train_loader))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
plt.imshow(img)
plt.show()

p.s.:记得自己修改root根目录。

2.使用自定义的数据类加载本地MNIST数据集

import numpy as np
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms
import gzip
import os
import torchvision
import cv2
import matplotlib.pyplot as plt

class DealDataset(Dataset):
    """ 读取数据、初始化数据 """
    def __init__(self, folder, data_name, label_name,transform=None):
        (train_set, train_labels) = load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式
        self.train_set = train_set
        self.train_labels = train_labels
        self.transform = transform

    def __getitem__(self, index):

        img, target = self.train_set[index], int(self.train_labels[index])
        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self):
        return len(self.train_set)
  
 def load_data(data_folder, data_name, label_name):
    with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据
        y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

    with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath:
        x_train = np.frombuffer(
            imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
    return (x_train, y_train)

trainDataset = DealDataset('./coding/learning/lrdata/MNIST/MNIST/raw', "train-images-idx3-ubyte.gz","train-labels-idx1-ubyte.gz",transform=transforms.ToTensor())

# 训练数据和测试数据的装载
train_loader = torch.utils.data.DataLoader(
    dataset=trainDataset,
    batch_size=10, # 一个批次可以认为是一个包,每个包中含有10张图片
    shuffle=False,
)

# 实现单张图片可视化
images, labels = next(iter(train_loader))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
plt.imshow(img)
plt.show()

参考

1.《深度学习入门之Pytorch》- 廖星宇
2.使用Pytorch进行读取本地的MINIST数据集并进行装载
3.顺藤摸瓜-mnist数据集的补充

在这里插入图片描述

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/126858.html原文链接:https://javaforall.net

(0)
上一篇 2022年4月8日 下午4:20
下一篇 2022年4月8日 下午4:20


相关推荐

  • 搜索引擎使用技巧

    1、双引号把搜索词放在双引号中,代表完全匹配搜索,也就是说搜索结果返回的页面包含双引号中出现的所有的词,连顺序也必须完全匹配。百度和Google都支持这个指令。例如搜索:“Python”。2、减号减号代表搜索不包含减号后面的词的页面。使用这个指令时减号前面必须是空格,减号后面没有空格,紧跟着需要排除的词。Google和bd都支持这个指令。例如:搜索-引擎返回的则是包含“搜索”这…

    2022年4月5日
    62
  • IntelliJ IDEA创建maven web项目(IDEA新手适用)

    IntelliJ IDEA创建maven web项目(IDEA新手适用)PS:从eclipse刚转到IDEA,对于这个陌生的工具我表示无言,但听说很好用,也就试试,结果我几乎花了一晚上的时间才搭起来mavenweb项目,觉得在此给各位一个搭建mavenweb项目的教程,指出我踩过的各种坑!步骤一:首先先创建一个project,在这里就是创建一个maven的工作空间步骤二:按照下面的步骤操作就可以了,最后next首先,选择左边的maven然后在右…

    2022年6月26日
    58
  • APK 签名:v1 v2 v3 v4

    APK 签名:v1 v2 v3 v4通过对Apk进行签名,开发者可以证明对Apk的所有权和控制权,可用于安装和更新其应用。而在Android设备上的安装Apk,如果是一个没有被签名的Apk,则会被拒绝安装。在安装Apk的时候,软件包管理器也会验证Apk是否已经被正确签名,并且通过签名证书和数据摘要验证是否合法没有被篡改。只有确认安全无篡改的情况下,才允许安装在设备上。简单来说,APK的签名主要作用有两个:证明APK的所有者。 允许Android市场和设备校验APK的正确性。

    2022年5月17日
    174
  • Hash与Hash冲突及四种解决方案

    Hash与Hash冲突及四种解决方案Hash 与 Hash 冲突大家都了解过 HashMap 或者其他有着 hash 表结构的容器 所以首先我们来谈谈什么是 Hash 什么是 Hash 冲突什么是 Hash Hash 表也称散列表 也有直接译作哈希表 Hash 表是一种特殊的数据结构 它同数组 链表以及二叉排序树等相比较有很明显的区别 它能够快速定位到想要查找的记录 而不是与表中存在的记录的关键字进行比较来进行查找 原理 Hash 表采用一个映射函数 f key gt address 将关键字映射到该记录在表中的存储位置 从而在想要查找该记录时 可以

    2026年3月26日
    3
  • 基于MATLAB的卷积神经网络车牌识别系统

    基于MATLAB的卷积神经网络车牌识别系统车牌识别是基于车牌照片的车牌信息的识别工作,车牌识别技术对我们的实际生活至关重要,例如交通违规行为的增加,拦截非法车辆,在速度上能够进行快速识别能够很好地解决这些问题。获得的照片的质量是影响车牌识别准确性的最重要因素之一。卷积神经网络在图像识别领域具有良好的适应性,目前在计算机视觉任务中应用广泛,并在手写数字识别、人脸识别、车牌识别等图像领域的应用中取得了很好的效果。本文基于MATLAB卷积神…

    2022年5月29日
    30
  • 面试常问到的经典100问 附答案和点评 参加过面试的人就知道这些题目出现的频率有多高啦 ①

    面试常问到的经典100问 附答案和点评 参加过面试的人就知道这些题目出现的频率有多高啦 ①1 问题 请给我们谈谈你自己的一些情况回答 简要的描述你的相关工作经历以及你的一些特征 包括与人相处的能力和个人的性格特征 如果你一下子不能够确定面试者到底需要什么样的内容 你可以这样说 有没有什么您特别感兴趣的范围 BYSCN com 点评 企业以此来判断是否应该聘用你 通过你的谈论 可以看出你想的是如何为公司效力还是那些会影响工作的个人问题 当然 还可以知道你的一些背景

    2026年3月27日
    2

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号