aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/export
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-06 11:40:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-06 11:40:39 -0700
commit20adbf1c5ba313210ac496ec61472319690739be (patch)
treea791da5df40f9b689215295c370b5e6e44ef4c29 /tensorflow/python/estimator/export
parentd0ec94e71091a989583fc4fb760c1301f3921292 (diff)
parentec24a640d90146f24c0f386f39538f029a51dbf6 (diff)
Merge pull request #20056 from msamogh:master
PiperOrigin-RevId: 207578381
Diffstat (limited to 'tensorflow/python/estimator/export')
-rw-r--r--tensorflow/python/estimator/export/export.py75
-rw-r--r--tensorflow/python/estimator/export/export_test.py6
2 files changed, 37 insertions, 44 deletions
diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py
index ca26341445..529e7a8b87 100644
--- a/tensorflow/python/estimator/export/export.py
+++ b/tensorflow/python/estimator/export/export.py
@@ -40,29 +40,38 @@ _SINGLE_FEATURE_DEFAULT_NAME = 'feature'
_SINGLE_RECEIVER_DEFAULT_NAME = 'input'
_SINGLE_LABEL_DEFAULT_NAME = 'label'
+_SINGLE_TENSOR_DEFAULT_NAMES = {
+ 'feature': _SINGLE_FEATURE_DEFAULT_NAME,
+ 'label': _SINGLE_LABEL_DEFAULT_NAME,
+ 'receiver_tensor': _SINGLE_RECEIVER_DEFAULT_NAME,
+ 'receiver_tensors_alternative': _SINGLE_RECEIVER_DEFAULT_NAME
+}
+
-def _wrap_and_check_receiver_tensors(receiver_tensors):
- """Ensure that receiver_tensors is a dict of str to Tensor mappings.
+def _wrap_and_check_input_tensors(tensors, field_name):
+ """Ensure that tensors is a dict of str to Tensor mappings.
Args:
- receiver_tensors: dict of str to Tensors, or a single Tensor.
+ tensors: dict of str to Tensors, or a single Tensor.
+ field_name: name of the member field of `ServingInputReceiver`
+ whose value is being passed to `tensors`.
Returns:
dict of str to Tensors; this is the original dict if one was passed, or
the original tensor wrapped in a dictionary.
Raises:
- ValueError: if receiver_tensors is None, or has non-string keys,
+ ValueError: if tensors is None, or has non-string keys,
or non-Tensor values
"""
- if receiver_tensors is None:
- raise ValueError('receiver_tensors must be defined.')
- if not isinstance(receiver_tensors, dict):
- receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors}
- for name, tensor in receiver_tensors.items():
- _check_tensor_key(name, error_label='receiver_tensors')
- _check_tensor(tensor, name, error_label='receiver_tensor')
- return receiver_tensors
+ if tensors is None:
+ raise ValueError('{}s must be defined.'.format(field_name))
+ if not isinstance(tensors, dict):
+ tensors = {_SINGLE_TENSOR_DEFAULT_NAMES[field_name]: tensors}
+ for name, tensor in tensors.items():
+ _check_tensor_key(name, error_label=field_name)
+ _check_tensor(tensor, name, error_label=field_name)
+ return tensors
def _check_tensor(tensor, name, error_label='feature'):
@@ -125,15 +134,10 @@ class ServingInputReceiver(
features,
receiver_tensors,
receiver_tensors_alternatives=None):
- if features is None:
- raise ValueError('features must be defined.')
- if not isinstance(features, dict):
- features = {_SINGLE_FEATURE_DEFAULT_NAME: features}
- for name, tensor in features.items():
- _check_tensor_key(name)
- _check_tensor(tensor, name)
+ features = _wrap_and_check_input_tensors(features, 'feature')
- receiver_tensors = _wrap_and_check_receiver_tensors(receiver_tensors)
+ receiver_tensors = _wrap_and_check_input_tensors(receiver_tensors,
+ 'receiver_tensor')
if receiver_tensors_alternatives is not None:
if not isinstance(receiver_tensors_alternatives, dict):
@@ -142,17 +146,10 @@ class ServingInputReceiver(
receiver_tensors_alternatives))
for alternative_name, receiver_tensors_alt in (
six.iteritems(receiver_tensors_alternatives)):
- if not isinstance(receiver_tensors_alt, dict):
- receiver_tensors_alt = {
- _SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors_alt
- }
- # Updating dict during iteration is OK in this case.
- receiver_tensors_alternatives[alternative_name] = (
- receiver_tensors_alt)
- for name, tensor in receiver_tensors_alt.items():
- _check_tensor_key(name, error_label='receiver_tensors_alternative')
- _check_tensor(
- tensor, name, error_label='receiver_tensors_alternative')
+ # Updating dict during iteration is OK in this case.
+ receiver_tensors_alternatives[alternative_name] = (
+ _wrap_and_check_input_tensors(
+ receiver_tensors_alt, 'receiver_tensors_alternative'))
return super(ServingInputReceiver, cls).__new__(
cls,
@@ -245,16 +242,12 @@ class SupervisedInputReceiver(
def __new__(cls, features, labels, receiver_tensors):
# Both features and labels can be dicts or raw tensors.
for input_vals, error_label in ((features, 'feature'), (labels, 'label')):
- if input_vals is None:
- raise ValueError('{}s must be defined.'.format(error_label))
- if isinstance(input_vals, dict):
- for name, tensor in input_vals.items():
- _check_tensor_key(name, error_label=error_label)
- _check_tensor(tensor, name, error_label=error_label)
- else:
- _check_tensor(input_vals, None, error_label=error_label)
-
- receiver_tensors = _wrap_and_check_receiver_tensors(receiver_tensors)
+ # _wrap_and_check_input_tensors is called here only to validate the
+ # tensors. The wrapped dict that is returned is deliberately discarded.
+ _wrap_and_check_input_tensors(input_vals, error_label)
+
+ receiver_tensors = _wrap_and_check_input_tensors(receiver_tensors,
+ 'receiver_tensor')
return super(SupervisedInputReceiver, cls).__new__(
cls,
diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py
index a7074712c2..d2ac7f0b3b 100644
--- a/tensorflow/python/estimator/export/export_test.py
+++ b/tensorflow/python/estimator/export/export_test.py
@@ -107,7 +107,7 @@ class ServingInputReceiverTest(test_util.TensorFlowTestCase):
receiver_tensors=None)
with self.assertRaisesRegexp(
- ValueError, "receiver_tensors keys must be strings"):
+ ValueError, "receiver_tensor keys must be strings"):
export.ServingInputReceiver(
features=features,
receiver_tensors={
@@ -271,7 +271,7 @@ class SupervisedInputReceiverTest(test_util.TensorFlowTestCase):
receiver_tensors=None)
with self.assertRaisesRegexp(
- ValueError, "receiver_tensors keys must be strings"):
+ ValueError, "receiver_tensor keys must be strings"):
export.SupervisedInputReceiver(
features=features,
labels=labels,
@@ -740,7 +740,7 @@ class TensorServingReceiverTest(test_util.TensorFlowTestCase):
receiver_tensors=None)
with self.assertRaisesRegexp(
- ValueError, "receiver_tensors keys must be strings"):
+ ValueError, "receiver_tensor keys must be strings"):
export.TensorServingInputReceiver(
features=features,
receiver_tensors={