aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/tensor_forest/client/eval_metrics.py4
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py21
2 files changed, 21 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py
index c99f9b7c12..1726986354 100644
--- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py
+++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py
@@ -27,8 +27,8 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
-INFERENCE_PROB_NAME = prediction_key.PredictionKey.CLASSES
-INFERENCE_PRED_NAME = prediction_key.PredictionKey.PROBABILITIES
+INFERENCE_PROB_NAME = prediction_key.PredictionKey.PROBABILITIES
+INFERENCE_PRED_NAME = prediction_key.PredictionKey.CLASSES
FEATURE_IMPORTANCE_NAME = 'global_feature_importance'
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index 0ba636c697..0da1f78755 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -19,8 +19,10 @@ from __future__ import print_function
from tensorflow.contrib import framework as contrib_framework
+from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
+from tensorflow.contrib.learn.python.learn.estimators import prediction_key
from tensorflow.contrib.tensor_forest.client import eval_metrics
from tensorflow.contrib.tensor_forest.python import tensor_forest
@@ -145,15 +147,29 @@ def get_model_fn(params,
graph_builder = graph_builder_class(params,
device_assigner=dev_assn)
inference = {}
+ output_alternatives = None
if (mode == model_fn_lib.ModeKeys.EVAL or
mode == model_fn_lib.ModeKeys.INFER):
inference[eval_metrics.INFERENCE_PROB_NAME] = (
graph_builder.inference_graph(features))
- if not params.regression:
+ if params.regression:
+ predictions = {
+ None: inference[eval_metrics.INFERENCE_PROB_NAME]}
+ output_alternatives = {
+ None: (constants.ProblemType.LINEAR_REGRESSION, predictions)}
+ else:
inference[eval_metrics.INFERENCE_PRED_NAME] = math_ops.argmax(
inference[eval_metrics.INFERENCE_PROB_NAME], 1)
+ predictions = {
+ prediction_key.PredictionKey.PROBABILITIES:
+ inference[eval_metrics.INFERENCE_PROB_NAME],
+ prediction_key.PredictionKey.CLASSES:
+ inference[eval_metrics.INFERENCE_PRED_NAME]}
+ output_alternatives = {
+ None: (constants.ProblemType.CLASSIFICATION, predictions)}
+
if report_feature_importances:
inference[eval_metrics.FEATURE_IMPORTANCE_NAME] = (
graph_builder.feature_importances())
@@ -205,7 +221,8 @@ def get_model_fn(params,
loss=training_loss,
train_op=training_graph,
training_hooks=training_hooks,
- scaffold=scaffold)
+ scaffold=scaffold,
+ output_alternatives=output_alternatives)
return _model_fn