diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-12 12:28:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-12 12:33:10 -0700 |
commit | f337425dc71e3ea95aa91ce401a40c1b594486ca (patch) | |
tree | 48b112cbbe6e37ed14c04df3ab3d0c3460960f8c /tensorflow/contrib/data | |
parent | f02a2fad042bc401f3d1c89a9fd52e40ca5d1835 (diff) |
Added ability to bucket without padding, as sparse tensors to `bucket_by_sequence_length`.
PiperOrigin-RevId: 212684420
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/bucketing_test.py | 174 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/grouping.py | 9 |
3 files changed, 144 insertions, 40 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index b3c90ded39..1f947e97f9 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -44,6 +44,7 @@ py_test( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/data/python/ops:grouping", + "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 293be2bd06..94718bb477 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -21,6 +21,7 @@ import random import numpy as np +from tensorflow.contrib import layers from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op @@ -531,6 +532,11 @@ class BucketTest(test.TestCase): self.assertEqual(batches, 15) +def _element_length_fn(x, y=None): + del y + return array_ops.shape(x)[0] + + class BucketBySequenceLength(test.TestCase): def testBucket(self): @@ -543,35 +549,49 @@ class BucketBySequenceLength(test.TestCase): # Produce 1 batch for each bucket elements = [] for batch_size, length in zip(batch_sizes, lengths): + record_len = length - 1 for _ in range(batch_size): - elements.append([1] * length) + elements.append([1] * record_len) + record_len = length random.shuffle(elements) for el in elements: yield (el,) - element_len = lambda el: array_ops.shape(el)[0] - dataset = dataset_ops.Dataset.from_generator( - element_gen, (dtypes.int64,), ([None],)).apply( - grouping.bucket_by_sequence_length( - element_len, boundaries, batch_sizes)) - batch, = dataset.make_one_shot_iterator().get_next() + def _test_bucket_by_padding(no_padding): + dataset = dataset_ops.Dataset.from_generator( + element_gen, (dtypes.int64,), ([None],)) + if no_padding: + dataset = dataset.map(lambda x: (layers.dense_to_sparse(x),)) + dataset = dataset.apply( + grouping.bucket_by_sequence_length( + _element_length_fn, + boundaries, + batch_sizes, + no_padding=no_padding)) + batch, = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - batches = [] - for _ in range(4): - batches.append(sess.run(batch)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(batch) - batch_sizes_val = [] - lengths_val = [] - for batch in batches: - batch_size = batch.shape[0] - length = batch.shape[1] - batch_sizes_val.append(batch_size) - lengths_val.append(length) - self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) - self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) - self.assertEqual(sorted(lengths), sorted(lengths_val)) + with self.cached_session() as sess: + batches = [] + for _ in range(4): + batches.append(sess.run(batch)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(batch) + batch_sizes_val = [] + lengths_val = [] + for batch in batches: + shape = batch.dense_shape if no_padding else batch.shape + batch_size = shape[0] + length = shape[1] + batch_sizes_val.append(batch_size) + lengths_val.append(length) + sum_check = batch.values.sum() if no_padding else batch.sum() + self.assertEqual(sum_check, batch_size * length - 1) + self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) + self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) + self.assertEqual(sorted(lengths), sorted(lengths_val)) + + for no_padding in (True, False): + _test_bucket_by_padding(no_padding) def testPadToBoundary(self): @@ -663,22 +683,100 @@ class BucketBySequenceLength(test.TestCase): for x, y in zip(text, label): yield (x, y) - def element_length_fn(x, y): - del y - return array_ops.shape(x)[0] - - dataset = dataset_ops.Dataset.from_generator( - generator=elements_gen, - output_shapes=(tensor_shape.TensorShape([None]), - tensor_shape.TensorShape([])), - output_types=(dtypes.int32, dtypes.int32)) + def _test_tuple_elements_by_padding(no_padding): + dataset = dataset_ops.Dataset.from_generator( + generator=elements_gen, + output_shapes=(tensor_shape.TensorShape([None]), + tensor_shape.TensorShape([])), + output_types=(dtypes.int32, dtypes.int32)) + if no_padding: + dataset = dataset.map(lambda x, y: (layers.dense_to_sparse(x), y)) + dataset = dataset.apply(grouping.bucket_by_sequence_length( + element_length_func=_element_length_fn, + bucket_batch_sizes=[2, 2, 2], + bucket_boundaries=[0, 8], + no_padding=no_padding)) + shapes = dataset.output_shapes + self.assertEqual([None, None], shapes[0].as_list()) + self.assertEqual([None], shapes[1].as_list()) + + for no_padding in (True, False): + _test_tuple_elements_by_padding(no_padding) + + def testBucketSparse(self): + """Tests bucketing of sparse tensors (case where `no_padding` == True). + + Test runs on following dataset: + [ + [0], + [0, 1], + [0, 1, 2] + ... + [0, ..., max_len - 1] + ] + Sequences are bucketed by length and batched with + `batch_size` < `bucket_size`. + """ + + min_len = 0 + max_len = 100 + batch_size = 7 + bucket_size = 10 + + def _build_dataset(): + input_data = [range(i+1) for i in range(min_len, max_len)] + def generator_fn(): + for record in input_data: + yield record + dataset = dataset_ops.Dataset.from_generator( + generator=generator_fn, + output_shapes=(tensor_shape.TensorShape([None])), + output_types=(dtypes.int64)) + dataset = dataset.map(lambda x: layers.dense_to_sparse(x, eos_token=-1)) + return dataset + + def _compute_expected_batches(): + """Computes expected batch outputs and stores in a set.""" + all_expected_sparse_tensors = set() + for bucket_start_len in range(min_len, max_len, bucket_size): + for batch_offset in range(0, bucket_size, batch_size): + batch_start_len = bucket_start_len + batch_offset + batch_end_len = min(batch_start_len + batch_size, + bucket_start_len + bucket_size) + expected_indices = [] + expected_values = [] + for length in range(batch_start_len, batch_end_len): + for val in range(length + 1): + expected_indices.append((length - batch_start_len, val)) + expected_values.append(val) + expected_sprs_tensor = (tuple(expected_indices), + tuple(expected_values)) + all_expected_sparse_tensors.add(expected_sprs_tensor) + return all_expected_sparse_tensors + + def _compute_batches(dataset): + """Computes actual batch outputs of dataset and stores in a set.""" + batch = dataset.make_one_shot_iterator().get_next() + all_sparse_tensors = set() + with self.cached_session() as sess: + with self.assertRaises(errors.OutOfRangeError): + while True: + output = sess.run(batch) + sprs_tensor = (tuple([tuple(idx) for idx in output.indices]), + tuple(output.values)) + all_sparse_tensors.add(sprs_tensor) + return all_sparse_tensors + + dataset = _build_dataset() + boundaries = range(min_len + bucket_size + 1, max_len, bucket_size) dataset = dataset.apply(grouping.bucket_by_sequence_length( - element_length_func=element_length_fn, - bucket_batch_sizes=[2, 2, 2], - bucket_boundaries=[0, 8])) - shapes = dataset.output_shapes - self.assertEqual([None, None], shapes[0].as_list()) - self.assertEqual([None], shapes[1].as_list()) + _element_length_fn, + boundaries, + [batch_size] * (len(boundaries) + 1), + no_padding=True)) + batches = _compute_batches(dataset) + expected_batches = _compute_expected_batches() + self.assertEqual(batches, expected_batches) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 6edc1d7990..099e10db92 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -124,7 +124,8 @@ def bucket_by_sequence_length(element_length_func, bucket_batch_sizes, padded_shapes=None, padding_values=None, - pad_to_bucket_boundary=False): + pad_to_bucket_boundary=False, + no_padding=False): """A transformation that buckets elements in a `Dataset` by length. Elements of the `Dataset` are grouped together by length and then are padded @@ -152,6 +153,8 @@ def bucket_by_sequence_length(element_length_func, unknown size to bucket boundary minus 1 (i.e., the maximum length in each bucket), and caller must ensure that the source `Dataset` does not contain any elements with length longer than `max(bucket_boundaries)`. + no_padding: `bool`, indicates whether to pad the batch features (features + need to be either of type `tf.SparseTensor` or of same shape). Returns: A `Dataset` transformation function, which can be passed to @@ -199,7 +202,9 @@ def bucket_by_sequence_length(element_length_func, def batching_fn(bucket_id, grouped_dataset): """Batch elements in dataset.""" - batch_size = batch_sizes[bucket_id] + batch_size = window_size_fn(bucket_id) + if no_padding: + return grouped_dataset.batch(batch_size) none_filler = None if pad_to_bucket_boundary: err_msg = ("When pad_to_bucket_boundary=True, elements must have " |