diff options
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/bucketing_test.py')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/bucketing_test.py | 39 |
1 files changed, 38 insertions, 1 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 5fc7e51d81..2022c1f2bd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -616,7 +616,44 @@ class BucketBySequenceLength(test.TestCase): batch_sizes = batch_sizes[:-1] self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) - self.assertEqual(sorted(boundaries), sorted(lengths_val)) + self.assertEqual([boundary - 1 for boundary in sorted(boundaries)], + sorted(lengths_val)) + + def testPadToBoundaryNoExtraneousPadding(self): + + boundaries = [3, 7, 11] + batch_sizes = [2, 2, 2, 2] + lengths = range(1, 11) + + def element_gen(): + for length in lengths: + yield ([1] * length,) + + element_len = lambda element: array_ops.shape(element)[0] + dataset = dataset_ops.Dataset.from_generator( + element_gen, (dtypes.int64,), ([None],)).apply( + grouping.bucket_by_sequence_length( + element_len, boundaries, batch_sizes, + pad_to_bucket_boundary=True)) + batch, = dataset.make_one_shot_iterator().get_next() + + with self.test_session() as sess: + batches = [] + for _ in range(5): + batches.append(sess.run(batch)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(batch) + + self.assertAllEqual(batches[0], [[1, 0], + [1, 1]]) + self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0]]) + self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1]]) + self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) def testTupleElements(self): |