diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-02-23 10:45:58 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-23 11:23:31 -0800 |
commit | 7b2b8cc0b221cb5ec33e7e8d8a856d7453eaba02 (patch) | |
tree | 3a844736bdae1c8bbebc70512cb6dbc5251b53fc /tensorflow/examples/learn | |
parent | e609117c8ac28c0fb60e76dc745a7333d1a558ca (diff) |
Refactor TensorForestEstimator to inherit from tf.learn Estimator. This changes the interface a bit:
- TensorForestLossHook should be passed to fit, it isn't included by default
- predict returns an iterable of dicts that contains probabilities and predictions.
- in-memory data sets (x=, y=) should now wrap TensorForestEstimator with estimator.SKCompat, instead of using the unified interface that it previously supplied.
Change: 148362051
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r-- | tensorflow/examples/learn/random_forest_mnist.py | 27 |
1 files changed, 13 insertions, 14 deletions
diff --git a/tensorflow/examples/learn/random_forest_mnist.py b/tensorflow/examples/learn/random_forest_mnist.py index 6a943eb42e..3c09990ea1 100644 --- a/tensorflow/examples/learn/random_forest_mnist.py +++ b/tensorflow/examples/learn/random_forest_mnist.py @@ -1,4 +1,4 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. + # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,6 +24,8 @@ import tempfile # pylint: disable=g-backslash-continuation from tensorflow.contrib.learn.python.learn\ import metric_spec +from tensorflow.contrib.learn.python.learn.estimators\ + import estimator from tensorflow.contrib.tensor_forest.client\ import eval_metrics from tensorflow.contrib.tensor_forest.client\ @@ -44,9 +46,11 @@ def build_estimator(model_dir): graph_builder_class = tensor_forest.RandomForestGraphs if FLAGS.use_training_loss: graph_builder_class = tensor_forest.TrainingLossForest - return random_forest.TensorForestEstimator( + # Use the SKCompat wrapper, which gives us a convenient way to split + # in-memory data like MNIST into batches. + return estimator.SKCompat(random_forest.TensorForestEstimator( params, graph_builder_class=graph_builder_class, - model_dir=model_dir) + model_dir=model_dir)) def train_and_eval(): @@ -54,17 +58,12 @@ def train_and_eval(): model_dir = tempfile.mkdtemp() if not FLAGS.model_dir else FLAGS.model_dir print('model directory = %s' % model_dir) - estimator = build_estimator(model_dir) - - # TensorForest's loss hook allows training to terminate early if the - # forest is no longer growing. - early_stopping_rounds = 100 - monitor = random_forest.TensorForestLossHook(early_stopping_rounds) + est = build_estimator(model_dir) mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=False) - estimator.fit(x=mnist.train.images, y=mnist.train.labels, - batch_size=FLAGS.batch_size, monitors=[monitor]) + est.fit(x=mnist.train.images, y=mnist.train.labels, + batch_size=FLAGS.batch_size) metric_name = 'accuracy' metric = {metric_name: @@ -72,9 +71,9 @@ def train_and_eval(): eval_metrics.get_metric(metric_name), prediction_key=eval_metrics.get_prediction_key(metric_name))} - results = estimator.evaluate(x=mnist.test.images, y=mnist.test.labels, - batch_size=FLAGS.batch_size, - metrics=metric) + results = est.score(x=mnist.test.images, y=mnist.test.labels, + batch_size=FLAGS.batch_size, + metrics=metric) for key in sorted(results): print('%s: %s' % (key, results[key])) |