aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 5abb38c2d2..75482f67da 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -413,6 +413,20 @@ class BatchDatasetTest(test.TestCase):
def testMapAndBatchPartialBatchDropRemainder(self):
return self._testMapAndBatchPartialBatchHelper(drop_remainder=True)
+ def testMapAndBatchYieldsPartialBatch(self):
+ iterator = (dataset_ops.Dataset.range(10)
+ .apply(batching.map_and_batch(
+ lambda x: array_ops.reshape(x * x, [1]), 4))
+ .make_one_shot_iterator())
+ self.assertEqual([None, 1], iterator.output_shapes.as_list())
+ next_element = iterator.get_next()
+ with self.test_session() as sess:
+ self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
+ self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
+ self.assertAllEqual([[64], [81]], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
def testMapAndBatchSparse(self):
def _sparse(i):