aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-06-08 09:00:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-08 09:02:36 -0700
commitef1555172d452539d749340cdb076f0a24f6c505 (patch)
treea291fef41525ec77f68fd0b6c987475807cc9852 /tensorflow/contrib/training
parent7b5d9e86e77bb750d5b794f1673fc08d4d289ec7 (diff)
[tf.data] Improve the error message for `Dataset.padded_batch()`.
Previously, we accepted the `padded_shapes` argument without validating that it was compatible with the `input_dataset.output_shapes`. In many cases, we have enough static shape information to do this, and so we now raise an actionable error at the point where the mistake is committed, rather than at runtime. PiperOrigin-RevId: 199800348
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r--tensorflow/contrib/training/python/training/tensor_queue_dataset.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
index 409aba817c..a2444934bc 100644
--- a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
+++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import convert
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
@@ -45,14 +46,14 @@ class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset):
self._input_dataset = input_dataset
self._batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
- # pylint: disable=protected-access
if padded_shapes is None:
self._padded_shapes = nest.map_structure(
- dataset_ops._partial_shape_to_tensor, input_dataset.output_shapes)
+ convert.partial_shape_to_tensor, input_dataset.output_shapes)
else:
self._padded_shapes = nest.map_structure_up_to(
- input_dataset.output_shapes, dataset_ops._partial_shape_to_tensor,
+ input_dataset.output_shapes, convert.partial_shape_to_tensor,
padded_shapes)
+ # pylint: disable=protected-access
padding_values = (
padding_values if padding_values is not None else
dataset_ops._default_padding(input_dataset))