pytorch Dataloader Sampler参数深入理解

pytorch Dataloader Sampler参数深入理解DataLoader 函数参数与初始化 def init self dataset batch size 1 shuffle False sampler None batch sampler None num workers 0 collate fn None pin memory False

DataLoader函数

  • 参数与初始化

 def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None): 

其中几个常用的参数

  1. dataset 数据集,map-style and iterable-style 可以用index取值的对象、
  2. batch_size 大小
  3. shuffle 取batch是否随机取, 默认为False
  4. sampler 定义取batch的方法,是一个迭代器, 每次生成一个key 用于读取dataset中的值
  5. batch_sampler 也是一个迭代器, 每次生次一个batch_size的key
  6. num_workers 参与工作的线程数
  7. collate_fn 对取出的batch进行处理
  8. drop_last 对最后不足batchsize的数据的处理方法

下面看两段取自DataLoader中的__init__代码, 帮助我们理解几个常用参数之间的关系

 if sampler is None: # give default samplers if self._dataset_kind == _DatasetKind.Iterable: # See NOTE [ Custom Samplers and IterableDataset ] sampler = _InfiniteConstantSampler() else: # map-style if shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset) 

可以看出, 当dataset类型是map style时, shuffle其实就是改变sampler的取值

  • shuffle为默认值 False时,sampler是SequentialSampler,就是按顺序取样,
  • shuffle为True时,sampler是RandomSampler, 就是按随机取样

所以当我们sampler有输入时,shuffle的值就没有意义,后面我们再看sampler的定义方法

再看一段初始化代码

 if batch_size is not None and batch_sampler is None: # auto_collation without custom batch_sampler batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler 

再看看,BatchSampler的生成过程

# 略去类的初始化 def __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch 

就是按batch_size从sampler中读取索引, 并形成生成器返回。

以上可以看出, batch_sampler和sampler, batch_size, drop_last之间的关系

  • 如果batch_sampler没有定义的话且batch_size有定义, 会根据sampler, batch_size, drop_last生成一个batch_sampler
  • 自带的注释中对batch_sampler有一句话: Mutually exclusive with :attr:batch_size :attr:shuffle, :attr:sampler, and :attr:drop_last.
  • 意思就是batch_sampler 与这些参数冲突 ,即 如果你定义了batch_sampler, 其他参数都不需要有

再看batch的生成过程

每个batch都是由迭代器产生的

# DataLoader中iter的部分 def __iter__(self): if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) else: return _MultiProcessingDataLoaderIter(self) # 再看调用的另一个类 class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): def __init__(self, loader): super(_SingleProcessDataLoaderIter, self).__init__(loader) assert self._timeout == 0 assert self._num_workers == 0 self._dataset_fetcher = _DatasetKind.create_fetcher( self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last) def __next__(self): index = self._next_index() data = self._dataset_fetcher.fetch(index) if self._pin_memory: data = _utils.pin_memory.pin_memory(data) return data 

对上面的代码进行一一解读, 初始化略过

  • 先对_next_index()一步步溯源
 def _next_index(self): return next(self._sampler_iter) /// self._sampler_iter = iter(self._index_sampler) # 以上又用了一个迭代器生成索引  self._index_sampler = loader._index_sampler /// def _index_sampler(self): if self._auto_collation: return self.batch_sampler else: return self.sampler 
  • 再看 _dataset_fetcher函数
 def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): if kind == _DatasetKind.Map: return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last) else: return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last) # 按map-style往下看 class _MapDatasetFetcher(_BaseDatasetFetcher): # 略过初始化 def fetch(self, possibly_batched_index): if self.auto_collation: data = [self.dataset[idx] for idx in possibly_batched_index] # 关键 else: data = self.dataset[possibly_batched_index] return self.collate_fn(data) 
 def _auto_collation(self): return self.batch_sampler is not None 
  • sampler 参数的使用

sampler 是用来定义取batch方法的一个函数或者类,返回的是一个迭代器。

  • 我们可以看下自带的RandomSampler类中最重要的iter函数
 def __iter__(self): n = len(self.data_source) # dataset的长度, 按顺序索引 if self.replacement:# 对应的replace参数 return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist()) return iter(torch.randperm(n).tolist()) 
  • 比如__len__函数,包括DataLoader的len和sample的len, 两者区别, 这部分代码比较简单,可以自行阅读,其实参考着RandomSampler写也不会出现问题。
  • 比如,迭代器和生成器的使用, 以及区别

