aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-09 12:37:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-09 12:39:32 -0700
commit970baf64cffe9de0b124b5eea53b1ee1d5158506 (patch)
treeb3295a5288030bde2bf7147d997285d36954a75b
parent2e6f8b3f05fe2d212c19b9598f93f4e6ee07675f (diff)
Renames exported signature names in MultiHead so head_name comes first.
PiperOrigin-RevId: 192168628
-rw-r--r--tensorflow/contrib/estimator/python/estimator/multi_head.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/multi_head_test.py16
2 files changed, 9 insertions, 9 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py
index bbbc19cc4d..ce75899214 100644
--- a/tensorflow/contrib/estimator/python/estimator/multi_head.py
+++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py
@@ -345,7 +345,7 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
if k == _DEFAULT_SERVING_KEY:
key = head_name
else:
- key = '%s/%s' % (k, head_name)
+ key = '%s/%s' % (head_name, k)
export_outputs[key] = v
if (k == head_lib._PREDICT_SERVING_KEY and # pylint:disable=protected-access
isinstance(v, export_output_lib.PredictOutput)):
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
index d9e5aca295..3d6fccb118 100644
--- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
@@ -127,8 +127,8 @@ class MultiHeadTest(test.TestCase):
logits=logits)
self.assertItemsEqual(
- (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'classification/head1',
- 'predict/head1', 'head2', 'classification/head2', 'predict/head2'),
+ (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification',
+ 'head1/predict', 'head2', 'head2/classification', 'head2/predict'),
spec.export_outputs.keys())
# Assert predictions and export_outputs.
@@ -169,11 +169,11 @@ class MultiHeadTest(test.TestCase):
self.assertAllClose(
expected_probabilities['head1'],
sess.run(
- spec.export_outputs['predict/head1'].outputs['probabilities']))
+ spec.export_outputs['head1/predict'].outputs['probabilities']))
self.assertAllClose(
expected_probabilities['head2'],
sess.run(
- spec.export_outputs['predict/head2'].outputs['probabilities']))
+ spec.export_outputs['head2/predict'].outputs['probabilities']))
def test_predict_two_heads_logits_tensor(self):
"""Tests predict with logits as Tensor."""
@@ -197,8 +197,8 @@ class MultiHeadTest(test.TestCase):
logits=logits)
self.assertItemsEqual(
- (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'classification/head1',
- 'predict/head1', 'head2', 'classification/head2', 'predict/head2'),
+ (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification',
+ 'head1/predict', 'head2', 'head2/classification', 'head2/predict'),
spec.export_outputs.keys())
# Assert predictions and export_outputs.
@@ -254,8 +254,8 @@ class MultiHeadTest(test.TestCase):
logits=logits)
self.assertItemsEqual(
- (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'regression/head1',
- 'predict/head1', 'head2', 'regression/head2', 'predict/head2'),
+ (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/regression',
+ 'head1/predict', 'head2', 'head2/regression', 'head2/predict'),
spec.export_outputs.keys())
# Assert predictions and export_outputs.