dataloader 源码_DataLoader

dataloader 源码_DataLoaderimportpaddle.fluidasfluidimportnumpyasnpBATCH_NUM=10BATCH_SIZE=16EPOCH_NUM=4CLASS_NUM=10ITERABLE=True#whetherthecreatedDataLoaderobjectisiterableUSE_GPU=False#whethertous…

大家好,又见面了,我是你们的朋友全栈君。

import paddle.fluid as fluid

import numpy as np

BATCH_NUM = 10

BATCH_SIZE = 16

EPOCH_NUM = 4

CLASS_NUM = 10

ITERABLE = True # whether the created DataLoader object is iterable

USE_GPU = False # whether to use GPU

DATA_FORMAT = ‘batch_generator’ # data format of data source user provides

def simple_net(image, label):

fc_tmp = fluid.layers.fc(image, size=CLASS_NUM)

cross_entropy = fluid.layers.softmax_with_cross_entropy(image, label)

loss = fluid.layers.reduce_mean(cross_entropy)

sgd = fluid.optimizer.SGD(learning_rate=1e-3)

sgd.minimize(loss)

return loss

def get_random_images_and_labels(image_shape, label_shape):

image = np.random.random(size=image_shape).astype(‘float32’)

label = np.random.random(size=label_shape).astype(‘int64’)

return image, label

# If the data generator yields one sample each time,

# use DataLoader.set_sample_generator to set the data source.

def sample_generator_creator():

def __reader__():

for _ in range(BATCH_NUM * BATCH_SIZE):

image, label = get_random_images_and_labels([784], [1])

yield image, label

return __reader__

# If the data generator yield list of samples each time,

# use DataLoader.set_sample_list_generator to set the data source.

def sample_list_generator_creator():

def __reader__():

for _ in range(BATCH_NUM):

sample_list = []

for _ in range(BATCH_SIZE):

image, label = get_random_images_and_labels([784], [1])

sample_list.append([image, label])

yield sample_list

return __reader__

# If the data generator yields a batch each time,

# use DataLoader.set_batch_generator to set the data source.

def batch_generator_creator():

def __reader__():

for _ in range(BATCH_NUM):

batch_image, batch_label = get_random_images_and_labels([BATCH_SIZE, 784], [BATCH_SIZE, 1])

yield batch_image, batch_label

return __reader__

# If DataLoader is iterable, use for loop to train the network

def train_iterable(exe, prog, loss, loader):

for _ in range(EPOCH_NUM):

for data in loader():

exe.run(prog, feed=data, fetch_list=[loss])

# If DataLoader is not iterable, use start() and reset() method to control the process

def train_non_iterable(exe, prog, loss, loader):

for _ in range(EPOCH_NUM):

loader.start() # call DataLoader.start() before each epoch starts

try:

while True:

exe.run(prog, fetch_list=[loss])

except fluid.core.EOFException:

loader.reset() # call DataLoader.reset() after catching EOFException

def set_data_source(loader, places):

if DATA_FORMAT == ‘sample_generator’:

loader.set_sample_generator(sample_generator_creator(), batch_size=BATCH_SIZE, drop_last=True, places=places)

elif DATA_FORMAT == ‘sample_list_generator’:

loader.set_sample_list_generator(sample_list_generator_creator(), places=places)

elif DATA_FORMAT == ‘batch_generator’:

loader.set_batch_generator(batch_generator_creator(), places=places)

else:

raise ValueError(‘Unsupported data format’)

image = fluid.layers.data(name=’image’, shape=[784], dtype=’float32′)

label = fluid.layers.data(name=’label’, shape=[1], dtype=’int64′)

# Define DataLoader

loader = fluid.io.DataLoader.from_generator(feed_list=[image, label], capacity=16, iterable=ITERABLE)

# Define network

loss = simple_net(image, label)

# Set data source of DataLoader

#

# If DataLoader is iterable, places must be given and the number of places must be the same with device number.

# – If you are using GPU, call `fluid.cuda_places()` to get all GPU places.

# – If you are using CPU, call `fluid.cpu_places()` to get all CPU places.

#

# If DataLoader is not iterable, places can be None.

places = fluid.cuda_places() if USE_GPU else fluid.cpu_places()

