aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/learn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-23 10:45:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-23 11:23:31 -0800
commit7b2b8cc0b221cb5ec33e7e8d8a856d7453eaba02 (patch)
tree3a844736bdae1c8bbebc70512cb6dbc5251b53fc /tensorflow/examples/learn
parente609117c8ac28c0fb60e76dc745a7333d1a558ca (diff)
Refactor TensorForestEstimator to inherit from tf.learn Estimator. This changes the interface a bit:
- TensorForestLossHook should be passed to fit, it isn't included by default - predict returns an iterable of dicts that contains probabilities and predictions. - in-memory data sets (x=, y=) should now wrap TensorForestEstimator with estimator.SKCompat, instead of using the unified interface that it previously supplied. Change: 148362051
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r--tensorflow/examples/learn/random_forest_mnist.py27
1 files changed, 13 insertions, 14 deletions
diff --git a/tensorflow/examples/learn/random_forest_mnist.py b/tensorflow/examples/learn/random_forest_mnist.py
index 6a943eb42e..3c09990ea1 100644
--- a/tensorflow/examples/learn/random_forest_mnist.py
+++ b/tensorflow/examples/learn/random_forest_mnist.py
@@ -1,4 +1,4 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+ # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,6 +24,8 @@ import tempfile
# pylint: disable=g-backslash-continuation
from tensorflow.contrib.learn.python.learn\
import metric_spec
+from tensorflow.contrib.learn.python.learn.estimators\
+ import estimator
from tensorflow.contrib.tensor_forest.client\
import eval_metrics
from tensorflow.contrib.tensor_forest.client\
@@ -44,9 +46,11 @@ def build_estimator(model_dir):
graph_builder_class = tensor_forest.RandomForestGraphs
if FLAGS.use_training_loss:
graph_builder_class = tensor_forest.TrainingLossForest
- return random_forest.TensorForestEstimator(
+ # Use the SKCompat wrapper, which gives us a convenient way to split
+ # in-memory data like MNIST into batches.
+ return estimator.SKCompat(random_forest.TensorForestEstimator(
params, graph_builder_class=graph_builder_class,
- model_dir=model_dir)
+ model_dir=model_dir))
def train_and_eval():
@@ -54,17 +58,12 @@ def train_and_eval():
model_dir = tempfile.mkdtemp() if not FLAGS.model_dir else FLAGS.model_dir
print('model directory = %s' % model_dir)
- estimator = build_estimator(model_dir)
-
- # TensorForest's loss hook allows training to terminate early if the
- # forest is no longer growing.
- early_stopping_rounds = 100
- monitor = random_forest.TensorForestLossHook(early_stopping_rounds)
+ est = build_estimator(model_dir)
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=False)
- estimator.fit(x=mnist.train.images, y=mnist.train.labels,
- batch_size=FLAGS.batch_size, monitors=[monitor])
+ est.fit(x=mnist.train.images, y=mnist.train.labels,
+ batch_size=FLAGS.batch_size)
metric_name = 'accuracy'
metric = {metric_name:
@@ -72,9 +71,9 @@ def train_and_eval():
eval_metrics.get_metric(metric_name),
prediction_key=eval_metrics.get_prediction_key(metric_name))}
- results = estimator.evaluate(x=mnist.test.images, y=mnist.test.labels,
- batch_size=FLAGS.batch_size,
- metrics=metric)
+ results = est.score(x=mnist.test.images, y=mnist.test.labels,
+ batch_size=FLAGS.batch_size,
+ metrics=metric)
for key in sorted(results):
print('%s: %s' % (key, results[key]))