diff options
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/bucketing_test.py')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/bucketing_test.py | 276 |
1 files changed, 207 insertions, 69 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 2022c1f2bd..48971f2ccc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -40,7 +40,7 @@ class GroupByReducerTest(test.TestCase): def checkResults(self, dataset, shapes, values): self.assertEqual(shapes, dataset.output_shapes) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for expected in values: got = sess.run(get_next) self.assertEqual(got, expected) @@ -129,7 +129,7 @@ class GroupByReducerTest(test.TestCase): self.assertIs(None, dataset.output_shapes[1].ndims) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: x, y = sess.run(get_next) self.assertAllEqual([0] * (2**i), x) self.assertAllEqual(np.array(1, ndmin=i), y) @@ -192,7 +192,7 @@ class GroupByReducerTest(test.TestCase): (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply( grouping.group_by_reducer(lambda x, y: np.int64(0), reducer)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: x, y = sess.run(get_next) self.assertAllEqual(x, np.asarray([x for x in range(10)])) self.assertEqual(y, 45) @@ -210,7 +210,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) counts = [] with self.assertRaises(errors.OutOfRangeError): @@ -237,7 +237,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # The input is infinite, so this test demonstrates that: # 1. We produce output without having to consume the entire input, @@ -258,7 +258,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) self.assertAllEqual([0, 0, 0, 0], sess.run(get_next)) self.assertAllEqual([1, 1, 1, 1], sess.run(get_next)) @@ -275,7 +275,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaisesRegexp( errors.InvalidArgumentError, @@ -301,7 +301,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) @@ -329,7 +329,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) counts = [] with self.assertRaises(errors.OutOfRangeError): @@ -376,7 +376,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) which_bucket, bucketed_values = sess.run(get_next) @@ -411,7 +411,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # Get two minibatches (one containing even values, one containing odds) @@ -482,7 +482,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # Get two minibatches ([0, 2, ...] and [64, 66, ...]) @@ -515,7 +515,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.OutOfRangeError): batches = 0 @@ -531,6 +531,45 @@ class BucketTest(test.TestCase): self.assertEqual(batches, 15) +def _element_length_fn(x, y=None): + del y + return array_ops.shape(x)[0] + + +def _to_sparse_tensor(record): + return sparse_tensor.SparseTensor(**record) + + +def _format_record(array, sparse): + if sparse: + return { + "values": array, + "indices": [[i] for i in range(len(array))], + "dense_shape": (len(array),) + } + return array + + +def _get_record_type(sparse): + if sparse: + return { + "values": dtypes.int64, + "indices": dtypes.int64, + "dense_shape": dtypes.int64 + } + return dtypes.int32 + + +def _get_record_shape(sparse): + if sparse: + return { + "values": tensor_shape.TensorShape([None,]), + "indices": tensor_shape.TensorShape([None, 1]), + "dense_shape": tensor_shape.TensorShape([1,]) + } + return tensor_shape.TensorShape([None]) + + class BucketBySequenceLength(test.TestCase): def testBucket(self): @@ -539,39 +578,58 @@ class BucketBySequenceLength(test.TestCase): batch_sizes = [10, 8, 4, 2] lengths = [8, 13, 25, 35] - def element_gen(): - # Produce 1 batch for each bucket - elements = [] - for batch_size, length in zip(batch_sizes, lengths): - for _ in range(batch_size): - elements.append([1] * length) - random.shuffle(elements) - for el in elements: - yield (el,) - - element_len = lambda el: array_ops.shape(el)[0] - dataset = dataset_ops.Dataset.from_generator( - element_gen, (dtypes.int64,), ([None],)).apply( - grouping.bucket_by_sequence_length( - element_len, boundaries, batch_sizes)) - batch, = dataset.make_one_shot_iterator().get_next() - - with self.test_session() as sess: - batches = [] - for _ in range(4): - batches.append(sess.run(batch)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(batch) - batch_sizes_val = [] - lengths_val = [] - for batch in batches: - batch_size = batch.shape[0] - length = batch.shape[1] - batch_sizes_val.append(batch_size) - lengths_val.append(length) - self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) - self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) - self.assertEqual(sorted(lengths), sorted(lengths_val)) + def build_dataset(sparse): + def _generator(): + # Produce 1 batch for each bucket + elements = [] + for batch_size, length in zip(batch_sizes, lengths): + record_len = length - 1 + for _ in range(batch_size): + elements.append([1] * record_len) + record_len = length + random.shuffle(elements) + for el in elements: + yield (_format_record(el, sparse),) + dataset = dataset_ops.Dataset.from_generator( + _generator, + (_get_record_type(sparse),), + (_get_record_shape(sparse),)) + if sparse: + dataset = dataset.map(lambda x: (_to_sparse_tensor(x),)) + return dataset + + def _test_bucket_by_padding(no_padding): + dataset = build_dataset(sparse=no_padding) + dataset = dataset.apply( + grouping.bucket_by_sequence_length( + _element_length_fn, + boundaries, + batch_sizes, + no_padding=no_padding)) + batch, = dataset.make_one_shot_iterator().get_next() + + with self.cached_session() as sess: + batches = [] + for _ in range(4): + batches.append(sess.run(batch)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(batch) + batch_sizes_val = [] + lengths_val = [] + for batch in batches: + shape = batch.dense_shape if no_padding else batch.shape + batch_size = shape[0] + length = shape[1] + batch_sizes_val.append(batch_size) + lengths_val.append(length) + sum_check = batch.values.sum() if no_padding else batch.sum() + self.assertEqual(sum_check, batch_size * length - 1) + self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) + self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) + self.assertEqual(sorted(lengths), sorted(lengths_val)) + + for no_padding in (True, False): + _test_bucket_by_padding(no_padding) def testPadToBoundary(self): @@ -600,7 +658,7 @@ class BucketBySequenceLength(test.TestCase): pad_to_bucket_boundary=True)) batch, = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: batches = [] for _ in range(3): batches.append(sess.run(batch)) @@ -637,7 +695,7 @@ class BucketBySequenceLength(test.TestCase): pad_to_bucket_boundary=True)) batch, = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: batches = [] for _ in range(5): batches.append(sess.run(batch)) @@ -657,28 +715,108 @@ class BucketBySequenceLength(test.TestCase): def testTupleElements(self): - def elements_gen(): - text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]] - label = [1, 2, 1, 2] - for x, y in zip(text, label): - yield (x, y) - - def element_length_fn(x, y): - del y - return array_ops.shape(x)[0] - - dataset = dataset_ops.Dataset.from_generator( - generator=elements_gen, - output_shapes=(tensor_shape.TensorShape([None]), - tensor_shape.TensorShape([])), - output_types=(dtypes.int32, dtypes.int32)) + def build_dataset(sparse): + def _generator(): + text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]] + label = [1, 2, 1, 2] + for x, y in zip(text, label): + yield (_format_record(x, sparse), y) + dataset = dataset_ops.Dataset.from_generator( + generator=_generator, + output_types=(_get_record_type(sparse), dtypes.int32), + output_shapes=(_get_record_shape(sparse), + tensor_shape.TensorShape([]))) + if sparse: + dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y)) + return dataset + + def _test_tuple_elements_by_padding(no_padding): + dataset = build_dataset(sparse=no_padding) + dataset = dataset.apply(grouping.bucket_by_sequence_length( + element_length_func=_element_length_fn, + bucket_batch_sizes=[2, 2, 2], + bucket_boundaries=[0, 8], + no_padding=no_padding)) + shapes = dataset.output_shapes + self.assertEqual([None, None], shapes[0].as_list()) + self.assertEqual([None], shapes[1].as_list()) + + for no_padding in (True, False): + _test_tuple_elements_by_padding(no_padding) + + def testBucketSparse(self): + """Tests bucketing of sparse tensors (case where `no_padding` == True). + + Test runs on following dataset: + [ + [0], + [0, 1], + [0, 1, 2] + ... + [0, ..., max_len - 1] + ] + Sequences are bucketed by length and batched with + `batch_size` < `bucket_size`. + """ + + min_len = 0 + max_len = 100 + batch_size = 7 + bucket_size = 10 + + def _build_dataset(): + input_data = [range(i+1) for i in range(min_len, max_len)] + def generator_fn(): + for record in input_data: + yield _format_record(record, sparse=True) + dataset = dataset_ops.Dataset.from_generator( + generator=generator_fn, + output_types=_get_record_type(sparse=True)) + dataset = dataset.map(_to_sparse_tensor) + return dataset + + def _compute_expected_batches(): + """Computes expected batch outputs and stores in a set.""" + all_expected_sparse_tensors = set() + for bucket_start_len in range(min_len, max_len, bucket_size): + for batch_offset in range(0, bucket_size, batch_size): + batch_start_len = bucket_start_len + batch_offset + batch_end_len = min(batch_start_len + batch_size, + bucket_start_len + bucket_size) + expected_indices = [] + expected_values = [] + for length in range(batch_start_len, batch_end_len): + for val in range(length + 1): + expected_indices.append((length - batch_start_len, val)) + expected_values.append(val) + expected_sprs_tensor = (tuple(expected_indices), + tuple(expected_values)) + all_expected_sparse_tensors.add(expected_sprs_tensor) + return all_expected_sparse_tensors + + def _compute_batches(dataset): + """Computes actual batch outputs of dataset and stores in a set.""" + batch = dataset.make_one_shot_iterator().get_next() + all_sparse_tensors = set() + with self.cached_session() as sess: + with self.assertRaises(errors.OutOfRangeError): + while True: + output = sess.run(batch) + sprs_tensor = (tuple([tuple(idx) for idx in output.indices]), + tuple(output.values)) + all_sparse_tensors.add(sprs_tensor) + return all_sparse_tensors + + dataset = _build_dataset() + boundaries = range(min_len + bucket_size + 1, max_len, bucket_size) dataset = dataset.apply(grouping.bucket_by_sequence_length( - element_length_func=element_length_fn, - bucket_batch_sizes=[2, 2, 2], - bucket_boundaries=[0, 8])) - shapes = dataset.output_shapes - self.assertEqual([None, None], shapes[0].as_list()) - self.assertEqual([None], shapes[1].as_list()) + _element_length_fn, + boundaries, + [batch_size] * (len(boundaries) + 1), + no_padding=True)) + batches = _compute_batches(dataset) + expected_batches = _compute_expected_batches() + self.assertEqual(batches, expected_batches) if __name__ == "__main__": |