aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-30 10:18:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-30 12:32:21 -0800
commit8f0e7207774279f4fe50f4d6c4fbd576e2941463 (patch)
treecf2b67f0885347ae46e1fed30ab34bd98c7f3310 /tensorflow/contrib/tensor_forest
parent7149a2e2e2f549035f23e21224ee41afe8df3876 (diff)
Prepare variance to be exported for serving with the servo library.
PiperOrigin-RevId: 183851026
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py41
1 files changed, 34 insertions, 7 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index a998ac1e11..4abcc20ed3 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -18,7 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib import layers
-
+from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
@@ -43,8 +43,8 @@ from tensorflow.python.training import training_util
KEYS_NAME = 'keys'
LOSS_NAME = 'rf_training_loss'
TREE_PATHS_PREDICTION_KEY = 'tree_paths'
-VARIANCE_PREDICTION_KEY = 'regression_variance'
-
+VARIANCE_PREDICTION_KEY = 'prediction_variance'
+ALL_SERVING_KEY = 'tensorforest_all'
EPSILON = 0.000001
@@ -134,7 +134,8 @@ def get_model_fn(params,
trainer_id=0,
report_feature_importances=False,
local_eval=False,
- head_scope=None):
+ head_scope=None,
+ include_all_in_serving=False):
"""Return a model function given a way to construct a graph builder."""
if model_head is None:
model_head = get_default_head(params, weights_name)
@@ -238,7 +239,13 @@ def get_model_fn(params,
model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
-
+ if include_all_in_serving:
+ # In order to serve the variance we need to add the prediction dict
+ # to output_alternatives dict.
+ if not model_ops.output_alternatives:
+ model_ops.output_alternatives = {}
+ model_ops.output_alternatives[ALL_SERVING_KEY] = (
+ constants.ProblemType.UNSPECIFIED, model_ops.predictions)
return model_ops
return _model_fn
@@ -293,7 +300,8 @@ class TensorForestEstimator(estimator.Estimator):
report_feature_importances=False,
local_eval=False,
version=None,
- head=None):
+ head=None,
+ include_all_in_serving=False):
"""Initializes a TensorForestEstimator instance.
Args:
@@ -339,6 +347,23 @@ class TensorForestEstimator(estimator.Estimator):
version: Unused.
head: A heads_lib.Head object that calculates losses and such. If None,
one will be automatically created based on params.
+ include_all_in_serving: if True, allow preparation of the complete
+ prediction dict including the variance to be exported for serving with
+ the Servo lib; and it also requires calling export_savedmodel with
+ default_output_alternative_key=ALL_SERVING_KEY, i.e.
+ estimator.export_savedmodel(export_dir_base=your_export_dir,
+ serving_input_fn=your_export_input_fn,
+ default_output_alternative_key=ALL_SERVING_KEY)
+ if False, resort to default behavior, i.e. export scores and
+ probabilities but no variances. In this case
+ default_output_alternative_key should be None while calling
+ export_savedmodel().
+ Note, that due to backward compatibility we cannot always set
+ include_all_in_serving to True because in this case calling
+ export_saved_model() without
+ default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the
+ saved_model_export_utils.get_output_alternatives() would raise
+ ValueError.
Returns:
A `TensorForestEstimator` instance.
@@ -357,7 +382,9 @@ class TensorForestEstimator(estimator.Estimator):
num_trainers=num_trainers,
trainer_id=trainer_id,
report_feature_importances=report_feature_importances,
- local_eval=local_eval),
+ local_eval=local_eval,
+ include_all_in_serving=include_all_in_serving,
+ ),
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)