aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-03-14 18:39:25 -0700
committerGravatar Derek Murray <mrry@google.com>2018-03-16 12:07:48 -0700
commit6655570f12dba22fe752796471635e109e682056 (patch)
treea6be19200d10702d32129f93d94d32aee2acde4e
parent89e77c4ed37af1fae283e27f4de62578fa9b96fc (diff)
[tf.data] Fix Python shape inference for `tf.contrib.data.map_and_batch()`.
Previously, it would incorrectly report that all batches have the same size, not accounting for the possibility of the last batch being partial. Fixes #17720. PiperOrigin-RevId: 189121488
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py14
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py3
2 files changed, 15 insertions, 2 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 71dc1c1172..a2da953c7b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -387,6 +387,20 @@ class BatchDatasetTest(test.TestCase):
def testBatchAndMapDatasetWithParallelBatching(self):
return self._testBatchAndMapDatasetHelper(num_parallel_batches=10)
+ def testMapAndBatchYieldsPartialBatch(self):
+ iterator = (dataset_ops.Dataset.range(10)
+ .apply(batching.map_and_batch(
+ lambda x: array_ops.reshape(x * x, [1]), 4))
+ .make_one_shot_iterator())
+ self.assertEqual([None, 1], iterator.output_shapes.as_list())
+ next_element = iterator.get_next()
+ with self.test_session() as sess:
+ self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
+ self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
+ self.assertAllEqual([[64], [81]], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
def testMapAndBatchSparse(self):
def _sparse(i):
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 6eb512dec6..6463d75750 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -374,8 +374,7 @@ class _MapAndBatchDataset(dataset_ops.MapDataset):
@property
def output_shapes(self):
return nest.pack_sequence_as(self._output_shapes, [
- tensor_shape.vector(tensor_util.constant_value(
- self._batch_size)).concatenate(s)
+ tensor_shape.vector(None).concatenate(s)
for s in nest.flatten(self._output_shapes)
])