tf.estimator.Estimator讲解

tf.estimator.Estimator讲解tf estimator Estimator 简单介绍是一个 class 所以需要初始化 作用是用来训练和评价 tensorflow 模型的 Estimator 对象包装由一个名为 model fn 函数指定的模型 model fn 在给定输入和许多其他参数的情况下 返回执行训练 评估或预测所需的操作 所有输出 checkpoints eventfiles etc 等 都写入 model dir 或其

tf.estimator.Estimator

简单介绍

初始化

__init__( model_fn, model_dir=None, config=None, params=None, warm_start_from=None ) ''' Args: model_fn: Model function. Follows the signature: Args: features: 是从 input_fn中返回的词典tensor 或者 单个tensor ;其实质就是模型的输入(以前我们都是用tf.placeholder输入的,这里使用input_fn 函数返回) This is the first item returned from the input_fn labels: 是从 input_fn中返回的词典tensor 或者 单个tensor,注意,如果mode=tf.estimator.ModeKeys.PREDICT(就是在预测的时候), labels将会被设置为None This is the second item returned from the input_fn mode: Optional. Specifies if this training, evaluation or prediction. See tf.estimator.ModeKeys. params: Optional dict of hyperparameters.接受初始化Estimator实例时的参数params config: Optional estimator.RunConfig object.接受初始化Estimator实例时的参数config 或者一个默认的值. Allows setting up things in your model_fn based on configuration such as num_ps_replicas, or model_dir. Returns: tf.estimator.EstimatorSpec 这里一定要注意 返回的是EstimatorSpec实例 model_dir: 输出路径,有关模型的输出的一切东西,全部输出在这里 config: 这个是一个类,是官方固定的配置参数,如果用户觉得,不能满足使用,需要添加自己的参数,可以使用下面的这个参数params params: dict of hyper parameters that will be passed into model_fn. Keys are names of parameters, values are basic python types. warm_start_from: Optional string filepath to a checkpoint or SavedModel to warm-start from, or a tf.estimator.WarmStartSettings object to fully configure warm-starting. If the string filepath is provided instead of a tf.estimator.WarmStartSettings, then all variables are warm-started, and it is assumed that vocabularies and tf.Tensor names are unchanged. ''' 
重点圈出

The config argument can be passed tf.estimator.RunConfig object containing information about the execution environment. It is passed on to the model_fn, if the model_fn has a parameter named “config” (and input functions in the same manner). If the config parameter is not passed, it is instantiated by the Estimator. Not passing config means that defaults useful for local execution are used. Estimator makes config available to the model (for instance, to allow specialization based on the number of workers available), and also uses some of its fields to control internals, especially regarding checkpointing.

The params argument contains hyperparameters. It is passed to the model_fn, if the model_fn has a parameter named “params”, and to the input functions in the same manner. Estimator only passes params along, it does not inspect it. The structure of params is therefore entirely up to the developer.

方法

train 方法

从input_fn 获取数据,用来训练模型

train( input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None ) ''' Args: input_fn: A function that provides input data for training as minibatches. See Premade Estimators for more information. The function should construct and return one of the following: * A tf.data.Dataset object: Outputs of Dataset object must be a tuple (features, labels) with same constraints as below. * A tuple (features, labels): Where features is a tf.Tensor or a dictionary of string feature name to Tensor and labels is a Tensor or a dictionary of string label name to Tensor. Both features and labels are consumed by model_fn. They should satisfy the expectation of model_fn from inputs. hooks: List of tf.train.SessionRunHook subclass instances. Used for callbacks inside the training loop. steps: Number of steps for which to train the model. If None, train forever or train until input_fn generates the tf.errors.OutOfRange error or StopIteration exception. steps works incrementally. If you call two times train(steps=10) then training occurs in total 20 steps. If OutOfRange or StopIteration occurs in the middle, training stops before 20 steps. If you don't want to have incremental behavior please set max_steps instead. If set, max_steps must be None. max_steps: Number of total steps for which to train model. If None, train forever or train until input_fn generates the tf.errors.OutOfRange error or StopIteration exception. If set, steps must be None. If OutOfRange or StopIteration occurs in the middle, training stops before max_steps steps. Two calls to train(steps=100) means 200 training iterations. On the other hand, two calls to train(max_steps=100) means that the second call will not do any iteration since first call did all 100 steps. saving_listeners: list of CheckpointSaverListener objects. Used for callbacks that run immediately before or after checkpoint savings. Returns: self, for chaining. ''' 
主要参数说明

