PyTorch 中的数据类型 torch.utils.data.DataLoader

PyTorch 中的数据类型 torch.utils.data.DataLoaderDataLoader是PyTorch中的一种数据类型。在PyTorch中训练模型经常要使用它,那么该数据结构长什么样子,如何生成这样的数据类型?下面就研究一下:先看看 dataloader.py脚本是怎么写的(VS中按F12跳转到该脚本) __init__(构造函数)中的几个重要的属性:1、dataset:(数据类型dataset)输入的数据类型。看名字感觉就像是数据库,…

大家好,又见面了,我是你们的朋友全栈君。

DataLoader是PyTorch中的一种数据类型。

在PyTorch中训练模型经常要使用它,那么该数据结构长什么样子,如何生成这样的数据类型?

下面就研究一下:

先看看 dataloader.py脚本是怎么写的(VS中按F12跳转到该脚本)

 __init__(构造函数)中的几个重要的属性:

1、dataset:(数据类型 dataset)

输入的数据类型。看名字感觉就像是数据库,C#里面也有dataset类,理论上应该还有下一级的datatable。这应当是原始数据的输入。PyTorch内也有这种数据结构。这里先不管,估计和C#的类似,这里只需要知道是输入数据类型是dataset就可以了。

2、batch_size:(数据类型 int)

每次输入数据的行数,默认为1。PyTorch训练模型时调用数据不是一行一行进行的(这样太没效率),而是一捆一捆来的。这里就是定义每次喂给神经网络多少行数据,如果设置成1,那就是一行一行进行(个人偏好,PyTorch默认设置是1)。

3、shuffle:(数据类型 bool)

洗牌。默认设置为False。在每次迭代训练时是否将数据洗牌,默认设置是False。将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了。

4、collate_fn:(数据类型 callable,没见过的类型)

将一小段数据合并成数据列表,默认设置是False。如果设置成True,系统会在返回前会将张量数据(Tensors)复制到CUDA内存中。(不太明白作用是什么,就暂时默认False)

5、batch_sampler:(数据类型 Sampler)

批量采样,默认设置为None。但每次返回的是一批数据的索引(注意:不是数据)。其和batch_size、shuffle 、sampler and drop_last参数是不兼容的。我想,应该是每次输入网络的数据是随机采样模式,这样能使数据更具有独立性质。所以,它和一捆一捆按顺序输入,数据洗牌,数据采样,等模式是不兼容的。

6、sampler:(数据类型 Sampler)

采样,默认设置为None。根据定义的策略从数据集中采样输入。如果定义采样规则,则洗牌(shuffle)设置必须为False。

7、num_workers:(数据类型 Int)

工作者数量,默认是0。使用多少个子进程来导入数据。设置为0,就是使用主进程来导入数据。注意:这个数字必须是大于等于0的,负数估计会出错。

8、pin_memory:(数据类型 bool)

内存寄存,默认为False。在数据返回前,是否将数据复制到CUDA内存中。

9、drop_last:(数据类型 bool)

丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。

10、timeout:(数据类型 numeric)

超时,默认为0。是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。

11、worker_init_fn(数据类型 callable,没见过的类型)

子进程导入模式,默认为Noun。在数据导入前和步长结束后,根据工作子进程的ID逐个按顺序导入数据。

 

从DataLoader类的属性定义中可以看出,这个类的作用就是实现数据以什么方式输入到什么网络中。

代码一般是这么写的:

# 定义学习集 DataLoader

train_data = torch.utils.data.DataLoader(各种设置...) 

# 将数据喂入神经网络进行训练

for i, (input, target) in enumerate(train_data): 
    循环代码行……

 

如果全部采用默认设置输入数据,数据就是一行一行按顺序输入到神经网络。如果对数据的输入有特殊要求。

比如:想打乱一下数据的排序,可以设置 shuffle(洗牌)为True;

比如:想数据是一捆的输入,可以设置 batch_size 的数目;

比如:想随机抽取的模式输入,可以设置 sampler 或 batch_sampler。如何定义抽样规则,可以看sampler.py脚本。这里不是重点;

比如:像多线程输入,可以设置 num_workers 的数目

其他的就不太懂了,以后实际应用时碰到特殊要求再研究吧。

DataLoader类中还有3个函数:

def __setattr__(self, attr, val):
        if self.__initialized and attr in (‘batch_size’, ‘sampler’, ‘drop_last’):
            raise ValueError(‘{} attribute should not be set after {} is ‘
                             ‘initialized’.format(attr, self.__class__.__name__))

        super(DataLoader, self).__setattr__(attr, val)

