大家好,又见面了,我是你们的朋友全栈君。
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
def main():
batchsz = 32
cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor
]), download=True)
cifar_train = DataLoader(cifar_train,batch_size=batchse,shuffle=True)
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor
]), download=True)
cifar_teat = DataLoader(cifar_train,batch_size=batchse,shuffle=True)
x, label = iter(cifar_train).next()
print('x:', x.shape, 'label:', label.shape)
if __name__ == "__main__":
main()
发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/152097.html原文链接:https://javaforall.net