aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/learn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-01-09 11:54:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-09 12:06:09 -0800
commit7ad7e4dfae4344d6b955b5eb61dc4b6bb792f1b3 (patch)
treeac821517862671c0d77bcbd67d1c911860324f67 /tensorflow/examples/learn
parentc71ac2dce6bc73536956ff50e261c73f993af16c (diff)
Move TensorForestEstimator to contrib, since that's where most of its code is and it will not be considered a canned estimator in the near future.
Change: 143989623
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r--tensorflow/examples/learn/random_forest_mnist.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/tensorflow/examples/learn/random_forest_mnist.py b/tensorflow/examples/learn/random_forest_mnist.py
index a34d52275a..6a943eb42e 100644
--- a/tensorflow/examples/learn/random_forest_mnist.py
+++ b/tensorflow/examples/learn/random_forest_mnist.py
@@ -21,25 +21,24 @@ import argparse
import sys
import tempfile
-import tensorflow as tf
-
# pylint: disable=g-backslash-continuation
from tensorflow.contrib.learn.python.learn\
import metric_spec
-from tensorflow.contrib.learn.python.learn.estimators\
- import random_forest
from tensorflow.contrib.tensor_forest.client\
import eval_metrics
+from tensorflow.contrib.tensor_forest.client\
+ import random_forest
from tensorflow.contrib.tensor_forest.python\
import tensor_forest
from tensorflow.examples.tutorials.mnist import input_data
+from tensorflow.python.platform import app
FLAGS = None
def build_estimator(model_dir):
"""Build an estimator."""
- params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
+ params = tensor_forest.ForestHParams(
num_classes=10, num_features=784,
num_trees=FLAGS.num_trees, max_nodes=FLAGS.max_nodes)
graph_builder_class = tensor_forest.RandomForestGraphs
@@ -129,4 +128,4 @@ if __name__ == '__main__':
help='If true, use training loss as termination criteria.'
)
FLAGS, unparsed = parser.parse_known_args()
- tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)