PyTorch 中的数据类型 torch.utils.data.DataLoader

PyTorch 中的数据类型 torch.utils.data.DataLoaderDataLoader是PyTorch中的一种数据类型。在PyTorch中训练模型经常要使用它,那么该数据结构长什么样子,如何生成这样的数据类型?下面就研究一下:先看看 dataloader.py脚本是怎么写的(VS中按F12跳转到该脚本) __init__(构造函数)中的几个重要的属性:1、dataset:(数据类型dataset)输入的数据类型。看名字感觉就像是数据库,…

大家好,又见面了,我是你们的朋友全栈君。

DataLoader是PyTorch中的一种数据类型。

在PyTorch中训练模型经常要使用它,那么该数据结构长什么样子,如何生成这样的数据类型?

下面就研究一下:

先看看 dataloader.py脚本是怎么写的(VS中按F12跳转到该脚本)

 __init__(构造函数)中的几个重要的属性:

1、dataset:(数据类型 dataset)

输入的数据类型。看名字感觉就像是数据库,C#里面也有dataset类,理论上应该还有下一级的datatable。这应当是原始数据的输入。PyTorch内也有这种数据结构。这里先不管,估计和C#的类似,这里只需要知道是输入数据类型是dataset就可以了。

2、batch_size:(数据类型 int)

每次输入数据的行数,默认为1。PyTorch训练模型时调用数据不是一行一行进行的(这样太没效率),而是一捆一捆来的。这里就是定义每次喂给神经网络多少行数据,如果设置成1,那就是一行一行进行(个人偏好,PyTorch默认设置是1)。

3、shuffle:(数据类型 bool)

洗牌。默认设置为False。在每次迭代训练时是否将数据洗牌,默认设置是False。将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了。

4、collate_fn:(数据类型 callable,没见过的类型)

将一小段数据合并成数据列表,默认设置是False。如果设置成True,系统会在返回前会将张量数据(Tensors)复制到CUDA内存中。(不太明白作用是什么,就暂时默认False)

5、batch_sampler:(数据类型 Sampler)

批量采样,默认设置为None。但每次返回的是一批数据的索引(注意:不是数据)。其和batch_size、shuffle 、sampler and drop_last参数是不兼容的。我想,应该是每次输入网络的数据是随机采样模式,这样能使数据更具有独立性质。所以,它和一捆一捆按顺序输入,数据洗牌,数据采样,等模式是不兼容的。

6、sampler:(数据类型 Sampler)

采样,默认设置为None。根据定义的策略从数据集中采样输入。如果定义采样规则,则洗牌(shuffle)设置必须为False。

7、num_workers:(数据类型 Int)

工作者数量,默认是0。使用多少个子进程来导入数据。设置为0,就是使用主进程来导入数据。注意:这个数字必须是大于等于0的,负数估计会出错。

8、pin_memory:(数据类型 bool)

内存寄存,默认为False。在数据返回前,是否将数据复制到CUDA内存中。

9、drop_last:(数据类型 bool)

丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。

10、timeout:(数据类型 numeric)

超时,默认为0。是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。

11、worker_init_fn(数据类型 callable,没见过的类型)

子进程导入模式,默认为Noun。在数据导入前和步长结束后,根据工作子进程的ID逐个按顺序导入数据。

 

从DataLoader类的属性定义中可以看出,这个类的作用就是实现数据以什么方式输入到什么网络中。

代码一般是这么写的:

# 定义学习集 DataLoader

train_data = torch.utils.data.DataLoader(各种设置...) 

# 将数据喂入神经网络进行训练

for i, (input, target) in enumerate(train_data): 
    循环代码行……

 

如果全部采用默认设置输入数据,数据就是一行一行按顺序输入到神经网络。如果对数据的输入有特殊要求。

比如:想打乱一下数据的排序,可以设置 shuffle(洗牌)为True;

比如:想数据是一捆的输入,可以设置 batch_size 的数目;

比如:想随机抽取的模式输入,可以设置 sampler 或 batch_sampler。如何定义抽样规则,可以看sampler.py脚本。这里不是重点;