附上最近读的reid的代码中涉及sampler的部分

  • 关于dataset预处理和collate_fn的一些问题

class ImageDataset(Dataset): """Image Person ReID Dataset""" def __init__(self, dataset, transform=None): self.dataset = dataset self.transform = transform def __len__(self): return len(self.dataset) def __getitem__(self, index): img_path, pid, camid = self.dataset[index] img = read_image(img_path) if self.transform is not None: img = self.transform(img) return img, pid, camid, img_path 

collate_fn 函数,可以从上面的fetch部分中看到, 也是对读取到的batch进行处理的一个对象,所以,预处理实际上也可以放在collate_fn中。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/213104.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月18日 下午6:33
下一篇 2026年3月18日 下午6:33


相关推荐

  • C#验证码的实现_验证码怎么实现

    C#验证码的实现_验证码怎么实现一.编程思想(1).验证码由四位随机数字或者字母组成,此时就要考虑怎么获取随机数(2).各个字符之间怎么进行连接(3).当点击更换时会重新生成四位随机数(4).四位字符的显示二.代码的实现(1).引入伪随机数生成器Random,生成随机数实例化Random:Randomp=newRandom();//表示伪随机数生成器Randomp=newRandom();……

    2025年10月16日
    3
  • Linux移植一_linux从零开始移植

    Linux移植一_linux从零开始移植Linux移植一本文博客链接:http://blog.csdn.net/jdh99,作者:jdh,转载请注明.现在手上有两个开发板,一个是tiny6410,一个是OK6410-A.tiny6410上跑的是linux2.6.38,支持alsa,uboot支持yaffs2系统以及从sd卡启动linux,并且移植了qte的库.而ok6410上跑的是linux2.6.36,没有

    2025年11月27日
    5
  • 已安装的实例怎么删除_如何删除数据库实例

    已安装的实例怎么删除_如何删除数据库实例1.事件问题描述考虑到整个项目组的需求,我将PLC博图V16卸载,然后重新下载安装博图V15.1,然而因为在删除前博图V16时没有删除干净,安装博图V15.1导致出现以下问题:请删除SQLServer的”WinCC”实例,因为在卸载”WinCCProfessional”或”WinCCRuntimeProfessional”之后,该实例仍然存在于TIAPortal.2.解决方案…

    2022年10月2日
    5
  • html 的scor属性,scrollheight属性「建议收藏」

    html 的scor属性,scrollheight属性「建议收藏」scrollHeight属性是属于什么范畴?CSS布局HTML小编今天和大家分享问大神,Height属性到底指的是什么html设置overflow-x:scroll;属性后怎么让指定位如果页面不够长(至少窗口长度两倍),那肯定滚动不到一半的位置。否则任何浏览器都不会产生误差。下面的例子输出100个,页面加载的时候会滚动到第51个。window.onload=function(…

    2022年7月23日
    17
  • 作为测试负责人如何规范测试团队建设_测试人员如何开展测试工作

    作为测试负责人如何规范测试团队建设_测试人员如何开展测试工作前言:今天是2021年11月17日,我入职新公司工作的第20天,工作也确实比较忙,准确的来说在公司大家都忙,我基本上都是早上7点半起床,晚上12点到家,睡午觉的时间忙中偷闲更新下博客!作为测试负责人如何规范测试团队?一、我的提问二、你会发现存在的问题1、流程不规范2、缺乏沟通3、没有共享文档4、没有输出三、如何做好流程规范1、测试进度及计划面板2、技术评审3、提测规范4、测试用例评审四、如何做好流程规范1、测试进度及计划面板一、我的提问当你来到一个项目不规范的技术团队,你会怎么处理呢?二、你会发现存

    2025年8月5日
    5
  • Apache的URL地址重写(RewriteCond与RewriteRule)

    Apache的URL地址重写(RewriteCond与RewriteRule)Apache的URL地址重写http://hi.baidu.com/sonan/blog/item/c408963d89468208bba16716.html第一种方法:Apache环境中如果要将URL地址重写,正则表达式是最基本的要求,但对于一般的URL地址来说,基本的匹配就能实现我们大部分要求,因此除非是非常特殊的URL地址,但这不是我要讨论的范围,简单几招学会Apache中URL地

    2022年6月11日
    26

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号