dataloader 源码_DataLoader

dataloader 源码_DataLoaderimportpaddle.fluidasfluidimportnumpyasnpBATCH_NUM=10BATCH_SIZE=16EPOCH_NUM=4CLASS_NUM=10ITERABLE=True#whetherthecreatedDataLoaderobjectisiterableUSE_GPU=False#whethertous…

大家好,又见面了,我是你们的朋友全栈君。

import paddle.fluid as fluid

import numpy as np

BATCH_NUM = 10

BATCH_SIZE = 16

EPOCH_NUM = 4

CLASS_NUM = 10

ITERABLE = True # whether the created DataLoader object is iterable

USE_GPU = False # whether to use GPU

DATA_FORMAT = ‘batch_generator’ # data format of data source user provides

def simple_net(image, label):

fc_tmp = fluid.layers.fc(image, size=CLASS_NUM)

cross_entropy = fluid.layers.softmax_with_cross_entropy(image, label)

loss = fluid.layers.reduce_mean(cross_entropy)

sgd = fluid.optimizer.SGD(learning_rate=1e-3)

sgd.minimize(loss)

return loss

def get_random_images_and_labels(image_shape, label_shape):

image = np.random.random(size=image_shape).astype(‘float32’)

label = np.random.random(size=label_shape).astype(‘int64’)

return image, label

# If the data generator yields one sample each time,

# use DataLoader.set_sample_generator to set the data source.

def sample_generator_creator():

def __reader__():

for _ in range(BATCH_NUM * BATCH_SIZE):

image, label = get_random_images_and_labels([784], [1])

yield image, label

return __reader__

# If the data generator yield list of samples each time,

# use DataLoader.set_sample_list_generator to set the data source.

def sample_list_generator_creator():

def __reader__():

for _ in range(BATCH_NUM):

sample_list = []

for _ in range(BATCH_SIZE):

image, label = get_random_images_and_labels([784], [1])

sample_list.append([image, label])

yield sample_list

return __reader__

# If the data generator yields a batch each time,

# use DataLoader.set_batch_generator to set the data source.

def batch_generator_creator():

def __reader__():

for _ in range(BATCH_NUM):

batch_image, batch_label = get_random_images_and_labels([BATCH_SIZE, 784], [BATCH_SIZE, 1])

yield batch_image, batch_label

return __reader__

# If DataLoader is iterable, use for loop to train the network

def train_iterable(exe, prog, loss, loader):

for _ in range(EPOCH_NUM):

for data in loader():

exe.run(prog, feed=data, fetch_list=[loss])

# If DataLoader is not iterable, use start() and reset() method to control the process

def train_non_iterable(exe, prog, loss, loader):

for _ in range(EPOCH_NUM):

loader.start() # call DataLoader.start() before each epoch starts

try:

while True:

exe.run(prog, fetch_list=[loss])

except fluid.core.EOFException:

loader.reset() # call DataLoader.reset() after catching EOFException

def set_data_source(loader, places):

if DATA_FORMAT == ‘sample_generator’:

loader.set_sample_generator(sample_generator_creator(), batch_size=BATCH_SIZE, drop_last=True, places=places)

elif DATA_FORMAT == ‘sample_list_generator’:

loader.set_sample_list_generator(sample_list_generator_creator(), places=places)

elif DATA_FORMAT == ‘batch_generator’:

loader.set_batch_generator(batch_generator_creator(), places=places)

else:

raise ValueError(‘Unsupported data format’)

image = fluid.layers.data(name=’image’, shape=[784], dtype=’float32′)

label = fluid.layers.data(name=’label’, shape=[1], dtype=’int64′)

# Define DataLoader

loader = fluid.io.DataLoader.from_generator(feed_list=[image, label], capacity=16, iterable=ITERABLE)

# Define network

loss = simple_net(image, label)

# Set data source of DataLoader

#

# If DataLoader is iterable, places must be given and the number of places must be the same with device number.

# – If you are using GPU, call `fluid.cuda_places()` to get all GPU places.

# – If you are using CPU, call `fluid.cpu_places()` to get all CPU places.

#

