diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-09 12:37:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-09 12:39:32 -0700 |
commit | 970baf64cffe9de0b124b5eea53b1ee1d5158506 (patch) | |
tree | b3295a5288030bde2bf7147d997285d36954a75b | |
parent | 2e6f8b3f05fe2d212c19b9598f93f4e6ee07675f (diff) |
Renames exported signature names in MultiHead so head_name comes first.
PiperOrigin-RevId: 192168628
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/multi_head.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/multi_head_test.py | 16 |
2 files changed, 9 insertions, 9 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index bbbc19cc4d..ce75899214 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -345,7 +345,7 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access if k == _DEFAULT_SERVING_KEY: key = head_name else: - key = '%s/%s' % (k, head_name) + key = '%s/%s' % (head_name, k) export_outputs[key] = v if (k == head_lib._PREDICT_SERVING_KEY and # pylint:disable=protected-access isinstance(v, export_output_lib.PredictOutput)): diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index d9e5aca295..3d6fccb118 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -127,8 +127,8 @@ class MultiHeadTest(test.TestCase): logits=logits) self.assertItemsEqual( - (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'classification/head1', - 'predict/head1', 'head2', 'classification/head2', 'predict/head2'), + (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification', + 'head1/predict', 'head2', 'head2/classification', 'head2/predict'), spec.export_outputs.keys()) # Assert predictions and export_outputs. @@ -169,11 +169,11 @@ class MultiHeadTest(test.TestCase): self.assertAllClose( expected_probabilities['head1'], sess.run( - spec.export_outputs['predict/head1'].outputs['probabilities'])) + spec.export_outputs['head1/predict'].outputs['probabilities'])) self.assertAllClose( expected_probabilities['head2'], sess.run( - spec.export_outputs['predict/head2'].outputs['probabilities'])) + spec.export_outputs['head2/predict'].outputs['probabilities'])) def test_predict_two_heads_logits_tensor(self): """Tests predict with logits as Tensor.""" @@ -197,8 +197,8 @@ class MultiHeadTest(test.TestCase): logits=logits) self.assertItemsEqual( - (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'classification/head1', - 'predict/head1', 'head2', 'classification/head2', 'predict/head2'), + (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification', + 'head1/predict', 'head2', 'head2/classification', 'head2/predict'), spec.export_outputs.keys()) # Assert predictions and export_outputs. @@ -254,8 +254,8 @@ class MultiHeadTest(test.TestCase): logits=logits) self.assertItemsEqual( - (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'regression/head1', - 'predict/head1', 'head2', 'regression/head2', 'predict/head2'), + (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/regression', + 'head1/predict', 'head2', 'head2/regression', 'head2/predict'), spec.export_outputs.keys()) # Assert predictions and export_outputs. |