dataloader 源码_DataLoader

dataloader 源码_DataLoaderimportpaddle.fluidasfluidimportnumpyasnpBATCH_NUM=10BATCH_SIZE=16EPOCH_NUM=4CLASS_NUM=10ITERABLE=True#whetherthecreatedDataLoaderobjectisiterableUSE_GPU=False#whethertous…

大家好,又见面了,我是你们的朋友全栈君。

import paddle.fluid as fluid

import numpy as np

BATCH_NUM = 10

BATCH_SIZE = 16

EPOCH_NUM = 4

CLASS_NUM = 10

ITERABLE = True # whether the created DataLoader object is iterable

USE_GPU = False # whether to use GPU

DATA_FORMAT = ‘batch_generator’ # data format of data source user provides

def simple_net(image, label):

fc_tmp = fluid.layers.fc(image, size=CLASS_NUM)

cross_entropy = fluid.layers.softmax_with_cross_entropy(image, label)

loss = fluid.layers.reduce_mean(cross_entropy)

sgd = fluid.optimizer.SGD(learning_rate=1e-3)

sgd.minimize(loss)

return loss

def get_random_images_and_labels(image_shape, label_shape):

image = np.random.random(size=image_shape).astype(‘float32’)

label = np.random.random(size=label_shape).astype(‘int64’)

return image, label

# If the data generator yields one sample each time,

# use DataLoader.set_sample_generator to set the data source.

def sample_generator_creator():

def __reader__():

for _ in range(BATCH_NUM * BATCH_SIZE):

image, label = get_random_images_and_labels([784], [1])

yield image, label

return __reader__

# If the data generator yield list of samples each time,

# use DataLoader.set_sample_list_generator to set the data source.

def sample_list_generator_creator():

def __reader__():

for _ in range(BATCH_NUM):

sample_list = []

for _ in range(BATCH_SIZE):

image, label = get_random_images_and_labels([784], [1])

sample_list.append([image, label])

yield sample_list

return __reader__

# If the data generator yields a batch each time,

# use DataLoader.set_batch_generator to set the data source.

def batch_generator_creator():

def __reader__():

for _ in range(BATCH_NUM):

batch_image, batch_label = get_random_images_and_labels([BATCH_SIZE, 784], [BATCH_SIZE, 1])

yield batch_image, batch_label

return __reader__

# If DataLoader is iterable, use for loop to train the network

def train_iterable(exe, prog, loss, loader):

for _ in range(EPOCH_NUM):

for data in loader():

exe.run(prog, feed=data, fetch_list=[loss])

# If DataLoader is not iterable, use start() and reset() method to control the process

def train_non_iterable(exe, prog, loss, loader):

for _ in range(EPOCH_NUM):

loader.start() # call DataLoader.start() before each epoch starts

try:

while True:

exe.run(prog, fetch_list=[loss])

except fluid.core.EOFException:

loader.reset() # call DataLoader.reset() after catching EOFException

def set_data_source(loader, places):

if DATA_FORMAT == ‘sample_generator’:

loader.set_sample_generator(sample_generator_creator(), batch_size=BATCH_SIZE, drop_last=True, places=places)

elif DATA_FORMAT == ‘sample_list_generator’:

loader.set_sample_list_generator(sample_list_generator_creator(), places=places)

elif DATA_FORMAT == ‘batch_generator’:

loader.set_batch_generator(batch_generator_creator(), places=places)

else:

raise ValueError(‘Unsupported data format’)

image = fluid.layers.data(name=’image’, shape=[784], dtype=’float32′)

label = fluid.layers.data(name=’label’, shape=[1], dtype=’int64′)

# Define DataLoader

loader = fluid.io.DataLoader.from_generator(feed_list=[image, label], capacity=16, iterable=ITERABLE)

# Define network

loss = simple_net(image, label)

# Set data source of DataLoader

#

# If DataLoader is iterable, places must be given and the number of places must be the same with device number.

# – If you are using GPU, call `fluid.cuda_places()` to get all GPU places.

# – If you are using CPU, call `fluid.cpu_places()` to get all CPU places.

#

# If DataLoader is not iterable, places can be None.

places = fluid.cuda_places() if USE_GPU else fluid.cpu_places()

set_data_source(loader, places)

exe = fluid.Executor(places[0])

exe.run(fluid.default_startup_program())

prog = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name)

if loader.iterable:

train_iterable(exe, prog, loss, loader)

else:

