tensorflow estimator 实践

tensorflow estimator 实践nbsp nbsp 本文以 mnist 数据集为例 estimator 通常是和 tf 的 dataset 一起使用 故先制作 tfrecord 文件 在使用 estimator 进行测试 文章结构 1 文件目录 2 制作 tfrecord 文件 3 使用 estimator 训练模型 4 tf estimator Estimator 参数介绍 文件目录 nbsp nbsp nbsp nbsp nbsp nbsp data 目录下存放 mnist 数据集

    本文以mnist数据集为例。estimator通常是和tf的dataset一起使用,故先制作tfrecord文件,在使用estimator进行测试。

文章结构:

1.文件目录

2. 制作tfrecord文件

3.使用estimator训练模型

4.tf.estimator.Estimator()参数介绍:


文件目录: 

     tensorflow estimator 实践

   data目录下存放mnist数据集,并且tfrecord文件也将存放在data目录下。

   result目录下保存训练的模型

   make_record.py: 制作tfreocrd文件

   mnist.py: 使用estimator训练模型、test结果

 

制作tfrecord文件:

from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf mnist = input_data.read_data_sets("./data/", one_hot=True) def tf_record(data, labels, path): #data是mnist图片 #labels是图片的标签 # path是tfrecord保存的路劲 writer=tf.python_io.TFRecordWriter(path) for example, label in zip(data, labels): tf_example=tf.train.Example( features=tf.train.Features( feature={ "image":tf.train.Feature(float_list=tf.train.FloatList(value=list(example))), "label":tf.train.Feature(float_list=tf.train.FloatList(value=list(label))) } ) ) writer.write(tf_example.SerializeToString()) writer.close()

使用estimator训练模型:

 

from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf from make_record import tf_record flags = tf.flags FLAGS = flags.FLAGS tf.logging.set_verbosity(tf.logging.INFO) # 模型部分,使用2个卷积+全连接进行建模 def create_model(images): def weight_variable(shape): initial = tf.truncated_normal(shape, stddev=0.1) return tf.Variable(initial) def bias_variable(shape): initial = tf.constant(0.1, shape=shape) return tf.Variable(initial) def conv2d(x, W): return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') def max_poo_2x2(x): return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') image = tf.reshape(images, [-1, 28, 28, 1]) W_conv1 = weight_variable([5, 5, 1, 32]) b_conv1 = bias_variable([32]) h_conv1 = tf.nn.relu(conv2d(image, W_conv1) + b_conv1) h_pool1 = max_poo_2x2(h_conv1) W_conv2 = weight_variable([5, 5, 32, 64]) b_conv2 = bias_variable([64]) h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) h_pool2 = max_poo_2x2(h_conv2) W_fc1 = weight_variable([7 * 7 * 64, 1024]) b_fc1 = bias_variable([1024]) h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) W_fc2 = weight_variable([1024, 10]) b_fc2 = bias_variable([10]) prediction = tf.nn.softmax(tf.matmul(h_fc1, W_fc2) + b_fc2) return prediction # 将词函数传递给estimator的model_fn参数 # 关于estimator中的各个参数作用,见后 def model_fn(features,mode): is_training = (mode == tf.estimator.ModeKeys.TRAIN) prediction = create_model(images=images) if is_training: labels = features['label'] loss = tf.reduce_mean(-tf.reduce_sum(labels * tf.log(prediction))) train_op = tf.train.AdamOptimizer(1e-4).minimize(loss, global_step=tf.train.get_or_create_global_step()) return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op) if mode == tf.estimator.ModeKeys.PREDICT: labels = features['label'] return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions={ "ground_truth":tf.argmax(labels, axis=1), "prediction":tf.argmax(prediction, axis=1) }) # 此函数返回一个函数, 其返回train的dataset def input_fn_builder(input_file): def input_fn(): def pase(record): keys_to_features = { "image": tf.FixedLenFeature([784], tf.float32), "label": tf.FixedLenFeature([10], tf.float32) } parsed = tf.parse_single_example(record, keys_to_features) return parsed d = tf.data.TFRecordDataset(input_file) return d.map(pase).shuffle(buffer_size=100).batch(32).repeat(10) return input_fn # 此函数返回一个函数, 其返回test的dataset def input_fn_builder_test(input_file): def input_fn(): def pase(record): keys_to_features = { "image": tf.FixedLenFeature([784], tf.float32), "label": tf.FixedLenFeature([10], tf.float32) } parsed = tf.parse_single_example(record, keys_to_features) return parsed d = tf.data.TFRecordDataset(input_file) return d.map(pase).shuffle(buffer_size=100).batch(32).repeat(1) return input_fn def main(_): tf.logging.set_verbosity(tf.logging.INFO) tf.logging.info("read mnist dataset") mnist = input_data.read_data_sets('./data', one_hot=True) tf.logging.info("make train record") train = mnist.train.images labels = mnist.train.labels tf_record(train, labels, "./data/train_record") tf.logging.info("make test record") test = mnist.test.images labels = mnist.test.labels tf_record(test, labels, "./data/test_record") session_config = tf.ConfigProto(log_device_placement=True) session_config.gpu_options.per_process_gpu_memory_fraction = 0.5 # 运行配置,如空置显存,训练多少步(这里是2000)保存一次。 run_config = tf.estimator.RunConfig(session_config=session_config,,save_checkpoints_steps=2000) # model_dir是模型的保存路径 # model_fn是模型函数 # config是运行配置 estimator = tf.estimator.Estimator( model_dir="./result", model_fn=model_fn, config=run_config, ) input_fn = input_fn_builder("./data/train_record") input_fn_test = input_fn_builder_test("./data/test_record") # 训练 estimator.train(input_fn=input_fn) all_results = [] # 测试 for result in estimator.predict(input_fn_test, yield_single_examples=True): if len(all_results) % 1000 == 0: tf.logging.info("Processing example: %d" % (len(all_results))) label = int(result["ground_truth"]) prediction = int(result['prediction']) all_results.append([label, prediction]) acc = 0. for idx, result in enumerate(all_results): if idx == 0: print(type(result)) print(type(all_results)) print(len(all_results)) print(len(result)) if result[0] == result[1]: acc = acc + 1 print(acc/len(all_results)) if __name__ == "__main__": tf.app.run() 

