tf.estimator.train_and_evaluate 详解

tf.estimator.train_and_evaluate 详解tf estimator train and evaluate 是 TensorFlow1 4 0 版中引入的 API 根据官方文档的内容 其应该是用来替代 tf contrib learn Experiment 的 1 tf estimator train and evaluate 简介字面理解这个 API 就是用来 train 然后 evaluate 一个 Estimator 的 函数

TensorFlow 版本:1.11.0

在 TensorFlow 1.4 版本中,Google 新引入了一个新 API:tf.estimator.train_and_evaluate。提出这个 API 的目的是:代替 tf.contrib.learn.Experiment

1. tf.estimator.train_and_evaluate 简介

train_and_evaluate API 用来 train 然后 evaluate 一个 Estimator。调用方式如下:

tf.estimator.train_and_evaluate( estimator, train_spec, eval_spec ) 

这个函数除了 train 和 evaluate 之外,还可选的提供了模型的导出功能,这样就可以把一个训练好的模型直接转交给业务部门来使用了,可以算是“产学研”一条龙服务了。

该函数的参数有三个:

  • estimator:一个 Estimator 实例。
  • train_spec:一个 TrainSpec 实例。用来配置训练过程。
  • eval_spec:一个 EvalSpec 实例。用来配置评估过程、(可选)模型的导出。

该函数的返回值有一个:

  • Estimator.evaluate 的结果 及 前面指定的 ExportStrategy 的输出结果。当前,尚未定义分布式训练模式的返回值。

实际上,如果直接使用 Estimator API,完成 train 和 evaluate 已经是很简单的任务了,为什么我们还要使用 train_and_evaluate 这个函数呢?按官方文档的说法:这个函数可以保证 本地 和 分布式 环境下行为的一致性。也就是说,使用 Estimatortrain_and_evaluate 编写的程序同时支持本地、集群上的训练,而不需要修改任何代码。可以想像一下,在完成了本地 CPU 训练的测试之后,直接 push 到 Cloud ML Engine 上,分分钟完成一个模型的训练,甚至还可以直接使用 TPU 集群(只要你保证模型里的 op 都是对 TPU 兼容的),多么方便的一个工具啊!

这个函数默认的分布式策略是:parameter server-based between-graph replication。对于其它的分布式策略的使用,可以参照 DistributionStrategies 。TensorFlow 关于分布式的官方文档见 Distributed TensorFlow

当然,方便的背后一般都有代价。为了保证代码在本地和集群上都可以正常终止,所以只能使用 Estimator 的 max_steps 参数设定终止条件。所以,如果想使用别的方式终止训练,可能就需要一些“技巧”了。

2. 参数说明

上面我们已经知道 train_and_evaluate 有三个参数,第一个先放在一边,因为这个参数就是一个 Estimator 的实例。我们先来看一下另外两个参数:

2.1 train_spec 参数

train_spec 参数接收一个 tf.estimator.TrainSpec 实例。

# TrainSpec的参数 __new__( cls, # 这个参数不用指定,忽略即可。 input_fn, max_steps=None, hooks=None ) 

其中:

  • input_fn: 参数用来指定数据输入。
  • max_steps: 参数用来指定训练的最大步数,这是训练的唯一终止条件。
  • hooks: 参数用来挂一些 tf.train.SessionRunHook,用来在 session 运行的时候做一些额外的操作,比如记录一些 TensorBoard 日志什么的。

2.2 eval_spec 参数

eval_spec 参数接收一个 tf.estimator.EvalSpec 实例。相比 TrainSpecEvalSpec 的参数多很多。因为 EvalSpec 不仅可以指定评估过程,还可以指定导出模型的功能(可选)。

