diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-01-09 11:54:28 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-09 12:06:09 -0800 |
commit | 7ad7e4dfae4344d6b955b5eb61dc4b6bb792f1b3 (patch) | |
tree | ac821517862671c0d77bcbd67d1c911860324f67 /tensorflow/examples/learn | |
parent | c71ac2dce6bc73536956ff50e261c73f993af16c (diff) |
Move TensorForestEstimator to contrib, since that's where most of its code is and it will not be considered a canned estimator in the near future.
Change: 143989623
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r-- | tensorflow/examples/learn/random_forest_mnist.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/tensorflow/examples/learn/random_forest_mnist.py b/tensorflow/examples/learn/random_forest_mnist.py index a34d52275a..6a943eb42e 100644 --- a/tensorflow/examples/learn/random_forest_mnist.py +++ b/tensorflow/examples/learn/random_forest_mnist.py @@ -21,25 +21,24 @@ import argparse import sys import tempfile -import tensorflow as tf - # pylint: disable=g-backslash-continuation from tensorflow.contrib.learn.python.learn\ import metric_spec -from tensorflow.contrib.learn.python.learn.estimators\ - import random_forest from tensorflow.contrib.tensor_forest.client\ import eval_metrics +from tensorflow.contrib.tensor_forest.client\ + import random_forest from tensorflow.contrib.tensor_forest.python\ import tensor_forest from tensorflow.examples.tutorials.mnist import input_data +from tensorflow.python.platform import app FLAGS = None def build_estimator(model_dir): """Build an estimator.""" - params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams( + params = tensor_forest.ForestHParams( num_classes=10, num_features=784, num_trees=FLAGS.num_trees, max_nodes=FLAGS.max_nodes) graph_builder_class = tensor_forest.RandomForestGraphs @@ -129,4 +128,4 @@ if __name__ == '__main__': help='If true, use training loss as termination criteria.' ) FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) + app.run(main=main, argv=[sys.argv[0]] + unparsed) |