diff options
author | Derek Murray <mrry@google.com> | 2018-06-08 09:00:06 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-08 09:02:36 -0700 |
commit | ef1555172d452539d749340cdb076f0a24f6c505 (patch) | |
tree | a291fef41525ec77f68fd0b6c987475807cc9852 /tensorflow/contrib/training | |
parent | 7b5d9e86e77bb750d5b794f1673fc08d4d289ec7 (diff) |
[tf.data] Improve the error message for `Dataset.padded_batch()`.
Previously, we accepted the `padded_shapes` argument without validating that
it was compatible with the `input_dataset.output_shapes`. In many cases, we have
enough static shape information to do this, and so we now raise an actionable
error at the point where the mistake is committed, rather than at runtime.
PiperOrigin-RevId: 199800348
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r-- | tensorflow/contrib/training/python/training/tensor_queue_dataset.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py index 409aba817c..a2444934bc 100644 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py +++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -45,14 +46,14 @@ class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset): self._input_dataset = input_dataset self._batch_size = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") - # pylint: disable=protected-access if padded_shapes is None: self._padded_shapes = nest.map_structure( - dataset_ops._partial_shape_to_tensor, input_dataset.output_shapes) + convert.partial_shape_to_tensor, input_dataset.output_shapes) else: self._padded_shapes = nest.map_structure_up_to( - input_dataset.output_shapes, dataset_ops._partial_shape_to_tensor, + input_dataset.output_shapes, convert.partial_shape_to_tensor, padded_shapes) + # pylint: disable=protected-access padding_values = ( padding_values if padding_values is not None else dataset_ops._default_padding(input_dataset)) |