结果:

tensorflow estimator 实践

模型:

tensorflow estimator 实践

tf.estimator.Estimator()参数介绍: 

这个只能自己看了

class Estimator(object): """Estimator class to train and evaluate TensorFlow models. The `Estimator` object wraps a model which is specified by a `model_fn`, which, given inputs and a number of other parameters, returns the ops necessary to perform training, evaluation, or predictions. All outputs (checkpoints, event files, etc.) are written to `model_dir`, or a subdirectory thereof. If `model_dir` is not set, a temporary directory is used. The `config` argument can be passed `RunConfig` object containing information about the execution environment. It is passed on to the `model_fn`, if the `model_fn` has a parameter named "config" (and input functions in the same manner). If the `config` parameter is not passed, it is instantiated by the `Estimator`. Not passing config means that defaults useful for local execution are used. `Estimator` makes config available to the model (for instance, to allow specialization based on the number of workers available), and also uses some of its fields to control internals, especially regarding checkpointing. The `params` argument contains hyperparameters. It is passed to the `model_fn`, if the `model_fn` has a parameter named "params", and to the input functions in the same manner. `Estimator` only passes params along, it does not inspect it. The structure of `params` is therefore entirely up to the developer. None of `Estimator`'s methods can be overridden in subclasses (its constructor enforces this). Subclasses should use `model_fn` to configure the base class, and may add methods implementing specialized functionality. @compatibility(eager) Estimators are not compatible with eager execution. @end_compatibility """ def __init__(self, model_fn, model_dir=None, config=None, params=None, warm_start_from=None): """Constructs an `Estimator` instance. See @{$estimators} for more information. To warm-start an `Estimator`: python estimator = tf.estimator.DNNClassifier( feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], hidden_units=[1024, 512, 256], warm_start_from="/path/to/checkpoint/dir") For more details on warm-start configuration, see @{tf.estimator.WarmStartSettings$WarmStartSettings}. Args: model_fn: Model function. Follows the signature: * Args: * `features`: This is the first item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. This should be a single `Tensor` or `dict` of same. * `labels`: This is the second item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. This should be a single `Tensor` or `dict` of same (for multi-head models). If mode is `ModeKeys.PREDICT`, `labels=None` will be passed. If the `model_fn`'s signature does not accept `mode`, the `model_fn` must still be able to handle `labels=None`. * `mode`: Optional. Specifies if this training, evaluation or prediction. See `ModeKeys`. * `params`: Optional `dict` of hyperparameters. Will receive what is passed to Estimator in `params` parameter. This allows to configure Estimators from hyper parameter tuning. * `config`: Optional configuration object. Will receive what is passed to Estimator in `config` parameter, or the default `config`. Allows updating things in your `model_fn` based on configuration such as `num_ps_replicas`, or `model_dir`. * Returns: `EstimatorSpec` model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. If `PathLike` object, the path will be resolved. If `None`, the model_dir in `config` will be used if set. If both are set, they must be same. If both are `None`, a temporary directory will be used. config: Configuration object. params: `dict` of hyper parameters that will be passed into `model_fn`. Keys are names of parameters, values are basic python types. warm_start_from: Optional string filepath to a checkpoint to warm-start from, or a `tf.estimator.WarmStartSettings` object to fully configure warm-starting. If the string filepath is provided instead of a `WarmStartSettings`, then all variables are warm-started, and it is assumed that vocabularies and Tensor names are unchanged. Raises: RuntimeError: If eager execution is enabled. ValueError: parameters of `model_fn` don't match `params`. ValueError: if this is called via a subclass and if that class overrides a member of `Estimator`. """

