aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-12 12:28:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 12:33:10 -0700
commitf337425dc71e3ea95aa91ce401a40c1b594486ca (patch)
tree48b112cbbe6e37ed14c04df3ab3d0c3460960f8c /tensorflow/contrib/data
parentf02a2fad042bc401f3d1c89a9fd52e40ca5d1835 (diff)
Added ability to bucket without padding, as sparse tensors to `bucket_by_sequence_length`.
PiperOrigin-RevId: 212684420
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py174
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py9
3 files changed, 144 insertions, 40 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index b3c90ded39..1f947e97f9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -44,6 +44,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data/python/ops:grouping",
+ "//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index 293be2bd06..94718bb477 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -21,6 +21,7 @@ import random
import numpy as np
+from tensorflow.contrib import layers
from tensorflow.contrib.data.python.ops import grouping
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
@@ -531,6 +532,11 @@ class BucketTest(test.TestCase):
self.assertEqual(batches, 15)
+def _element_length_fn(x, y=None):
+ del y
+ return array_ops.shape(x)[0]
+
+
class BucketBySequenceLength(test.TestCase):
def testBucket(self):
@@ -543,35 +549,49 @@ class BucketBySequenceLength(test.TestCase):
# Produce 1 batch for each bucket
elements = []
for batch_size, length in zip(batch_sizes, lengths):
+ record_len = length - 1
for _ in range(batch_size):
- elements.append([1] * length)
+ elements.append([1] * record_len)
+ record_len = length
random.shuffle(elements)
for el in elements:
yield (el,)
- element_len = lambda el: array_ops.shape(el)[0]
- dataset = dataset_ops.Dataset.from_generator(
- element_gen, (dtypes.int64,), ([None],)).apply(
- grouping.bucket_by_sequence_length(
- element_len, boundaries, batch_sizes))
- batch, = dataset.make_one_shot_iterator().get_next()
+ def _test_bucket_by_padding(no_padding):
+ dataset = dataset_ops.Dataset.from_generator(
+ element_gen, (dtypes.int64,), ([None],))
+ if no_padding:
+ dataset = dataset.map(lambda x: (layers.dense_to_sparse(x),))
+ dataset = dataset.apply(
+ grouping.bucket_by_sequence_length(
+ _element_length_fn,
+ boundaries,
+ batch_sizes,
+ no_padding=no_padding))
+ batch, = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- batches = []
- for _ in range(4):
- batches.append(sess.run(batch))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(batch)
- batch_sizes_val = []
- lengths_val = []
- for batch in batches:
- batch_size = batch.shape[0]
- length = batch.shape[1]
- batch_sizes_val.append(batch_size)
- lengths_val.append(length)
- self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
- self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
- self.assertEqual(sorted(lengths), sorted(lengths_val))
+ with self.cached_session() as sess:
+ batches = []
+ for _ in range(4):
+ batches.append(sess.run(batch))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(batch)
+ batch_sizes_val = []
+ lengths_val = []
+ for batch in batches:
+ shape = batch.dense_shape if no_padding else batch.shape
+ batch_size = shape[0]
+ length = shape[1]
+ batch_sizes_val.append(batch_size)
+ lengths_val.append(length)
+ sum_check = batch.values.sum() if no_padding else batch.sum()
+ self.assertEqual(sum_check, batch_size * length - 1)
+ self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
+ self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
+ self.assertEqual(sorted(lengths), sorted(lengths_val))
+
+ for no_padding in (True, False):
+ _test_bucket_by_padding(no_padding)
def testPadToBoundary(self):
@@ -663,22 +683,100 @@ class BucketBySequenceLength(test.TestCase):
for x, y in zip(text, label):
yield (x, y)
- def element_length_fn(x, y):
- del y
- return array_ops.shape(x)[0]
-
- dataset = dataset_ops.Dataset.from_generator(
- generator=elements_gen,
- output_shapes=(tensor_shape.TensorShape([None]),
- tensor_shape.TensorShape([])),
- output_types=(dtypes.int32, dtypes.int32))
+ def _test_tuple_elements_by_padding(no_padding):
+ dataset = dataset_ops.Dataset.from_generator(
+ generator=elements_gen,
+ output_shapes=(tensor_shape.TensorShape([None]),
+ tensor_shape.TensorShape([])),
+ output_types=(dtypes.int32, dtypes.int32))
+ if no_padding:
+ dataset = dataset.map(lambda x, y: (layers.dense_to_sparse(x), y))
+ dataset = dataset.apply(grouping.bucket_by_sequence_length(
+ element_length_func=_element_length_fn,
+ bucket_batch_sizes=[2, 2, 2],
+ bucket_boundaries=[0, 8],
+ no_padding=no_padding))
+ shapes = dataset.output_shapes
+ self.assertEqual([None, None], shapes[0].as_list())
+ self.assertEqual([None], shapes[1].as_list())
+
+ for no_padding in (True, False):
+ _test_tuple_elements_by_padding(no_padding)
+
+ def testBucketSparse(self):
+ """Tests bucketing of sparse tensors (case where `no_padding` == True).
+
+ Test runs on following dataset:
+ [
+ [0],
+ [0, 1],
+ [0, 1, 2]
+ ...
+ [0, ..., max_len - 1]
+ ]
+ Sequences are bucketed by length and batched with
+ `batch_size` < `bucket_size`.
+ """
+
+ min_len = 0
+ max_len = 100
+ batch_size = 7
+ bucket_size = 10
+
+ def _build_dataset():
+ input_data = [range(i+1) for i in range(min_len, max_len)]
+ def generator_fn():
+ for record in input_data:
+ yield record
+ dataset = dataset_ops.Dataset.from_generator(
+ generator=generator_fn,
+ output_shapes=(tensor_shape.TensorShape([None])),
+ output_types=(dtypes.int64))
+ dataset = dataset.map(lambda x: layers.dense_to_sparse(x, eos_token=-1))
+ return dataset
+
+ def _compute_expected_batches():
+ """Computes expected batch outputs and stores in a set."""
+ all_expected_sparse_tensors = set()
+ for bucket_start_len in range(min_len, max_len, bucket_size):
+ for batch_offset in range(0, bucket_size, batch_size):
+ batch_start_len = bucket_start_len + batch_offset
+ batch_end_len = min(batch_start_len + batch_size,
+ bucket_start_len + bucket_size)
+ expected_indices = []
+ expected_values = []
+ for length in range(batch_start_len, batch_end_len):
+ for val in range(length + 1):
+ expected_indices.append((length - batch_start_len, val))
+ expected_values.append(val)
+ expected_sprs_tensor = (tuple(expected_indices),
+ tuple(expected_values))
+ all_expected_sparse_tensors.add(expected_sprs_tensor)
+ return all_expected_sparse_tensors
+
+ def _compute_batches(dataset):
+ """Computes actual batch outputs of dataset and stores in a set."""
+ batch = dataset.make_one_shot_iterator().get_next()
+ all_sparse_tensors = set()
+ with self.cached_session() as sess:
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ output = sess.run(batch)
+ sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
+ tuple(output.values))
+ all_sparse_tensors.add(sprs_tensor)
+ return all_sparse_tensors
+
+ dataset = _build_dataset()
+ boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
dataset = dataset.apply(grouping.bucket_by_sequence_length(
- element_length_func=element_length_fn,
- bucket_batch_sizes=[2, 2, 2],
- bucket_boundaries=[0, 8]))
- shapes = dataset.output_shapes
- self.assertEqual([None, None], shapes[0].as_list())
- self.assertEqual([None], shapes[1].as_list())
+ _element_length_fn,
+ boundaries,
+ [batch_size] * (len(boundaries) + 1),
+ no_padding=True))
+ batches = _compute_batches(dataset)
+ expected_batches = _compute_expected_batches()
+ self.assertEqual(batches, expected_batches)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 6edc1d7990..099e10db92 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -124,7 +124,8 @@ def bucket_by_sequence_length(element_length_func,
bucket_batch_sizes,
padded_shapes=None,
padding_values=None,
- pad_to_bucket_boundary=False):
+ pad_to_bucket_boundary=False,
+ no_padding=False):
"""A transformation that buckets elements in a `Dataset` by length.
Elements of the `Dataset` are grouped together by length and then are padded
@@ -152,6 +153,8 @@ def bucket_by_sequence_length(element_length_func,
unknown size to bucket boundary minus 1 (i.e., the maximum length in each
bucket), and caller must ensure that the source `Dataset` does not contain
any elements with length longer than `max(bucket_boundaries)`.
+ no_padding: `bool`, indicates whether to pad the batch features (features
+ need to be either of type `tf.SparseTensor` or of same shape).
Returns:
A `Dataset` transformation function, which can be passed to
@@ -199,7 +202,9 @@ def bucket_by_sequence_length(element_length_func,
def batching_fn(bucket_id, grouped_dataset):
"""Batch elements in dataset."""
- batch_size = batch_sizes[bucket_id]
+ batch_size = window_size_fn(bucket_id)
+ if no_padding:
+ return grouped_dataset.batch(batch_size)
none_filler = None
if pad_to_bucket_boundary:
err_msg = ("When pad_to_bucket_boundary=True, elements must have "