一、model_fn
函数有5个输入参数features, labels, mode, params, config,并输出一个EstimatorSpec实例;
features:input_fn的第一个输出。labels:input_fn的第二个输出。mode:操作类型(是训练、预测还是评估),对应tf.estimator.ModeKeys.EVAL/TRAIN/PREDICT。params:定义Estimator实例时传入的params属性。config:定义Estimator实例时传入的config属性。- 输出
EstimatorSpec实例介绍:- 训练时:需要指定
loss和train_op。 - 预测时:需要指定
predictions。 - 评估时:需要指定
loss和metrics
- 训练时:需要指定
二、实例化Estimator
config参数:- 可用于设置训练过程中相关操作,主要就是summary/save/logging操作。
- 可用于设置
tf.Session的配置,session_config其实就是tf.ConfigProto对象。 - 可用于设置多GPU训练,即
train_distribute变量。
params参数:- 主要作用就是可以传入
model_fn中,帮助实现各类功能。 - 模型参数:以Faster R-CNN为例,可以选择backbone参数,anchors参数,weight decay参数等。
- 训练参数:如优化器类型及参数、学习率参数。
- 性能指标:如选择那些性能指标进行计算等。
- 主要作用就是可以传入
model_fnmodel_dir:summary和save的路径config:tf.estimator.RunConfig实例params:输入参数,会传输到model_fn中。warm_start_from:热启动功能,暂时没碰到做啥用的
三、tf.estimator.Estimator训练、预测、评估
def train(self, input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None):
# input_fn 在`1. 数据集`中介绍 # predict_keys 字符串列表,当EstimatorSpec.predictions是字典时使用 # hooks 一组`tf.train.SessionRunHook`实例,用于完成各种任务 # checkpoint_path ckpt文件的路径(包括ckpt),默认使用`modol_dir`中最新的ckpt文件 def predict(self, input_fn, predict_keys=None, hooks=None, checkpoint_path=None):
# input_fn 在`1. 数据集`中介绍 # steps 评估次数最大值 # hooks 一组`tf.train.SessionRunHook`实例,用于完成各种任务 # checkpoint_path ckpt文件的路径(包括ckpt),默认使用`modol_dir`中最新的ckpt文件 # name 名称,好像用于记录不同数据集上的结果,将评估结果保存到不同文件夹中 def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, name=None):
四、tf.data
dataset = tf.data.Dataset.range(10).batch(6).shuffle(10) dataset = dataset.map(lambda x: x + 10) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with tf.Session() as sess: for i in range(2): value = sess.run(next_element) print(value)

参数说明:
batch:指更新梯度中使用的样本数;
repeat:将数据重复多次,主要用来处理epoch;
shuffle:打乱dataset中的元素;
map:Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的DataSet;
发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/218986.html原文链接:https://javaforall.net