__new__( cls, # 这个参数不用指定,忽略即可。 input_fn, steps=100, # 评估的迭代步数,如果为None,则在整个数据集上评估。 name=None, hooks=None, exporters=None, start_delay_secs=120, throttle_secs=600 ) 

其中:

  • input_fn: 含义同2.1。
  • steps: 用来指定评估的迭代步数,如果为None,则在整个数据集上评估。
  • name:如果要在多个数据集上进行评估,通过 name 参数可以保证不同数据集上的评估日志保存在不同的文件夹中,从而区分不同数据集上的评估日志。
    不同的评估日志保存在独立的文件夹中,在 TensorBoard 中从而独立的展现。

  • hooks:含义同2.1
  • exporters:一个 tf.estimator.export 模块中的类的实例。
  • start_delay_secs:调用 train_and_evaluate 函数后,多少秒之后开始评估。第一次评估发生在 start_delay_secs + throttle_secs 秒后。
  • throttle_secs:多少秒后又开始评估,如果没有新的 checkpoints 产生,则不评估,所以这个间隔是最小值。

3. 非分布式实例

# Set up feature columns. categorial_feature_a = categorial_column_with_hash_bucket(...) categorial_feature_a_emb = embedding_column( categorical_column=categorial_feature_a, ...) ... # other feature columns estimator = DNNClassifier( feature_columns=[categorial_feature_a_emb, ...], hidden_units=[1024, 512, 256]) # Or set up the model directory # estimator = DNNClassifier( # config=tf.estimator.RunConfig( # model_dir='/my_model', save_summary_steps=100), # feature_columns=[categorial_feature_a_emb, ...], # hidden_units=[1024, 512, 256]) # Input pipeline for train and evaluate. def train_input_fn(): # returns x, y # please shuffle the data. pass def eval_input_fn(): # returns x, y pass train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000) eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) 

注意:在当前的实现中,estimator.evaluate 将被调用多次。这意味着在每次评估时,会重新创建评估图(包括eval_input_fn)。estimator.train 只会被调用一次

4. 分布式实例

上面的代码可以在不加修改的情况下用于分布式训练,但请确保所有 workerRunConfig.model_dir 设置为相同的目录(例如,一个所有 worker 都可以读写的共享文件系统。唯一需要做的就是正确得设置所有 worker 的环境变量 TF_CONFIG

设置环境变量的方式会随系统而变化。例如,在 Linux 上,设置环境变量的方式如下($ 是命令提示符):

$ TF_CONFIG=' 
   
     ' 
    python train_model.py 

训练的集群配置如下:

cluster = { 
   "chief": ["host0:2222"], "worker": ["host1:2222", "host2:2222", "host3:2222"], "ps": ["host4:2222", "host5:2222"]} 

chief training worker(必须有,且只能有一个)的 TF_CONFIG 应该被设置为:

# This should be a JSON string, which is set as environment variable. Usually # the cluster manager handles that. TF_CONFIG='{ 
    "cluster": { 
    "chief": ["host0:2222"], "worker": ["host1:2222", "host2:2222", "host3:2222"], "ps": ["host4:2222", "host5:2222"] }, "task": { 
   "type": "chief", "index": 0} }' 

注意:chief worker 与其他 non-chief training worker 一样,也进行模型的训练 job。chief worker 除了进行模型训练,还管理一些其它 work(例如:checkpoint 保存、恢复,写入 summaries 等)。

non-chief training worker(可选,可以有多个)的 TF_CONFIG 应该被设置为:

# This should be a JSON string, which is set as environment variable. Usually # the cluster manager handles that. TF_CONFIG='{ 
    "cluster": { 
    "chief": ["host0:2222"], "worker": ["host1:2222", "host2:2222", "host3:2222"], "ps": ["host4:2222", "host5:2222"] }, "task": { 
   "type": "worker", "index": 0} }' 

上面的 task.index 表示 worker 的编号。本例中,有三个 non-chief training worker,所以编号为 0,1,2。

parameter server(可以是多个)的 TF_CONFIG 应该被设置为:

# This should be a JSON string, which is set as environment variable. Usually # the cluster manager handles that. TF_CONFIG='{ 
    "cluster": { 
    "chief": ["host0:2222"], "worker": ["host1:2222", "host2:2222", "host3:2222"], "ps": ["host4:2222", "host5:2222"] }, "task": { 
   "type": "ps", "index": 0} }' 