# If DataLoader is not iterable, places can be None.

places = fluid.cuda_places() if USE_GPU else fluid.cpu_places()

set_data_source(loader, places)

exe = fluid.Executor(places[0])

exe.run(fluid.default_startup_program())

prog = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name)

if loader.iterable:

train_iterable(exe, prog, loss, loader)

else:

train_non_iterable(exe, prog, loss, loader)

”’

Users can use return_list = True in dygraph mode.

”’

with fluid.dygraph.guard(places[0]):

loader = fluid.io.DataLoader.from_generator(capacity=2, return_list=True)

set_data_source(loader, places[0])

for image, label in loader():

relu = fluid.layers.relu(image)

assert image.shape == [BATCH_SIZE, 784]

assert label.shape == [BATCH_SIZE, 1]

assert relu.shape == [BATCH_SIZE, 784]

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/132199.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • C语言 neutralize函数,三种常用分子模拟软件绍.doc

    C语言 neutralize函数,三种常用分子模拟软件绍.doc三种常用分子模拟软件绍三种常用分子模拟软件介绍一、NAMD  NAMD(NAnoscaleMolecularDynamics)是用于在大规模并行计算机上快速模拟大分子体系的并行分子动力学代码。NAMD用经验力场,如Amber,CHARMM和Dreiding,通过数值求解运动方程计算原子轨迹。  1.软件所能模拟的体系的尺度,如微观,介观或跨尺度等  微观。  是众多md软件中并行处理…

    2022年5月25日
    34
  • springboot上传文件到文件夹

    springboot上传文件到文件夹springboot上传文件至项目当前路径下的文件夹关键代码,之后会分享完整代码到gitee默认上传文件到文件夹/***默认上传文件到文件夹**@paramfolder默认文件夹*@paramfile上传的文件*@return*/privateStringmyfileUp(Stri…

    2022年5月30日
    27
  • pandas—dropna[通俗易懂]

    pandas—dropna[通俗易懂]文章目录1.pd.Series.dropna官方案例2.pd.DataFrame.dropna官方案例1.pd.Series.dropnaSeries.dropna(axis=0,inplace=False,how=None)描述返回删除了缺失值的新Series参数axis:{0or‘index’},default0只有一个轴可以从中删除值inplace:bool,defaultFalse如果为True,则就地修改返回None如果为False,则

    2025年6月3日
    2
  • oracle与mysql的存储区别_存储过程和触发器的区别和联系

    oracle与mysql的存储区别_存储过程和触发器的区别和联系1.创建存储过程语句不同oraclecreateorreplaceprocedureP_ADD_FAC(id_fac_cdINES_FAC_UNIT.FAC_CD%TYPE)asmysqlDROPPROCEDUREIFEXISTS`SD_USER_P_ADD_USR`;createprocedureP_ADD_FAC(id_fac_…

    2025年11月13日
    3
  • DHCP 协议详解

    DHCP 协议详解1DHCP协议1.1DHCP协议理解定义:DHCP:DynamicHostConfigurationProtocol,动态主机配置协议,是一个用于局域网的网络协议,位于OSI模型的应用层,使用UDP协议工作,主要有两个用途:用于内部网或网络服务供应商自动分配IP地址给用户 用于内部网管理员对所有电脑作中央管理作用:动态分配IP地址,过程自动化,终端无需一一…

    2022年5月24日
    34
  • css动画和js动画的优缺点_彼得兔第三季动画片

    css动画和js动画的优缺点_彼得兔第三季动画片大家好,我是小丞同学,一名准大二的前端爱好者这篇文章将欢快的带你了解一下CSS和JS动画的差别愿你忠于自己,热爱生活引言讲到动画,当然是非常有意思的啦,你可以往上滑一下,看看上面的封面图,是不是相当的炫酷,以为我是代码写出来的吗?那当然不可能啊,我这么摸鱼,怎么会为了个封面图上号呢废话不多说,其实上面的动图用代码实现也不会很困难,这个图是用canva做出来的。本文主要讲以下这些内容浏览器渲染流程回流和重绘CSS动画JS动画两者对比

    2022年10月15日
    2

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号