def __iter__(self):
        return _DataLoaderIter(self)

def __len__(self):
        return len(self.batch_sampler)

关键是第二个函数,

_DataLoaderIter 又是一个类,被一起写在DataLoader.py文件中。

主要是用来处理各种设置如何运作的,这里就不管那么多啦。

最后,如果要导入自己各种古灵精怪的数据,就要看看 DataSet 又是如何操作的。

torch.utils.data主要包括以下三个类: 
1. class torch.utils.data.Dataset

其他的数据集类必须是torch.utils.data.Dataset的子类,比如说torchvision.ImageFolder. 
2. class torch.utils.data.sampler.Sampler(data_source) 
参数: data_source (Dataset) – dataset to sample from 
作用: 创建一个采样器, class torch.utils.data.sampler.Sampler是所有的Sampler的基类, 其中,iter(self)函数来获取一个迭代器,对数据集中元素的索引进行迭代,len(self)方法返回迭代器中包含元素的长度. 
3. class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None) 
参数:

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/143863.html原文链接:https://javaforall.net

(0)
上一篇 2022年5月20日 下午5:40
下一篇 2022年5月20日 下午5:40


相关推荐

  • jdbc如何连接mysql数据库_sqlplus连接远程数据库

    jdbc如何连接mysql数据库_sqlplus连接远程数据库好多朋友遇到了在本地可以连接mysql数据库,而在jsp页面连接远程mysql数据库而连不上的问题,现总结以下:1.配置远程mysql数据库,使其允许远程tcp/ip连接,开放默认端口(3306) 或者设置为3309,2.创建用户,使其具有在任意HOST连接任意database的权限;3.在jdbc连接串中设置端口,如:jdbc:mysql://192.168.0.2:3309/ic

    2022年10月10日
    4
  • 门面模式和适配器模式_数字化门店转型

    门面模式和适配器模式_数字化门店转型门面模式Facade动机模式定义结构要点总结笔记动机上述A方案的问题在于组件的客户和组件中各种复杂的子系统有了过多的耦合,随着外部客户程序和各子系统的演化.这种过多的耦合面临很多变化的挑战如何简化外部客户端和系统间的交互接口呢?如何将外部客户程序的演化和内部子系统的变化之间的依赖相互解耦模式定义为子系统中的一组接口提供一个**一致(稳定)**的界面,Facade模式定义了一个高层接口,这个接口使得这一子系统更加容易使用(复用)结构要点总结从客户程序的角度来看,Facade模式简化了整个

    2022年8月9日
    6
  • 免费版(个人家庭免费使用)xshell7 和 xftp7 下载

    免费版(个人家庭免费使用)xshell7 和 xftp7 下载xshell6、xftp6个人免费版:百度云下载地址:https://pan.baidu.com/s/19mTPpYgXo65u9SCI1IINPQ密码:9wr0安装完毕,启动时会有弹出框,关闭即可缺点:一个xshell中shell窗口个数最多四个,有限制,可以下载下面xmanager5套件,使用不受限制xmanager5[包含xshell,xftp5]:xman…

    2022年10月12日
    3
  • mysql 联合查询_MySQL联合查询

    mysql 联合查询_MySQL联合查询MySQL联合查询联合查询:union,将多次查询(多条select语句)的结果,在字段数相同的情况下,在记录的层次上进行拼接。基本语法联合查询由多条select语句构成,每条select语句获取的字段数相同,但与字段类型无关。基本语法:select语句1+union+[union选项]+select语句2+…;union选项:与select选项一样有两种all:无论重复…

    2022年6月10日
    40
  • 我要自学编程,Java和C语言相比哪个好?[通俗易懂]

    我要自学编程,Java和C语言相比哪个好?[通俗易懂]JavaJava是一种可以撰写跨平台应用软件的面向对象的程序设计语言。Java技术具有卓越的通用性、高效性、平台移植性和安全性,广泛应用于PC、数据中心、游戏控制台、科学超级计算机、移动电话和互联网,同时拥有全球最大的开发者专业社群。C语言学习C语言是一种计算机程序设计语言,属高级语言范畴。它既具有高级语言的特点,又具有汇编语言的特点。它可以作为工作系统设计语言,编写系统应用程序,也可以作为应用程序设计语言,编写不依赖计算机硬件的应用程序,代码清晰精简,十分灵活。语言没有好坏之分,无论学习哪个语言

    2022年7月7日
    37
  • 涨见识| 字节PHP/Golang社招面经[通俗易懂]

    涨见识| 字节PHP/Golang社招面经

    2022年2月14日
    50

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号