estimator使用

estimator使用一 model fn 函数有 5 个输入参数 features labels mode params config 并输出一个 EstimatorSpe 实例 features input fn 的第一个输出 labels input fn 的第二个输出 mode 操作类型 是训练 预测还是评估 对应 tf estimator ModeKeys EVAL TRAIN PREDICT params 定义 Estimator 实例时传入的 params 属性 config 定义 Estimator 实例时

一、model_fn

函数有5个输入参数features, labels, mode, params, config,并输出一个EstimatorSpec实例;

  • featuresinput_fn的第一个输出。
  • labelsinput_fn的第二个输出。
  • mode:操作类型(是训练、预测还是评估),对应tf.estimator.ModeKeys.EVAL/TRAIN/PREDICT
  • params:定义Estimator实例时传入的params属性。
  • config:定义Estimator实例时传入的config属性。
  • 输出EstimatorSpec实例介绍:
    • 训练时:需要指定losstrain_op
    • 预测时:需要指定predictions
    • 评估时:需要指定lossmetrics

二、实例化Estimator

  • config参数:
    • 可用于设置训练过程中相关操作,主要就是summary/save/logging操作。
    • 可用于设置 tf.Session 的配置,session_config 其实就是 tf.ConfigProto 对象。
    • 可用于设置多GPU训练,即train_distribute变量。
  • params参数:
    • 主要作用就是可以传入model_fn中,帮助实现各类功能。
    • 模型参数:以Faster R-CNN为例,可以选择backbone参数,anchors参数,weight decay参数等。
    • 训练参数:如优化器类型及参数、学习率参数。
    • 性能指标:如选择那些性能指标进行计算等。
  • model_fn
  • model_dir:summary和save的路径
  • configtf.estimator.RunConfig实例
  • params:输入参数,会传输到 model_fn 中。
  • warm_start_from:热启动功能,暂时没碰到做啥用的

三、tf.estimator.Estimator训练、预测、评估

def train(self, input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None):
# input_fn 在`1. 数据集`中介绍 # predict_keys 字符串列表,当EstimatorSpec.predictions是字典时使用 # hooks 一组`tf.train.SessionRunHook`实例,用于完成各种任务 # checkpoint_path ckpt文件的路径(包括ckpt),默认使用`modol_dir`中最新的ckpt文件 def predict(self, input_fn, predict_keys=None, hooks=None, checkpoint_path=None):
# input_fn 在`1. 数据集`中介绍 # steps 评估次数最大值 # hooks 一组`tf.train.SessionRunHook`实例,用于完成各种任务 # checkpoint_path ckpt文件的路径(包括ckpt),默认使用`modol_dir`中最新的ckpt文件 # name 名称,好像用于记录不同数据集上的结果,将评估结果保存到不同文件夹中 def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, name=None):

四、tf.data

dataset = tf.data.Dataset.range(10).batch(6).shuffle(10) dataset = dataset.map(lambda x: x + 10) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with tf.Session() as sess:     for i in range(2):         value = sess.run(next_element)         print(value)

estimator使用

参数说明:

batch:指更新梯度中使用的样本数;

repeat:将数据重复多次,主要用来处理epoch;

shuffle:打乱dataset中的元素;

map:Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的DataSet;

 

 

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/218986.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月17日 下午11:03
下一篇 2026年3月17日 下午11:03


相关推荐

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号