diff options
author | Igor Saprykin <isaprykin@google.com> | 2017-10-19 16:16:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-19 16:20:20 -0700 |
commit | 7a253f3da99c3692d464a8dd95d8280d4cd8973a (patch) | |
tree | 8e5aba71e9d6457ab8543841002cb1916e182dc0 /tensorflow/examples/learn | |
parent | bc93dcbd9f7b445c5f6f0d1c8f597324d412a76a (diff) |
Fix random_forest_mnist.py and eliminate a contrib.learn reference to skcompat.
PiperOrigin-RevId: 172815173
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r-- | tensorflow/examples/learn/random_forest_mnist.py | 65 |
1 files changed, 36 insertions, 29 deletions
diff --git a/tensorflow/examples/learn/random_forest_mnist.py b/tensorflow/examples/learn/random_forest_mnist.py index 3c09990ea1..72c935cdae 100644 --- a/tensorflow/examples/learn/random_forest_mnist.py +++ b/tensorflow/examples/learn/random_forest_mnist.py @@ -1,4 +1,4 @@ - # Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,18 +21,14 @@ import argparse import sys import tempfile -# pylint: disable=g-backslash-continuation -from tensorflow.contrib.learn.python.learn\ - import metric_spec -from tensorflow.contrib.learn.python.learn.estimators\ - import estimator -from tensorflow.contrib.tensor_forest.client\ - import eval_metrics -from tensorflow.contrib.tensor_forest.client\ - import random_forest -from tensorflow.contrib.tensor_forest.python\ - import tensor_forest +import numpy + +from tensorflow.contrib.learn.python.learn import metric_spec +from tensorflow.contrib.tensor_forest.client import eval_metrics +from tensorflow.contrib.tensor_forest.client import random_forest +from tensorflow.contrib.tensor_forest.python import tensor_forest from tensorflow.examples.tutorials.mnist import input_data +from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.platform import app FLAGS = None @@ -41,16 +37,15 @@ FLAGS = None def build_estimator(model_dir): """Build an estimator.""" params = tensor_forest.ForestHParams( - num_classes=10, num_features=784, - num_trees=FLAGS.num_trees, max_nodes=FLAGS.max_nodes) + num_classes=10, + num_features=784, + num_trees=FLAGS.num_trees, + max_nodes=FLAGS.max_nodes) graph_builder_class = tensor_forest.RandomForestGraphs if FLAGS.use_training_loss: graph_builder_class = tensor_forest.TrainingLossForest - # Use the SKCompat wrapper, which gives us a convenient way to split - # in-memory data like MNIST into batches. - return estimator.SKCompat(random_forest.TensorForestEstimator( - params, graph_builder_class=graph_builder_class, - model_dir=model_dir)) + return random_forest.TensorForestEstimator( + params, graph_builder_class=graph_builder_class, model_dir=model_dir) def train_and_eval(): @@ -62,18 +57,30 @@ def train_and_eval(): mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=False) - est.fit(x=mnist.train.images, y=mnist.train.labels, - batch_size=FLAGS.batch_size) + train_input_fn = numpy_io.numpy_input_fn( + x={'images': mnist.train.images}, + y=mnist.train.labels.astype(numpy.int32), + batch_size=FLAGS.batch_size, + num_epochs=None, + shuffle=True) + est.fit(input_fn=train_input_fn, steps=None) metric_name = 'accuracy' - metric = {metric_name: - metric_spec.MetricSpec( - eval_metrics.get_metric(metric_name), - prediction_key=eval_metrics.get_prediction_key(metric_name))} - - results = est.score(x=mnist.test.images, y=mnist.test.labels, - batch_size=FLAGS.batch_size, - metrics=metric) + metric = { + metric_name: + metric_spec.MetricSpec( + eval_metrics.get_metric(metric_name), + prediction_key=eval_metrics.get_prediction_key(metric_name)) + } + + test_input_fn = numpy_io.numpy_input_fn( + x={'images': mnist.test.images}, + y=mnist.test.labels.astype(numpy.int32), + num_epochs=1, + batch_size=FLAGS.batch_size, + shuffle=False) + + results = est.evaluate(input_fn=test_input_fn, metrics=metric) for key in sorted(results): print('%s: %s' % (key, results[key])) |