pytorch下载CIFAR10数据集[通俗易懂]

pytorch下载CIFAR10数据集[通俗易懂]importtorchfromtorchvisionimportdatasetsfromtorchvisionimporttransformsfromtorch.utils.dataimportDataLoaderdefmain():batchsz=32cifar_train=datasets.CIFAR10(‘cifar’,True,transform=transforms.Compose([transforms.Re

大家好,又见面了,我是你们的朋友全栈君。

import torch 
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader


def main():
    batchsz = 32

    cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor
    ]), download=True)
    cifar_train = DataLoader(cifar_train,batch_size=batchse,shuffle=True)

    cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor
    ]), download=True)
    cifar_teat = DataLoader(cifar_train,batch_size=batchse,shuffle=True)

    x, label = iter(cifar_train).next()
    print('x:', x.shape, 'label:', label.shape)


if __name__ == "__main__":
    main()
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/152097.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号