diff options
author | David Soergel <soergel@google.com> | 2017-05-11 19:19:46 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-11 19:23:47 -0700 |
commit | cfdf449ae2daa749e66565f846808bb9e5c1bff5 (patch) | |
tree | 0e40abd5ca39df2f72e929aeb95059ce5cc630a6 | |
parent | 8364dd7d8c9a1ff1ea311b3e70c120ca34215f94 (diff) |
Update TensorForest to provide output_alternatives for serving.
Also, fix reversed inference dict key constants.
PiperOrigin-RevId: 155826149
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/eval_metrics.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/random_forest.py | 21 |
2 files changed, 21 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py index c99f9b7c12..1726986354 100644 --- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py +++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py @@ -27,8 +27,8 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn -INFERENCE_PROB_NAME = prediction_key.PredictionKey.CLASSES -INFERENCE_PRED_NAME = prediction_key.PredictionKey.PROBABILITIES +INFERENCE_PROB_NAME = prediction_key.PredictionKey.PROBABILITIES +INFERENCE_PRED_NAME = prediction_key.PredictionKey.CLASSES FEATURE_IMPORTANCE_NAME = 'global_feature_importance' diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index 0ba636c697..0da1f78755 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -19,8 +19,10 @@ from __future__ import print_function from tensorflow.contrib import framework as contrib_framework +from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib +from tensorflow.contrib.learn.python.learn.estimators import prediction_key from tensorflow.contrib.tensor_forest.client import eval_metrics from tensorflow.contrib.tensor_forest.python import tensor_forest @@ -145,15 +147,29 @@ def get_model_fn(params, graph_builder = graph_builder_class(params, device_assigner=dev_assn) inference = {} + output_alternatives = None if (mode == model_fn_lib.ModeKeys.EVAL or mode == model_fn_lib.ModeKeys.INFER): inference[eval_metrics.INFERENCE_PROB_NAME] = ( graph_builder.inference_graph(features)) - if not params.regression: + if params.regression: + predictions = { + None: inference[eval_metrics.INFERENCE_PROB_NAME]} + output_alternatives = { + None: (constants.ProblemType.LINEAR_REGRESSION, predictions)} + else: inference[eval_metrics.INFERENCE_PRED_NAME] = math_ops.argmax( inference[eval_metrics.INFERENCE_PROB_NAME], 1) + predictions = { + prediction_key.PredictionKey.PROBABILITIES: + inference[eval_metrics.INFERENCE_PROB_NAME], + prediction_key.PredictionKey.CLASSES: + inference[eval_metrics.INFERENCE_PRED_NAME]} + output_alternatives = { + None: (constants.ProblemType.CLASSIFICATION, predictions)} + if report_feature_importances: inference[eval_metrics.FEATURE_IMPORTANCE_NAME] = ( graph_builder.feature_importances()) @@ -205,7 +221,8 @@ def get_model_fn(params, loss=training_loss, train_op=training_graph, training_hooks=training_hooks, - scaffold=scaffold) + scaffold=scaffold, + output_alternatives=output_alternatives) return _model_fn |