aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2018-05-22 12:36:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-22 12:39:28 -0700
commita4c9efe6a5bf143f844b1cffbdc839c399620b9b (patch)
tree1c6df271c0aeb64d8027b037ba64bf57eaabc972
parent67b6696f9620734369ae99e7895fa6570d7faca6 (diff)
Detect unknown batch size in predictions dict
PiperOrigin-RevId: 197606059
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py21
1 files changed, 17 insertions, 4 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 77d117ba78..f0c7564175 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -1264,13 +1264,11 @@ class _ModelFnWrapper(object):
'estimator_spec used by TPU prediction must have type'
'`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))
+ self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions)
+
captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
to_record = {}
identity_fn = lambda **kwargs: kwargs
- # TODO(xiejw): Adds validation for prediction dictionrary.
- # TODO(xiejw): Adds support for single tensor as predictions.
- if not isinstance(tpu_estimator_spec.predictions, dict):
- raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.')
to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions]
to_record['signals'] = [identity_fn, stopping_signals]
if tpu_estimator_spec.host_call is not None:
@@ -1282,6 +1280,21 @@ class _ModelFnWrapper(object):
return predict_step, host_calls, captured_scaffold_fn
+ def _verify_tpu_spec_predictions(self, predictions):
+ """Validates TPUEstimatorSpec.predictions dict."""
+ # TODO(xiejw): Adds validation for prediction dictionrary.
+ # TODO(xiejw): Adds support for single tensor as predictions.
+ if not isinstance(predictions, dict):
+ raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.')
+
+ for (key, tensor) in predictions.items():
+ if tensor.shape[0].value is None:
+ raise ValueError(
+ 'The tensor with key ({}) in TPUEstimatorSpec.predictions has '
+ 'dynamic shape (should be static). Tensor: {}'.format(
+ key, tensor))
+ return predictions
+
def _call_model_fn(self, features, labels, is_export_mode=False):
"""Calls the model_fn with required parameters."""
model_fn_args = function_utils.fn_args(self._model_fn)