diff options
author | 2018-03-14 18:39:25 -0700 | |
---|---|---|
committer | 2018-03-16 12:07:48 -0700 | |
commit | 6655570f12dba22fe752796471635e109e682056 (patch) | |
tree | a6be19200d10702d32129f93d94d32aee2acde4e | |
parent | 89e77c4ed37af1fae283e27f4de62578fa9b96fc (diff) |
[tf.data] Fix Python shape inference for `tf.contrib.data.map_and_batch()`.
Previously, it would incorrectly report that all batches have the same size, not accounting for the possibility of the last batch being partial.
Fixes #17720.
PiperOrigin-RevId: 189121488
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py | 14 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/batching.py | 3 |
2 files changed, 15 insertions, 2 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 71dc1c1172..a2da953c7b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -387,6 +387,20 @@ class BatchDatasetTest(test.TestCase): def testBatchAndMapDatasetWithParallelBatching(self): return self._testBatchAndMapDatasetHelper(num_parallel_batches=10) + def testMapAndBatchYieldsPartialBatch(self): + iterator = (dataset_ops.Dataset.range(10) + .apply(batching.map_and_batch( + lambda x: array_ops.reshape(x * x, [1]), 4)) + .make_one_shot_iterator()) + self.assertEqual([None, 1], iterator.output_shapes.as_list()) + next_element = iterator.get_next() + with self.test_session() as sess: + self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) + self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) + self.assertAllEqual([[64], [81]], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + def testMapAndBatchSparse(self): def _sparse(i): diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 6eb512dec6..6463d75750 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -374,8 +374,7 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): @property def output_shapes(self): return nest.pack_sequence_as(self._output_shapes, [ - tensor_shape.vector(tensor_util.constant_value( - self._batch_size)).concatenate(s) + tensor_shape.vector(None).concatenate(s) for s in nest.flatten(self._output_shapes) ]) |