比如:像多线程输入,可以设置 num_workers 的数目

其他的就不太懂了,以后实际应用时碰到特殊要求再研究吧。

DataLoader类中还有3个函数:

def __setattr__(self, attr, val):
        if self.__initialized and attr in (‘batch_size’, ‘sampler’, ‘drop_last’):
            raise ValueError(‘{} attribute should not be set after {} is ‘
                             ‘initialized’.format(attr, self.__class__.__name__))

        super(DataLoader, self).__setattr__(attr, val)

def __iter__(self):
        return _DataLoaderIter(self)

def __len__(self):
        return len(self.batch_sampler)

关键是第二个函数,

_DataLoaderIter 又是一个类,被一起写在DataLoader.py文件中。

主要是用来处理各种设置如何运作的,这里就不管那么多啦。

最后,如果要导入自己各种古灵精怪的数据,就要看看 DataSet 又是如何操作的。

torch.utils.data主要包括以下三个类: 
1. class torch.utils.data.Dataset

其他的数据集类必须是torch.utils.data.Dataset的子类,比如说torchvision.ImageFolder. 
2. class torch.utils.data.sampler.Sampler(data_source) 
参数: data_source (Dataset) – dataset to sample from 
作用: 创建一个采样器, class torch.utils.data.sampler.Sampler是所有的Sampler的基类, 其中,iter(self)函数来获取一个迭代器,对数据集中元素的索引进行迭代,len(self)方法返回迭代器中包含元素的长度. 
3. class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None) 
参数:

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/143863.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 离散系统的变换域

    离散系统的变换域

    2022年1月4日
    71
  • Oracle 正则表达式以及常用正则函数

    Oracle 正则表达式以及常用正则函数Oracle正则表达式以及常用函数正则表达式简介正则表达式基础Oracle常用函数正则表达式简介菜鸟教程练习网站1练习网站2练习网站3练习网站4软件下载什么是正则表达式?正则表达式,又称规则表达式。(英语:RegularExpression,在代码中简写为regex、regexp或RE),计算机科学的一个概念。正则表达式通常被用来检索、替换那些符合某个模式(规则)的文本。什么时候会用到正则表达式?数据验证字符串查找字符串替换正则表达式基础元字符描述

    2022年6月1日
    64
  • DELL服务器数据恢复成功案例

    DELL服务器数据恢复成功案例DELLEqualLogicPS6100采用虚拟ISCSISAN阵列,为远程或分支办公室、部门和中小企业存储部署带来企业级功能、智能化、自动化和可靠性。以简化的管理、快速的部署及合理的价格满足了分支办公室和中小企业的存储需求,同时提供全套企业级数据保护和管理功能、可靠的性能、可扩展性和容错功能,是中型企业级存储的起点产品,但某些物理故障或其他操作都可能会对卷或存储造成破坏,因此对系列存储的数…

    2022年6月30日
    24
  • SQL语句LIKE CONCAT模糊查询

    SQL语句LIKE CONCAT模糊查询Oracle拼接字符串concat需要注意的小事项在用ssm框架编写代码的时候,因为数据库换成了Oracle,在模糊查询数据的时候突然发现报错了select*fromSYS_MENUwhereurllikeconcat(‘%’,#{roleName},’%’)一直报错参数个数无效,在网上查找资料发现模糊查询的sql语句还是concat(‘%’,’s’,’%’)这样写的…

    2022年5月29日
    35
  • zabbix监控mysql的哪些参数_Zabbix监控Mysql数据库性能

    zabbix监控mysql的哪些参数_Zabbix监控Mysql数据库性能在之前的博文里面写过如何通过Zabbix监控mysql主从同步是否OK,mysql从库是否有延时(Seconds_Behind_Master)主库,当mysql主从有异常时通过Email或者SMS通知DBA和系统人员。除此之外,Zabbix还可以监控mysqlslowqueries,mysqlversion,uptime,alive等。下面通过ZabbixGraphs实时查看的SQL语句操…

    2022年4月28日
    29
  • 基于Python的频谱分析(一)

    基于Python的频谱分析(一)

    2021年11月21日
    43

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号