aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/learn
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2017-10-19 16:16:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-19 16:20:20 -0700
commit7a253f3da99c3692d464a8dd95d8280d4cd8973a (patch)
tree8e5aba71e9d6457ab8543841002cb1916e182dc0 /tensorflow/examples/learn
parentbc93dcbd9f7b445c5f6f0d1c8f597324d412a76a (diff)
Fix random_forest_mnist.py and eliminate a contrib.learn reference to skcompat.
PiperOrigin-RevId: 172815173
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r--tensorflow/examples/learn/random_forest_mnist.py65
1 files changed, 36 insertions, 29 deletions
diff --git a/tensorflow/examples/learn/random_forest_mnist.py b/tensorflow/examples/learn/random_forest_mnist.py
index 3c09990ea1..72c935cdae 100644
--- a/tensorflow/examples/learn/random_forest_mnist.py
+++ b/tensorflow/examples/learn/random_forest_mnist.py
@@ -1,4 +1,4 @@
- # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,18 +21,14 @@ import argparse
import sys
import tempfile
-# pylint: disable=g-backslash-continuation
-from tensorflow.contrib.learn.python.learn\
- import metric_spec
-from tensorflow.contrib.learn.python.learn.estimators\
- import estimator
-from tensorflow.contrib.tensor_forest.client\
- import eval_metrics
-from tensorflow.contrib.tensor_forest.client\
- import random_forest
-from tensorflow.contrib.tensor_forest.python\
- import tensor_forest
+import numpy
+
+from tensorflow.contrib.learn.python.learn import metric_spec
+from tensorflow.contrib.tensor_forest.client import eval_metrics
+from tensorflow.contrib.tensor_forest.client import random_forest
+from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.examples.tutorials.mnist import input_data
+from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.platform import app
FLAGS = None
@@ -41,16 +37,15 @@ FLAGS = None
def build_estimator(model_dir):
"""Build an estimator."""
params = tensor_forest.ForestHParams(
- num_classes=10, num_features=784,
- num_trees=FLAGS.num_trees, max_nodes=FLAGS.max_nodes)
+ num_classes=10,
+ num_features=784,
+ num_trees=FLAGS.num_trees,
+ max_nodes=FLAGS.max_nodes)
graph_builder_class = tensor_forest.RandomForestGraphs
if FLAGS.use_training_loss:
graph_builder_class = tensor_forest.TrainingLossForest
- # Use the SKCompat wrapper, which gives us a convenient way to split
- # in-memory data like MNIST into batches.
- return estimator.SKCompat(random_forest.TensorForestEstimator(
- params, graph_builder_class=graph_builder_class,
- model_dir=model_dir))
+ return random_forest.TensorForestEstimator(
+ params, graph_builder_class=graph_builder_class, model_dir=model_dir)
def train_and_eval():
@@ -62,18 +57,30 @@ def train_and_eval():
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=False)
- est.fit(x=mnist.train.images, y=mnist.train.labels,
- batch_size=FLAGS.batch_size)
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'images': mnist.train.images},
+ y=mnist.train.labels.astype(numpy.int32),
+ batch_size=FLAGS.batch_size,
+ num_epochs=None,
+ shuffle=True)
+ est.fit(input_fn=train_input_fn, steps=None)
metric_name = 'accuracy'
- metric = {metric_name:
- metric_spec.MetricSpec(
- eval_metrics.get_metric(metric_name),
- prediction_key=eval_metrics.get_prediction_key(metric_name))}
-
- results = est.score(x=mnist.test.images, y=mnist.test.labels,
- batch_size=FLAGS.batch_size,
- metrics=metric)
+ metric = {
+ metric_name:
+ metric_spec.MetricSpec(
+ eval_metrics.get_metric(metric_name),
+ prediction_key=eval_metrics.get_prediction_key(metric_name))
+ }
+
+ test_input_fn = numpy_io.numpy_input_fn(
+ x={'images': mnist.test.images},
+ y=mnist.test.labels.astype(numpy.int32),
+ num_epochs=1,
+ batch_size=FLAGS.batch_size,
+ shuffle=False)
+
+ results = est.evaluate(input_fn=test_input_fn, metrics=metric)
for key in sorted(results):
print('%s: %s' % (key, results[key]))