由于例子中参数服务器的个数为两个,所以 task.index 编号分别为 0,1。

评估的集群配置如下:
评估 task 的 TF_CONFIG 如下所示。评估是一个特殊的 task,该 task 不是训练集群的一部分。有可能只有一个。该 task 被用于模型评估。

# This should be a JSON string, which is set as environment variable. Usually # the cluster manager handles that. TF_CONFIG='{ 
    "cluster": { 
    "chief": ["host0:2222"], "worker": ["host1:2222", "host2:2222", "host3:2222"], "ps": ["host4:2222", "host5:2222"] }, "task": { 
   "type": "evaluator", "index": 0} }' 

distributeexperimental_distribute.train_distributeexperimental_distribute.remote_cluster 被设置时,这个方法将开始在本机运行一个 client,该 client 将连接到 remote_cluster,以进行训练和评估。

参考文档:

  1. tf.estimator.train_and_evaluate 官方文档(英文)
  2. tf.estimator.train_and_evaluate 试用
  3. 推荐一个 Estimator+Experiment 的实例:tensorflow/models里的cifar10_estimator

注意:欢迎大家转载,但需注明出处哦
\quad \quad    \; https://blog.csdn.net/u0/article/details/

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/214843.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月18日 下午3:19
下一篇 2026年3月18日 下午3:19


相关推荐

  • Claude API 全方位调用指南:轻松调用 Claude 4 Sonnet 与Claude 4 Opus 最新模型

    Claude API 全方位调用指南:轻松调用 Claude 4 Sonnet 与Claude 4 Opus 最新模型

    2026年3月16日
    2
  • scheduleAtFixedRate 和schedule

    scheduleAtFixedRate 和schedule最近整了一个 TimerTask 要求每天定点执行某一任务 code java importjava util Calendar importjava util Date importjava util Timer importjava util TimerTask importjavax servlet ServletConte imp

    2025年8月8日
    4
  • mybatis清空一级缓存_jvm缓存

    mybatis清空一级缓存_jvm缓存#一、前情提要长久以来,对springboot项目中缓存的使用都会有一些争论,一部分人认为缓存就应该具有延时性,即给他设置了10分钟的缓存,就应该10分钟后清理。还有一部分人认为缓存应该具有及时性(或弱及时性),即我设置了缓存后,一旦数据发生变化,缓存需要重新刷新。对于第一种观点,事实上现有的缓存结构就已经满足了,无需我们进行特殊操作,这里我们不做过多讨论。对于第二种观点,事实上现有的缓存结构也能够满足,只不过在加缓存的时候好加,可是在清理缓存的时候,我们需要手动对更新接口进行配置,可是由于项目的.

    2025年12月7日
    5
  • pycharm2021.8.3激活码_在线激活

    (pycharm2021.8.3激活码)这是一篇idea技术相关文章,由全栈君为大家提供,主要知识点是关于2021JetBrains全家桶永久激活码的内容IntelliJ2021最新激活注册码,破解教程可免费永久激活,亲测有效,下面是详细链接哦~https://javaforall.net/100143.html65MJGLILER-eyJsa…

    2022年3月22日
    55
  • File.createTempFile异常「建议收藏」

    错误:File.createtempfilejava.io.winntfilesystem.createfileexclusively(nativemethod)原来是Eclipse默认的JRE不是JDK下的修改为JDK下的jre就可以了转载于:https://www.cnblogs.com/cszzy/archive/2012/12/28/2837790.html…

    2022年4月11日
    107
  • SQL中的替换函数replace()使用

    SQL中的替换函数replace()使用语法 REPLACE string expression string pattern string replacement 参数 string expression 要搜索的字符串表达式 string expression 可以是字符或二进制数据类型 string pattern 是要查找的子字符串 string pattern 可以是字符或二进制数据类型 string pattern

    2026年3月19日
    2

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号