aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar David Soergel <soergel@google.com>2017-01-13 12:14:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-13 12:24:11 -0800
commit1c5120141b2043c4e0721774c183cb01d23b0682 (patch)
tree0dee49d45aa2977da670d38c929626bcd6872056
parent8f893368fc572f63da83666ca22722cea265a6c8 (diff)
Fix SavedModel export when predictions is a single tensor and output_alternatives not given
Change: 144470271
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/prediction_key.py1
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py15
3 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py b/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py
index 8a6d0ef018..7dc26781f9 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py
@@ -25,3 +25,4 @@ class PredictionKey(object):
LOGISTIC = "logistic"
SCORES = "scores"
TOP_K = "top_k"
+ GENERIC = "output"
diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
index c386b2adf9..c3fdd3086c 100644
--- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
+++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
@@ -167,6 +167,8 @@ def get_output_alternatives(
# interpret the model as single-headed of unknown type.
default_problem_type = constants.ProblemType.UNSPECIFIED
default_outputs = model_fn_ops.predictions
+ if not isinstance(default_outputs, dict):
+ default_outputs = {prediction_key.PredictionKey.GENERIC: default_outputs}
actual_default_output_alternative_key = DEFAULT_OUTPUT_ALTERNATIVE_KEY
output_alternatives = {actual_default_output_alternative_key:
(default_problem_type, default_outputs)}
diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py
index 2ce28f8648..23d171b58b 100644
--- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py
+++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py
@@ -124,6 +124,21 @@ class SavedModelExportUtilsTest(test.TestCase):
})
}, output_alternatives)
+ def test_get_output_alternatives_implicit_single(self):
+ prediction_tensor = constant_op.constant(["bogus"])
+ model_fn_ops = model_fn.ModelFnOps(
+ model_fn.ModeKeys.INFER,
+ predictions=prediction_tensor,
+ output_alternatives=None)
+
+ output_alternatives, _ = saved_model_export_utils.get_output_alternatives(
+ model_fn_ops)
+ self.assertEqual({
+ "default_output_alternative": (constants.ProblemType.UNSPECIFIED, {
+ "output": prediction_tensor
+ })
+ }, output_alternatives)
+
def test_build_all_signature_defs(self):
input_features = constant_op.constant(["10"])
input_example = constant_op.constant(["11"])