pytorch之DataLoader

pytorch之DataLoaderpytorch之DataLoader在训练神经网络时,最好是对一个batch的数据进行操作,同时还需要对数据进行shuffle和并行加速等。对此,PyTorch提供了DataLoader帮助实现这些功能。Dataset只负责数据的抽象,一次调用__getitem__只返回一个样本。DataLoader的函数定义如下:DataLoader(dataset,batch_size=1,shu…

大家好,又见面了,我是你们的朋友全栈君。

pytorch之DataLoader

在训练神经网络时,最好是对一个batch的数据进行操作,同时还需要对数据进行shuffle和并行加速等。对此,PyTorch提供了DataLoader帮助实现这些功能。Dataset只负责数据的抽象,一次调用__getitem__只返回一个样本。

DataLoader的函数定义如下: DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)

dataset:加载的数据集(Dataset对象)
batch_size:batch size
shuffle::是否将数据打乱
sampler: 样本抽样,后续会详细介绍
num_workers:使用多进程加载的进程数,0代表不使用多进程
collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可
pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

from torch.utils import data
import os
from PIL import  Image
import torch as t
from torchvision import transforms as T
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

#定义transform1
normalize = T.Normalize(mean=[0.4, 0.4, 0.4], std=[0.2, 0.2, 0.2])
transform1  = T.Compose([
         T.RandomResizedCrop(224),
         T.RandomHorizontalFlip(),
         T.ToTensor(),
         normalize,
])

在这里插入图片描述
在这里插入图片描述

#实例化数据集dataset
dataset = ImageFolder('data1/dogcat_2/', transform=transform1)
#利用Dataloader函数加载
dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)

#取一个batch
dataiter = iter(dataloader)
imgs, labels = next(dataiter)
print(imgs.size()) # batch_size, channel, height, weighttorch.Size([3, 3, 224, 224])
print('*****')
for batch_datas, batch_labels in dataloader:
    print(batch_datas.size(),batch_labels.size())

在这里插入图片描述

transform = T.Compose([
    T.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为224像素
    T.CenterCrop(224), # 从图片中间切出224*224的图片
    T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
    T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1, 1],规定均值和标准差
])

在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在__getitem__函数中将出现异常,此时最好的解决方案即是将出错的样本剔除。如果实在是遇到这种情况无法处理,则可以返回None对象,然后在Dataloader中实现自定义的collate_fn,将空对象过滤掉。但要注意,在这种情况下dataloader返回的batch数目会少于batch_size。

class DogCat(data.Dataset):
    def __init__(self, root, transforms=None):
        imgs = os.listdir(root)
        self.imgs = [os.path.join(root, img) for img in imgs]
        self.transforms=transforms
        
    def __getitem__(self, index):
        img_path = self.imgs[index]
        label = 0 if 'dog' in img_path.split('/')[-1] else 1
        data = Image.open(img_path)
        if self.transforms:
            data = self.transforms(data)
        return data, label
    
    def __len__(self):
        return len(self.imgs)

   
class NewDogCat(DogCat): # 继承前面实现的DogCat数据集
    def __getitem__(self, index):
        try:
            # 调用父类的获取函数,即 DogCat.__getitem__(self, index)
            return super(NewDogCat,self).__getitem__(index)
        except:
            return None, None

from torch.utils.data.dataloader import default_collate # 导入默认的拼接方式
def my_collate_fn(batch):
    '''
    batch中每个元素形如(data, label)
    '''
    # 过滤为None的数据
    batch = list(filter(lambda x:x[0] is not None, batch))
    if len(batch) == 0: return t.Tensor()
    return default_collate(batch) # 用默认方式拼接过滤后的batch数据

在这里插入图片描述

dataset = NewDogCat('data1/dogcat_wrong/', transforms=transform)
#print(dataset[5])
print('*************')
dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn,shuffle=True)
for batch_datas, batch_labels in dataloader:
    print(batch_datas.size(),batch_labels.size())