input_fn:是一个为训练提供输入数据的函数(每次提供一个batch_size的数据),其返回的是的格式是(features,labels),正好作为mode_fn的输入,其返回的格式应该是下列之一:

  1. tf.data.Dataset object: Outputs of Dataset object must be a tuple (features, labels)
  2. A tuple (features, labels): Where features is a tf.Tensor or a dictionary of string feature name to Tensor and labels is a Tensor or a dictionary of string label name to Tensor

max_steps:最大训练多少step(也就是训练多少个batch_size),当我们暂停后,继续训练程序会检测目前已经训练的步数是否大于max_steps若大于等于,那么就不会继续训练(On the other hand, two calls to train(max_steps=100) means that the second call will not do any iteration since first call did all 100 steps.

step:会在原来的基础上,继续“增长式”训练,例如你调用了两次train(input_fn,step=10),那么模型就相当于训练了20个迭代

evaluate 方法

evaluate( input_fn, steps=None, hooks=None, checkpoint_path=None, name=None ) ''' Args: input_fn: A function that constructs the input data for evaluation. See Premade Estimators for more information. The function should construct and return one of the following: * A tf.data.Dataset object: Outputs of Dataset object must be a tuple (features, labels) with same constraints as below. * A tuple (features, labels): Where features is a tf.Tensor or a dictionary of string feature name to Tensor and labels is a Tensor or a dictionary of string label name to Tensor. Both features and labels are consumed by model_fn. They should satisfy the expectation of model_fn from inputs. steps: Number of steps for which to evaluate model. If None, evaluates until input_fn raises an end-of-input exception. hooks: List of tf.train.SessionRunHook subclass instances. Used for callbacks inside the evaluation call. checkpoint_path: Path of a specific checkpoint to evaluate. If None, the latest checkpoint in model_dir is used. If there are no checkpoints in model_dir, evaluation is run with newly initialized Variables instead of ones restored from checkpoint. name: Name of the evaluation if user needs to run multiple evaluations on different data sets, such as on training data vs test data. Metrics for different evaluations are saved in separate folders, and appear separately in tensorboard. Returns: A dict containing the evaluation metrics specified in model_fn keyed by name, as well as an entry global_step which contains the value of the global step for which this evaluation was performed. For canned estimators, the dict contains the loss (mean loss per mini-batch) and the average_loss (mean loss per sample). Canned classifiers also return the accuracy. Canned regressors also return the label/mean and the prediction/mean. ''' 

参数说明

 estim_specs=tf.estimator.EstimatorSpec( mode=mode, predictions=pred_classes, loss=loss_op, train_op=train_op, eval_metric_ops={"accuracy":acc_op}) 

中的 eval_metric_ops={“accuracy”:acc_op}),最后会输出类似这种

{'accuracy': 0.9192, 'loss': 0., 'global_step': 1000} 

predict方法

predict( input_fn, predict_keys=None, hooks=None, checkpoint_path=None, yield_single_examples=True ) ''' Args: input_fn: A function that constructs the features. Prediction continues until input_fn raises an end-of-input exception (tf.errors.OutOfRangeError or StopIteration). See Premade Estimators for more information. The function should construct and return one of the following: A tf.data.Dataset object: Outputs of Dataset object must have same constraints as below. features: A tf.Tensor or a dictionary of string feature name to Tensor. features are consumed by model_fn. They should satisfy the expectation of model_fn from inputs. A tuple, in which case the first item is extracted as features. predict_keys: list of str, name of the keys to predict. It is used if the tf.estimator.EstimatorSpec.predictions is a dict. If predict_keys is used then rest of the predictions will be filtered from the dictionary. If None, returns all. hooks: List of tf.train.SessionRunHook subclass instances. Used for callbacks inside the prediction call. checkpoint_path: Path of a specific checkpoint to predict. If None, the latest checkpoint in model_dir is used. If there are no checkpoints in model_dir, prediction is run with newly initialized Variables instead of ones restored from checkpoint. yield_single_examples: If False, yields the whole batch as returned by the model_fn instead of decomposing the batch into individual elements. This is useful if model_fn returns some tensors whose first dimension is not equal to the batch size. ''' 
说明

给定输入,返回在model_fn中指定要输出的内容tf.estimator.EstimatorSpec(mode,predictions=pred_classes)

 .... .... pred_classes=tf.argmax(logits,axis=1) pred_probas=tf.nn.softmax(logits) #PREDICTS if mode==tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode,predictions=pred_classes) ..... ...... 
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/206614.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月19日 下午3:20
下一篇 2026年3月19日 下午3:21


相关推荐

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号