aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/bucketing_test.py')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py276
1 files changed, 207 insertions, 69 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index 2022c1f2bd..48971f2ccc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -40,7 +40,7 @@ class GroupByReducerTest(test.TestCase):
def checkResults(self, dataset, shapes, values):
self.assertEqual(shapes, dataset.output_shapes)
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for expected in values:
got = sess.run(get_next)
self.assertEqual(got, expected)
@@ -129,7 +129,7 @@ class GroupByReducerTest(test.TestCase):
self.assertIs(None, dataset.output_shapes[1].ndims)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x, y = sess.run(get_next)
self.assertAllEqual([0] * (2**i), x)
self.assertAllEqual(np.array(1, ndmin=i), y)
@@ -192,7 +192,7 @@ class GroupByReducerTest(test.TestCase):
(dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x, y = sess.run(get_next)
self.assertAllEqual(x, np.asarray([x for x in range(10)]))
self.assertEqual(y, 45)
@@ -210,7 +210,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
@@ -237,7 +237,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# The input is infinite, so this test demonstrates that:
# 1. We produce output without having to consume the entire input,
@@ -258,7 +258,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
@@ -275,7 +275,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -301,7 +301,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
@@ -329,7 +329,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
@@ -376,7 +376,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
which_bucket, bucketed_values = sess.run(get_next)
@@ -411,7 +411,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# Get two minibatches (one containing even values, one containing odds)
@@ -482,7 +482,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
@@ -515,7 +515,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.OutOfRangeError):
batches = 0
@@ -531,6 +531,45 @@ class BucketTest(test.TestCase):
self.assertEqual(batches, 15)
+def _element_length_fn(x, y=None):
+ del y
+ return array_ops.shape(x)[0]
+
+
+def _to_sparse_tensor(record):
+ return sparse_tensor.SparseTensor(**record)
+
+
+def _format_record(array, sparse):
+ if sparse:
+ return {
+ "values": array,
+ "indices": [[i] for i in range(len(array))],
+ "dense_shape": (len(array),)
+ }
+ return array
+
+
+def _get_record_type(sparse):
+ if sparse:
+ return {
+ "values": dtypes.int64,
+ "indices": dtypes.int64,
+ "dense_shape": dtypes.int64
+ }
+ return dtypes.int32
+
+
+def _get_record_shape(sparse):
+ if sparse:
+ return {
+ "values": tensor_shape.TensorShape([None,]),
+ "indices": tensor_shape.TensorShape([None, 1]),
+ "dense_shape": tensor_shape.TensorShape([1,])
+ }
+ return tensor_shape.TensorShape([None])
+
+
class BucketBySequenceLength(test.TestCase):
def testBucket(self):
@@ -539,39 +578,58 @@ class BucketBySequenceLength(test.TestCase):
batch_sizes = [10, 8, 4, 2]
lengths = [8, 13, 25, 35]
- def element_gen():
- # Produce 1 batch for each bucket
- elements = []
- for batch_size, length in zip(batch_sizes, lengths):
- for _ in range(batch_size):
- elements.append([1] * length)
- random.shuffle(elements)
- for el in elements:
- yield (el,)
-
- element_len = lambda el: array_ops.shape(el)[0]
- dataset = dataset_ops.Dataset.from_generator(
- element_gen, (dtypes.int64,), ([None],)).apply(
- grouping.bucket_by_sequence_length(
- element_len, boundaries, batch_sizes))
- batch, = dataset.make_one_shot_iterator().get_next()
-
- with self.test_session() as sess:
- batches = []
- for _ in range(4):
- batches.append(sess.run(batch))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(batch)
- batch_sizes_val = []
- lengths_val = []
- for batch in batches:
- batch_size = batch.shape[0]
- length = batch.shape[1]
- batch_sizes_val.append(batch_size)
- lengths_val.append(length)
- self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
- self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
- self.assertEqual(sorted(lengths), sorted(lengths_val))
+ def build_dataset(sparse):
+ def _generator():
+ # Produce 1 batch for each bucket
+ elements = []
+ for batch_size, length in zip(batch_sizes, lengths):
+ record_len = length - 1
+ for _ in range(batch_size):
+ elements.append([1] * record_len)
+ record_len = length
+ random.shuffle(elements)
+ for el in elements:
+ yield (_format_record(el, sparse),)
+ dataset = dataset_ops.Dataset.from_generator(
+ _generator,
+ (_get_record_type(sparse),),
+ (_get_record_shape(sparse),))
+ if sparse:
+ dataset = dataset.map(lambda x: (_to_sparse_tensor(x),))
+ return dataset
+
+ def _test_bucket_by_padding(no_padding):
+ dataset = build_dataset(sparse=no_padding)
+ dataset = dataset.apply(
+ grouping.bucket_by_sequence_length(
+ _element_length_fn,
+ boundaries,
+ batch_sizes,
+ no_padding=no_padding))
+ batch, = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ batches = []
+ for _ in range(4):
+ batches.append(sess.run(batch))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(batch)
+ batch_sizes_val = []
+ lengths_val = []
+ for batch in batches:
+ shape = batch.dense_shape if no_padding else batch.shape
+ batch_size = shape[0]
+ length = shape[1]
+ batch_sizes_val.append(batch_size)
+ lengths_val.append(length)
+ sum_check = batch.values.sum() if no_padding else batch.sum()
+ self.assertEqual(sum_check, batch_size * length - 1)
+ self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
+ self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
+ self.assertEqual(sorted(lengths), sorted(lengths_val))
+
+ for no_padding in (True, False):
+ _test_bucket_by_padding(no_padding)
def testPadToBoundary(self):
@@ -600,7 +658,7 @@ class BucketBySequenceLength(test.TestCase):
pad_to_bucket_boundary=True))
batch, = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batches = []
for _ in range(3):
batches.append(sess.run(batch))
@@ -637,7 +695,7 @@ class BucketBySequenceLength(test.TestCase):
pad_to_bucket_boundary=True))
batch, = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batches = []
for _ in range(5):
batches.append(sess.run(batch))
@@ -657,28 +715,108 @@ class BucketBySequenceLength(test.TestCase):
def testTupleElements(self):
- def elements_gen():
- text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
- label = [1, 2, 1, 2]
- for x, y in zip(text, label):
- yield (x, y)
-
- def element_length_fn(x, y):
- del y
- return array_ops.shape(x)[0]
-
- dataset = dataset_ops.Dataset.from_generator(
- generator=elements_gen,
- output_shapes=(tensor_shape.TensorShape([None]),
- tensor_shape.TensorShape([])),
- output_types=(dtypes.int32, dtypes.int32))
+ def build_dataset(sparse):
+ def _generator():
+ text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
+ label = [1, 2, 1, 2]
+ for x, y in zip(text, label):
+ yield (_format_record(x, sparse), y)
+ dataset = dataset_ops.Dataset.from_generator(
+ generator=_generator,
+ output_types=(_get_record_type(sparse), dtypes.int32),
+ output_shapes=(_get_record_shape(sparse),
+ tensor_shape.TensorShape([])))
+ if sparse:
+ dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y))
+ return dataset
+
+ def _test_tuple_elements_by_padding(no_padding):
+ dataset = build_dataset(sparse=no_padding)
+ dataset = dataset.apply(grouping.bucket_by_sequence_length(
+ element_length_func=_element_length_fn,
+ bucket_batch_sizes=[2, 2, 2],
+ bucket_boundaries=[0, 8],
+ no_padding=no_padding))
+ shapes = dataset.output_shapes
+ self.assertEqual([None, None], shapes[0].as_list())
+ self.assertEqual([None], shapes[1].as_list())
+
+ for no_padding in (True, False):
+ _test_tuple_elements_by_padding(no_padding)
+
+ def testBucketSparse(self):
+ """Tests bucketing of sparse tensors (case where `no_padding` == True).
+
+ Test runs on following dataset:
+ [
+ [0],
+ [0, 1],
+ [0, 1, 2]
+ ...
+ [0, ..., max_len - 1]
+ ]
+ Sequences are bucketed by length and batched with
+ `batch_size` < `bucket_size`.
+ """
+
+ min_len = 0
+ max_len = 100
+ batch_size = 7
+ bucket_size = 10
+
+ def _build_dataset():
+ input_data = [range(i+1) for i in range(min_len, max_len)]
+ def generator_fn():
+ for record in input_data:
+ yield _format_record(record, sparse=True)
+ dataset = dataset_ops.Dataset.from_generator(
+ generator=generator_fn,
+ output_types=_get_record_type(sparse=True))
+ dataset = dataset.map(_to_sparse_tensor)
+ return dataset
+
+ def _compute_expected_batches():
+ """Computes expected batch outputs and stores in a set."""
+ all_expected_sparse_tensors = set()
+ for bucket_start_len in range(min_len, max_len, bucket_size):
+ for batch_offset in range(0, bucket_size, batch_size):
+ batch_start_len = bucket_start_len + batch_offset
+ batch_end_len = min(batch_start_len + batch_size,
+ bucket_start_len + bucket_size)
+ expected_indices = []
+ expected_values = []
+ for length in range(batch_start_len, batch_end_len):
+ for val in range(length + 1):
+ expected_indices.append((length - batch_start_len, val))
+ expected_values.append(val)
+ expected_sprs_tensor = (tuple(expected_indices),
+ tuple(expected_values))
+ all_expected_sparse_tensors.add(expected_sprs_tensor)
+ return all_expected_sparse_tensors
+
+ def _compute_batches(dataset):
+ """Computes actual batch outputs of dataset and stores in a set."""
+ batch = dataset.make_one_shot_iterator().get_next()
+ all_sparse_tensors = set()
+ with self.cached_session() as sess:
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ output = sess.run(batch)
+ sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
+ tuple(output.values))
+ all_sparse_tensors.add(sprs_tensor)
+ return all_sparse_tensors
+
+ dataset = _build_dataset()
+ boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
dataset = dataset.apply(grouping.bucket_by_sequence_length(
- element_length_func=element_length_fn,
- bucket_batch_sizes=[2, 2, 2],
- bucket_boundaries=[0, 8]))
- shapes = dataset.output_shapes
- self.assertEqual([None, None], shapes[0].as_list())
- self.assertEqual([None], shapes[1].as_list())
+ _element_length_fn,
+ boundaries,
+ [batch_size] * (len(boundaries) + 1),
+ no_padding=True))
+ batches = _compute_batches(dataset)
+ expected_batches = _compute_expected_batches()
+ self.assertEqual(batches, expected_batches)
if __name__ == "__main__":