aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-11 17:13:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-11 17:19:05 -0700
commitbbee0c4c26d94aa7f0115f984116167052afa11e (patch)
tree97776553e7b3cc4fb2c23c43bf57edb8fa8cb196
parentc8980fd1b4d3a74de0214690f810d0c93da2558f (diff)
Checking that TPUEstimator model function features have static shapes.
PiperOrigin-RevId: 200139880
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py47
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 64ae35dfc5..2521522752 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -1343,8 +1343,55 @@ class _ModelFnWrapper(object):
key, tensor))
return predictions
+ def _validate_model_features_and_labels(self,
+ features,
+ labels,
+ is_export_mode):
+ """Validates that the features and labels for the model function are valid.
+
+ A valid features/labels object is the one with:
+ - Type: Tensor or a dictionary of Tensors
+ - Static shape if is_export_mode is False.
+
+ Args:
+ features: the features that would be input to the model function.
+ labels: the labels that would be input to the model function.
+ is_export_mode: boolean value specifying if in export mode.
+
+ Raises:
+ TypeError: If features/labels are not of the correct type.
+ ValueError: If features/labels have dynamic shape.
+ """
+
+ def validate(obj, obj_name):
+ """Helper validate function."""
+ if not isinstance(obj, ops.Tensor) and not isinstance(obj, dict):
+ raise TypeError(
+ 'The {} to the model returned by input_fn must be either a Tensor '
+ 'or a dictionary of Tensors. {}: {}'.format(obj_name, obj_name,
+ obj))
+ if is_export_mode:
+ return
+ if isinstance(obj, ops.Tensor):
+ if not obj.get_shape().is_fully_defined():
+ raise ValueError(
+ 'The {} to the model returned by input_fn must have static shape.'
+ ' Tensor: {}'.format(obj_name, obj))
+ else:
+ for (key, tensor) in obj.items():
+ if not tensor.get_shape().is_fully_defined():
+ raise ValueError(
+ 'The {} to the model returned by input_fn must have static '
+ 'shape. Key: \'{}\', Tensor: {}'.format(
+ obj_name, key, tensor))
+
+ validate(features, 'features')
+ if labels is not None:
+ validate(labels, 'labels')
+
def _call_model_fn(self, features, labels, is_export_mode=False):
"""Calls the model_fn with required parameters."""
+ self._validate_model_features_and_labels(features, labels, is_export_mode)
model_fn_args = function_utils.fn_args(self._model_fn)
kwargs = {}