set_data_source(loader, places)

exe = fluid.Executor(places[0])

exe.run(fluid.default_startup_program())

prog = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name)

if loader.iterable:

train_iterable(exe, prog, loss, loader)

else:

train_non_iterable(exe, prog, loss, loader)

”’

Users can use return_list = True in dygraph mode.

”’

with fluid.dygraph.guard(places[0]):

loader = fluid.io.DataLoader.from_generator(capacity=2, return_list=True)

set_data_source(loader, places[0])

for image, label in loader():

relu = fluid.layers.relu(image)

assert image.shape == [BATCH_SIZE, 784]

assert label.shape == [BATCH_SIZE, 1]

assert relu.shape == [BATCH_SIZE, 784]

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/132199.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 黑客入门视频教程(共57个)全实战过程

    黑客入门视频教程(共57个)全实战过程黑客入门视频教程(共57个)全实战过程 01ping命令的使用http://images.enet.com.cn/eschool/wmv/ping.wmv02netstat命令的使用http://images.enet.com.cn/eschool/wmv/netstat.wmv03tasklist和taskkill的使用h…

    2022年5月29日
    22
  • Electron那些事10:本地数据库sqlite

    Electron那些事10:本地数据库sqlite【前言】上一节讲了本地日志,本地数据(文件)的部分,详见:Electron那些事09:本地数据_uikoo9的博客-CSDN博客虽然本地日志可以记录日志信息,本地数据可以记录简单的配置文件,但是像一些复杂的业务,需要维护一个本地数据库进行查询,本节讲一下本地数据库sqlite【sqlite】sqlite是有名的本地数据库,在很多系统中都有应用,SQLiteHomePage当然也有nodejs的版本,一般配套和electron使用,sqlite3-np…

    2022年5月11日
    59
  • 人工势场法(APF) —— Path Planning「建议收藏」

    人工势场法(APF) —— Path Planning「建议收藏」版权声明:本文为博主原创博文,未经允许不得转载,若要转载,请说明出处并给出博文链接人工势场法(ArtificialPotentialField,APF)是一种将机器人的外形视为势场中的一个点,这个势场结合了对目标的吸引力和对障碍物的排斥力。得到的轨迹作为路径输出。该方法具有计算量小、容易理解等优点。然而,它们可能陷入势场的局部极小值而无法找到路径,或者无法找到最优路径。人工势场可以被视为与静电势场类似的连续方程(将机器人视为点电荷),或者通过场的运动可以使用一组语言规则进行离散…

    2022年6月17日
    42
  • oracle11g详细安装教程_oracle11g32位安装

    oracle11g详细安装教程_oracle11g32位安装1、首先从http://www.oracle.com/technetwork/database/enterprise-edition/downloads/index.html下载合适的oracle数据库版本。2、解压压缩包,点击setup.exe,开始安装,一下为安装步骤的截图:口令:oracle11g第四步如果不是集群服务器要选择单实例数据库安装。

    2022年9月21日
    1
  • 青龙面板从零搭建教程(一)

    青龙面板从零搭建教程(一)大家好,QX系列教程教会了大家js脚本挂机的基础玩法,Boxjs为这个玩法提升了不少可玩性,但是IOS系统下最多支持2个账号,许多助力需求无法满足,应群友要求出一个青龙从零开始搭建教程,欢迎大家入群交流:106511927注意教程看不懂的话可以进群找群主帮你代挂!如果本教程看不懂或者操作出现问题,证明您的计算机专业知识并不支持本文章的搭建操作。第一步购买云服务器个人推荐阿里云服务器1核2G即可搞活动一年一百来块钱系统选择CentOs7等待配置完成。百度搜索Finalshell下载安装

    2022年6月13日
    81
  • backbone中文_backbone公司

    backbone中文_backbone公司代码下载地址:下载地址支持的backbone为Ghostnet、Shufflenetv2、Mobilenetv3Small、EagleEye、EfficientNetLite-0、PP-LCNet-1x、SwinTrans-YOLOv5Requirementspipinstall-rrequirements.txtMulti-BackboneSubstitutionforYOLOs1、BaseModelTrainonVisdroneDataSet(Inp

    2022年8月16日
    7

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号