若有问题欢迎评论指出!!转载请标明地址:https://mp.csdn.net/postedit/

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/202881.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月19日 下午11:05
下一篇 2026年3月19日 下午11:06


相关推荐

  • php把int转string,如何在php中实现int转string

    php把int转string,如何在php中实现int转string如何在 php 中实现 int 转 string 发布时间 2020 07 2009 22 45 来源 亿速云阅读 83 作者 Leah 如何在 php 中实现 int 转 string 针对这个问题 这篇文章详细介绍了相对应的分析和解答 希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法 php 中 int 转 string 的方法 首先将 int 转换为 Integer 类型 然后再调用 toString 方法即可 代码为 Intr

    2026年3月26日
    2
  • Android 指纹认证

    Android 指纹认证安卓指纹认证使用智能手机触摸传感器对用户进行身份验证 AndroidMarsh 棉花糖 提供了一套 API 使用户很容易使用触摸传感器 在 AndroidMarsh 之前访问触摸传感器的方法不是标准的 本文地址 http wuyudong com 2016 12 15 3146 html 转载请注明出处 使用安卓指纹认证有几个好处 1 更快更容易使用 2 安

    2026年3月18日
    1
  • 《Java核心技术 卷1》「建议收藏」

    《Java核心技术 卷1》「建议收藏」<1>静态字段和静态方法classEmployee{privatestaticintnextId=1;privateintid;….}每一个Employee对象都有一个自己的id字段,但是这个类的所有实例将共享一个nextId字段,换句话说,如果有1000个Employee类对象,则有1000个实例字段id,分别对应一个对象,但是只有一个静态字段nextId,即使没有Employee对象,静态字段nextId也存在,它属于类,…

    2022年7月8日
    21
  • portainer添加mysql

    portainer添加mysql自己安装 mysql 的时候老是遇到问题 比如端口号没写 环境变量没写 就是想不到去 dockerhub 上去看一看 介绍两种方式安装 mysql 方式一 container Image 端口号环境变量最主要是用环境变量设置用户名和密码 MYSQL ROOT PASSWORD 比如 name MYSQL ROOT PASSWORDvalu 方式

    2026年3月17日
    3
  • 一个完整的、全面k8s化的集群稳定架构(值得借鉴)

    点击上方“全栈程序员社区”,星标公众号 重磅干货,第一时间送达 作者:紫色飞猪 cnblogs.com/zisefeizhu/p/13692782.html 前言 我司的集群时刻处…

    2021年6月28日
    131
  • hdu 1394

    hdu 1394

    2022年1月31日
    52

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号