train_non_iterable(exe, prog, loss, loader)

”’

Users can use return_list = True in dygraph mode.

”’

with fluid.dygraph.guard(places[0]):

loader = fluid.io.DataLoader.from_generator(capacity=2, return_list=True)

set_data_source(loader, places[0])

for image, label in loader():

relu = fluid.layers.relu(image)

assert image.shape == [BATCH_SIZE, 784]

assert label.shape == [BATCH_SIZE, 1]

assert relu.shape == [BATCH_SIZE, 784]

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/132199.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • sql中declare的用法_sql局部变量

    sql中declare的用法_sql局部变量换工作了,以后主要和SqlServer打交道了,仿佛回到了大学,不知道学校的饭还是那么好吃又便宜吗?北京的饭好贵;不知道门口哪家板面的生意是不是还是那么红火,好想再去吃一碗。。。咳咳,不多说了,直接进入主题declare这个类型,其实可以理解为Java里面的public类型变量,全局有效,当然非要较真的话,我觉得归到protected类也可以(不理解的话不要看后半段,只是为了严谨)Java修饰符 public:对所有类可见。使用对象:类、接口、变量、方法 protect..

    2022年8月20日
    8
  • Q-学习:强化学习「建议收藏」

    Q-学习:强化学习「建议收藏」原文地址:http://mnemstudio.org/path-finding-q-learning-tutorial.htm这篇教程通过简单且易于理解的实例介绍了Q-学习的概念知识,例子描述了一个智能体通过非监督学习的方法对未知的环境进行学习。假设我们的楼层内共有5个房间,房间之间通过一道门相连,正如下图所示。我们将房间编号为房间0到房间4,楼层的外部可以被看作是一间大房间,编号为5。注

    2022年9月25日
    2
  • vue环境安装与配置(Linux安装常用开发工具)

    vue安装环境搭建提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档文章目录vue安装环境搭建前言一、node.js安装和配置1.下载安装node.js2.配置默认安装目录和缓存日志目录3.node.js环境配置4.配置淘宝镜像源二、使用步骤1.引入库2.读入数据总结前言vue前端框架的环境搭建一、node.js安装和配置1.下载安装node.js官网下载最新版本:https://nodejs.org/en/download/可以下载安装包(安装教程见:http

    2022年4月18日
    69
  • C语言贪吃蛇代码_c语言贪吃蛇游戏

    C语言贪吃蛇代码_c语言贪吃蛇游戏一、C语言贪吃蛇代码实现前言设计贪吃蛇游戏的主要目的是让大家夯实C语言基础,训练编程思维,培养解决问题的思路,领略多姿多彩的C语言。贪吃蛇是非常经典的一款游戏,本次我们模拟在控制台实现贪吃蛇游戏,也就是实现贪吃蛇的基本功能,比如在地图中,用“↑↓←→”控制移动蛇的方向,吃掉食物之后,蛇身体会变长等等。。。。首先我们得分析,游戏中我们会碰见的一些情况。①蛇的部分,蛇的身子是一节一节的,此时最容易联想到的数据结构就是顺序表,链表,如果把蛇比做顺序表或者链表,在之后吃到食物的时候,身子肯定会变长,

    2025年9月6日
    5
  • uniapp打包ios真机测试_苹果手机没有app可供测试

    uniapp打包ios真机测试_苹果手机没有app可供测试这里说的是数据线链接首先要给电脑上安装一个手机驱动我用的是360助手(链接如下:)http://sj.360.cn/index.html先测试手机和电脑是否能连接成功然后再HBuilderX中点击运行到手机模拟器就可以了,注意:连接手机的时候会提示手机也需要下载360手机助手(按照提示下载即可),手机上会自动安装一个HBuilder,点击进去就是自己的app手机连接的时候要选中usb调试模式…

    2022年9月6日
    7
  • 安捷伦示波器连接电脑没反应_安捷伦示波器升级

    安捷伦示波器连接电脑没反应_安捷伦示波器升级调试或者测试电路的过程中,示波器是比不可少的东西,且常需要保存示波器的波形,可以通过一条网线把示波器与电脑连接起来,这样就可以在电脑上方便的保存波形图,怎么实现示波器和电脑连接,我们以安捷伦示波器DSO7052B为例,其他示波器类似。1.首先通过一条网线连接示波器和电脑,在电脑上找到:网络和共享中心->以太网状态->以太网属性->TCP/IPv4协议,点击属性2…

    2022年10月12日
    2

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号