DataLoader函数
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None):
其中几个常用的参数
- dataset 数据集,map-style and iterable-style 可以用index取值的对象、
- batch_size 大小
- shuffle 取batch是否随机取, 默认为False
- sampler 定义取batch的方法,是一个迭代器, 每次生成一个key 用于读取dataset中的值
- batch_sampler 也是一个迭代器, 每次生次一个batch_size的key
- num_workers 参与工作的线程数
- collate_fn 对取出的batch进行处理
- drop_last 对最后不足batchsize的数据的处理方法
下面看两段取自DataLoader中的__init__代码, 帮助我们理解几个常用参数之间的关系
if sampler is None: # give default samplers if self._dataset_kind == _DatasetKind.Iterable: # See NOTE [ Custom Samplers and IterableDataset ] sampler = _InfiniteConstantSampler() else: # map-style if shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset)
可以看出, 当dataset类型是map style时, shuffle其实就是改变sampler的取值
- shuffle为默认值 False时,sampler是SequentialSampler,就是按顺序取样,
- shuffle为True时,sampler是RandomSampler, 就是按随机取样
所以当我们sampler有输入时,shuffle的值就没有意义,后面我们再看sampler的定义方法
再看一段初始化代码
if batch_size is not None and batch_sampler is None: # auto_collation without custom batch_sampler batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler
再看看,BatchSampler的生成过程
# 略去类的初始化 def __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch
就是按batch_size从sampler中读取索引, 并形成生成器返回。
以上可以看出, batch_sampler和sampler, batch_size, drop_last之间的关系
- 如果batch_sampler没有定义的话且batch_size有定义, 会根据sampler, batch_size, drop_last生成一个batch_sampler
- 自带的注释中对batch_sampler有一句话: Mutually exclusive with :attr:
batch_size:attr:shuffle, :attr:sampler, and :attr:drop_last. - 意思就是batch_sampler 与这些参数冲突 ,即 如果你定义了batch_sampler, 其他参数都不需要有
再看batch的生成过程
每个batch都是由迭代器产生的
# DataLoader中iter的部分 def __iter__(self): if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) else: return _MultiProcessingDataLoaderIter(self) # 再看调用的另一个类 class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): def __init__(self, loader): super(_SingleProcessDataLoaderIter, self).__init__(loader) assert self._timeout == 0 assert self._num_workers == 0 self._dataset_fetcher = _DatasetKind.create_fetcher( self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last) def __next__(self): index = self._next_index() data = self._dataset_fetcher.fetch(index) if self._pin_memory: data = _utils.pin_memory.pin_memory(data) return data
对上面的代码进行一一解读, 初始化略过
- 先对_next_index()一步步溯源
def _next_index(self): return next(self._sampler_iter) /// self._sampler_iter = iter(self._index_sampler) # 以上又用了一个迭代器生成索引 self._index_sampler = loader._index_sampler /// def _index_sampler(self): if self._auto_collation: return self.batch_sampler else: return self.sampler
- 再看 _dataset_fetcher函数
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): if kind == _DatasetKind.Map: return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last) else: return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last) # 按map-style往下看 class _MapDatasetFetcher(_BaseDatasetFetcher): # 略过初始化 def fetch(self, possibly_batched_index): if self.auto_collation: data = [self.dataset[idx] for idx in possibly_batched_index] # 关键 else: data = self.dataset[possibly_batched_index] return self.collate_fn(data)
def _auto_collation(self): return self.batch_sampler is not None
sampler 是用来定义取batch方法的一个函数或者类,返回的是一个迭代器。
- 我们可以看下自带的RandomSampler类中最重要的iter函数
def __iter__(self): n = len(self.data_source) # dataset的长度, 按顺序索引 if self.replacement:# 对应的replace参数 return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist()) return iter(torch.randperm(n).tolist())
- 比如__len__函数,包括DataLoader的len和sample的len, 两者区别, 这部分代码比较简单,可以自行阅读,其实参考着RandomSampler写也不会出现问题。
- 比如,迭代器和生成器的使用, 以及区别
附上最近读的reid的代码中涉及sampler的部分
class ImageDataset(Dataset): """Image Person ReID Dataset""" def __init__(self, dataset, transform=None): self.dataset = dataset self.transform = transform def __len__(self): return len(self.dataset) def __getitem__(self, index): img_path, pid, camid = self.dataset[index] img = read_image(img_path) if self.transform is not None: img = self.transform(img) return img, pid, camid, img_path
collate_fn 函数,可以从上面的fetch部分中看到, 也是对读取到的batch进行处理的一个对象,所以,预处理实际上也可以放在collate_fn中。
发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/213104.html原文链接:https://javaforall.net
