diff options
author | 2018-08-06 11:40:39 -0700 | |
---|---|---|
committer | 2018-08-06 11:40:39 -0700 | |
commit | 20adbf1c5ba313210ac496ec61472319690739be (patch) | |
tree | a791da5df40f9b689215295c370b5e6e44ef4c29 /tensorflow/python/estimator/export | |
parent | d0ec94e71091a989583fc4fb760c1301f3921292 (diff) | |
parent | ec24a640d90146f24c0f386f39538f029a51dbf6 (diff) |
Merge pull request #20056 from msamogh:master
PiperOrigin-RevId: 207578381
Diffstat (limited to 'tensorflow/python/estimator/export')
-rw-r--r-- | tensorflow/python/estimator/export/export.py | 75 | ||||
-rw-r--r-- | tensorflow/python/estimator/export/export_test.py | 6 |
2 files changed, 37 insertions, 44 deletions
diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index ca26341445..529e7a8b87 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -40,29 +40,38 @@ _SINGLE_FEATURE_DEFAULT_NAME = 'feature' _SINGLE_RECEIVER_DEFAULT_NAME = 'input' _SINGLE_LABEL_DEFAULT_NAME = 'label' +_SINGLE_TENSOR_DEFAULT_NAMES = { + 'feature': _SINGLE_FEATURE_DEFAULT_NAME, + 'label': _SINGLE_LABEL_DEFAULT_NAME, + 'receiver_tensor': _SINGLE_RECEIVER_DEFAULT_NAME, + 'receiver_tensors_alternative': _SINGLE_RECEIVER_DEFAULT_NAME +} + -def _wrap_and_check_receiver_tensors(receiver_tensors): - """Ensure that receiver_tensors is a dict of str to Tensor mappings. +def _wrap_and_check_input_tensors(tensors, field_name): + """Ensure that tensors is a dict of str to Tensor mappings. Args: - receiver_tensors: dict of str to Tensors, or a single Tensor. + tensors: dict of str to Tensors, or a single Tensor. + field_name: name of the member field of `ServingInputReceiver` + whose value is being passed to `tensors`. Returns: dict of str to Tensors; this is the original dict if one was passed, or the original tensor wrapped in a dictionary. Raises: - ValueError: if receiver_tensors is None, or has non-string keys, + ValueError: if tensors is None, or has non-string keys, or non-Tensor values """ - if receiver_tensors is None: - raise ValueError('receiver_tensors must be defined.') - if not isinstance(receiver_tensors, dict): - receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} - for name, tensor in receiver_tensors.items(): - _check_tensor_key(name, error_label='receiver_tensors') - _check_tensor(tensor, name, error_label='receiver_tensor') - return receiver_tensors + if tensors is None: + raise ValueError('{}s must be defined.'.format(field_name)) + if not isinstance(tensors, dict): + tensors = {_SINGLE_TENSOR_DEFAULT_NAMES[field_name]: tensors} + for name, tensor in tensors.items(): + _check_tensor_key(name, error_label=field_name) + _check_tensor(tensor, name, error_label=field_name) + return tensors def _check_tensor(tensor, name, error_label='feature'): @@ -125,15 +134,10 @@ class ServingInputReceiver( features, receiver_tensors, receiver_tensors_alternatives=None): - if features is None: - raise ValueError('features must be defined.') - if not isinstance(features, dict): - features = {_SINGLE_FEATURE_DEFAULT_NAME: features} - for name, tensor in features.items(): - _check_tensor_key(name) - _check_tensor(tensor, name) + features = _wrap_and_check_input_tensors(features, 'feature') - receiver_tensors = _wrap_and_check_receiver_tensors(receiver_tensors) + receiver_tensors = _wrap_and_check_input_tensors(receiver_tensors, + 'receiver_tensor') if receiver_tensors_alternatives is not None: if not isinstance(receiver_tensors_alternatives, dict): @@ -142,17 +146,10 @@ class ServingInputReceiver( receiver_tensors_alternatives)) for alternative_name, receiver_tensors_alt in ( six.iteritems(receiver_tensors_alternatives)): - if not isinstance(receiver_tensors_alt, dict): - receiver_tensors_alt = { - _SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors_alt - } - # Updating dict during iteration is OK in this case. - receiver_tensors_alternatives[alternative_name] = ( - receiver_tensors_alt) - for name, tensor in receiver_tensors_alt.items(): - _check_tensor_key(name, error_label='receiver_tensors_alternative') - _check_tensor( - tensor, name, error_label='receiver_tensors_alternative') + # Updating dict during iteration is OK in this case. + receiver_tensors_alternatives[alternative_name] = ( + _wrap_and_check_input_tensors( + receiver_tensors_alt, 'receiver_tensors_alternative')) return super(ServingInputReceiver, cls).__new__( cls, @@ -245,16 +242,12 @@ class SupervisedInputReceiver( def __new__(cls, features, labels, receiver_tensors): # Both features and labels can be dicts or raw tensors. for input_vals, error_label in ((features, 'feature'), (labels, 'label')): - if input_vals is None: - raise ValueError('{}s must be defined.'.format(error_label)) - if isinstance(input_vals, dict): - for name, tensor in input_vals.items(): - _check_tensor_key(name, error_label=error_label) - _check_tensor(tensor, name, error_label=error_label) - else: - _check_tensor(input_vals, None, error_label=error_label) - - receiver_tensors = _wrap_and_check_receiver_tensors(receiver_tensors) + # _wrap_and_check_input_tensors is called here only to validate the + # tensors. The wrapped dict that is returned is deliberately discarded. + _wrap_and_check_input_tensors(input_vals, error_label) + + receiver_tensors = _wrap_and_check_input_tensors(receiver_tensors, + 'receiver_tensor') return super(SupervisedInputReceiver, cls).__new__( cls, diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py index a7074712c2..d2ac7f0b3b 100644 --- a/tensorflow/python/estimator/export/export_test.py +++ b/tensorflow/python/estimator/export/export_test.py @@ -107,7 +107,7 @@ class ServingInputReceiverTest(test_util.TensorFlowTestCase): receiver_tensors=None) with self.assertRaisesRegexp( - ValueError, "receiver_tensors keys must be strings"): + ValueError, "receiver_tensor keys must be strings"): export.ServingInputReceiver( features=features, receiver_tensors={ @@ -271,7 +271,7 @@ class SupervisedInputReceiverTest(test_util.TensorFlowTestCase): receiver_tensors=None) with self.assertRaisesRegexp( - ValueError, "receiver_tensors keys must be strings"): + ValueError, "receiver_tensor keys must be strings"): export.SupervisedInputReceiver( features=features, labels=labels, @@ -740,7 +740,7 @@ class TensorServingReceiverTest(test_util.TensorFlowTestCase): receiver_tensors=None) with self.assertRaisesRegexp( - ValueError, "receiver_tensors keys must be strings"): + ValueError, "receiver_tensor keys must be strings"): export.TensorServingInputReceiver( features=features, receiver_tensors={ |