aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/learn
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r--tensorflow/examples/learn/random_forest_mnist.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/tensorflow/examples/learn/random_forest_mnist.py b/tensorflow/examples/learn/random_forest_mnist.py
index a34d52275a..6a943eb42e 100644
--- a/tensorflow/examples/learn/random_forest_mnist.py
+++ b/tensorflow/examples/learn/random_forest_mnist.py
@@ -21,25 +21,24 @@ import argparse
import sys
import tempfile
-import tensorflow as tf
-
# pylint: disable=g-backslash-continuation
from tensorflow.contrib.learn.python.learn\
import metric_spec
-from tensorflow.contrib.learn.python.learn.estimators\
- import random_forest
from tensorflow.contrib.tensor_forest.client\
import eval_metrics
+from tensorflow.contrib.tensor_forest.client\
+ import random_forest
from tensorflow.contrib.tensor_forest.python\
import tensor_forest
from tensorflow.examples.tutorials.mnist import input_data
+from tensorflow.python.platform import app
FLAGS = None
def build_estimator(model_dir):
"""Build an estimator."""
- params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
+ params = tensor_forest.ForestHParams(
num_classes=10, num_features=784,
num_trees=FLAGS.num_trees, max_nodes=FLAGS.max_nodes)
graph_builder_class = tensor_forest.RandomForestGraphs
@@ -129,4 +128,4 @@ if __name__ == '__main__':
help='If true, use training loss as termination criteria.'
)
FLAGS, unparsed = parser.parse_known_args()
- tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)