在这里插入图片描述
来看一下上述batch_size的大小。其中第1个的batch_size为1,这是因为有一张图片损坏,导致其无法正常返回。而最后1个的batch_size也为1,这是因为共有9张(包括损坏的文件)图片,无法整除2(batch_size),因此最后一个batch的数据会少于batch_szie,可通过指定drop_last=True来丢弃最后一个不足batch_size的batch。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/138631.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • spring自定义注解实现(spring里面的注解)

    java注解:附在代码中的一些元信息,用于在编译、运行时起到说明、配置的功能。一、元注解java提供了4中元注解用于注解其他注解,所有的注解都是基于这四种注解来定义的。@Target注解:用于描述注解的使用范围,超出范围时编译失败。 取值类型(ElementType):  1.CONSTRUCTOR:用于描述构造器  2.FIELD:用于描述域(成

    2022年4月15日
    25
  • java文件上传到指定的路径_java sftp上传文件

    java文件上传到指定的路径_java sftp上传文件在java中获得文件的路径在我们做上传文件操作时是不可避免的。web上运行1:this.getClass().getClassLoader().getResource(“/”).getPath();this.getClass().getClassLoader().getResource(“”).getPath();得到的是ClassPath的绝对URI路径。如:/D:/jboss-4.2….

    2025年9月2日
    12
  • Python标准库 (pickle包,cPickle包)

    Python标准库 (pickle包,cPickle包)在之前对Python对象的介绍中(面向对象的基本概念,面向对象的进一步拓展),我提到过Python“一切皆对象”的哲学,在Python中,无论是变量还是函数,都是一个对象。当Python运行时,对象存储在内存中,随时等待系统的调用。然而,内存里的数据会随着计算机关机和消失,如何将对象保存到文件,并储存在硬盘上呢? 计算机的内存中存储的是二进制的序列(当然,在Linux眼中,是文本流)。我们…

    2022年4月20日
    35
  • onpropertychange替代方案[通俗易懂]

    onpropertychange替代方案1.onpropertychange的介绍onpropertychange事件就是property(属性)change(改变)的时候,触发事件。这是IE专有的!如果想兼容其它浏览器,有个类似的事件,oninput!可能大家会想到另外一个事件:onchange。但是,onchange有两个弊端。一、就是它在触发对象失去焦点时,才触发onchange事件。二、如果得用javascript改变触发对象的属性时,并不能触发onchange事件,oninput也

    2022年4月17日
    59
  • 整数规划matlab实例,整数规划matlab[通俗易懂]

    整数规划matlab实例,整数规划matlab[通俗易懂]整数规划matlabTag内容描述:1、例已知非线性整数规划为maxz=x12+x22+3×32+4×42+2×52-8×1-2×2-3×3-x4-2x5s.t.0xi99,i=1,2,5×1+x2+x3+x4+x5400x1+2×2+2×3+x4+6x58002x1+x2+6x3200x3+x4+5×5200(1)编写M文件mengte.m,定义目标函数f和约束向量函数g,程序如下:funct…

    2022年7月12日
    22
  • 记一次阿里云服务器因Redis被【挖矿病毒crypto和pnscan】攻击的case,附带解决方案

    记一次阿里云服务器因Redis被【挖矿病毒crypto和pnscan】攻击的case,附带解决方案一、事情缘由前段时间给大家录制《玩转Docker》教学视频,链接请参阅https://www.bilibili.com/video/BV1BK4y1A78M,为了更好的演示,就在阿里云搞了一个云服务器,自己本地连接。 云服务器到手之后,为了可以让自己本地可以连接到阿里云服务器上就做了如下操作: 开放了ssh(port:22)端口:为了远程连接使用。 开放了剩余的其他所有端口:录制教学视频时启动了相关容器,容器的一些固定端口6379等映射到了宿主机的随机端口上,为了能从公网访问到容器的内容,就开放

    2022年6月10日
    33

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号