diff options
author | Jiri Simsa <jsimsa@google.com> | 2018-09-17 16:31:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-17 16:35:28 -0700 |
commit | 8ef1ece7d0ecdec633a22a8100fdae05cfbacb3e (patch) | |
tree | 7123d7e44983f26da690ac511ceb09b77c067114 /tensorflow/python/data | |
parent | f5116dd366a5bb1d679e1682c13b8fa3c4830a84 (diff) |
[tf.data] Introducing `tf.data.Dataset.window(size, shift, stride, drop_remainder)`, which can be used for combining elements of input dataset into "windows". A window
is itself a finite dataset and, among other things, can be used for generalized batching (see https://github.com/tensorflow/community/pull/5 for details).
PiperOrigin-RevId: 213360134
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r-- | tensorflow/python/data/kernel_tests/BUILD | 17 | ||||
-rw-r--r-- | tensorflow/python/data/kernel_tests/window_dataset_op_test.py | 295 | ||||
-rw-r--r-- | tensorflow/python/data/ops/dataset_ops.py | 93 |
3 files changed, 403 insertions, 2 deletions
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 631b87a718..17d4fec662 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -407,3 +407,20 @@ cuda_py_test( "//tensorflow/python:tensor_shape", ], ) + +tf_py_test( + name = "window_dataset_op_test", + size = "small", + srcs = ["window_dataset_op_test.py"], + additional_deps = [ + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + ], +) diff --git a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py new file mode 100644 index 0000000000..fd4348426d --- /dev/null +++ b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py @@ -0,0 +1,295 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class WindowDatasetTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + ("1", 20, 14, 7, 1), + ("2", 20, 17, 9, 1), + ("3", 20, 14, 14, 1), + ("4", 20, 10, 14, 1), + ("5", 20, 14, 19, 1), + ("6", 20, 4, 1, 2), + ("7", 20, 2, 1, 6), + ("8", 20, 4, 7, 2), + ("9", 20, 2, 7, 6), + ("10", 1, 10, 4, 1), + ("11", 0, 10, 4, 1), + ("12", 20, 14, 7, 1, False), + ("13", 20, 17, 9, 1, False), + ("14", 20, 14, 14, 1, False), + ("15", 20, 10, 14, 1, False), + ("16", 20, 14, 19, 1, False), + ("17", 20, 4, 1, 2, False), + ("18", 20, 2, 1, 6, False), + ("19", 20, 4, 7, 2, False), + ("20", 20, 2, 7, 6, False), + ("21", 1, 10, 4, 1, False), + ("22", 0, 10, 4, 1, False), + ) + def testWindowDataset(self, count, size, shift, stride, drop_remainder=True): + """Tests a dataset that slides a window its input elements.""" + components = (np.arange(7), + np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], + np.array(37.0) * np.arange(7)) + + count_t = array_ops.placeholder(dtypes.int64, shape=[]) + size_t = array_ops.placeholder(dtypes.int64, shape=[]) + shift_t = array_ops.placeholder(dtypes.int64, shape=[]) + stride_t = array_ops.placeholder(dtypes.int64, shape=[]) + drop_remainder_t = array_ops.placeholder(dtypes.bool, shape=[]) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + def _flat_map_fn(x, y, z): + return dataset_ops.Dataset.zip((x.batch(batch_size=size_t), + y.batch(batch_size=size_t), + z.batch(batch_size=size_t))) + + iterator = dataset_ops.Dataset.from_tensor_slices(components).map( + _map_fn).repeat(count).window( + size=size_t, + shift=shift_t, + stride=stride_t, + drop_remainder=drop_remainder_t).flat_map( + _flat_map_fn).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual([[None] + list(c.shape[1:]) for c in components], + [t.shape.as_list() for t in get_next]) + + with self.cached_session() as sess: + sess.run( + init_op, + feed_dict={ + count_t: count, + size_t: size, + shift_t: shift, + stride_t: stride, + drop_remainder_t: drop_remainder + }) + num_full_batches = max( + 0, (count * 7 - ((size - 1) * stride + 1)) // shift + 1) + for i in range(num_full_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(size): + self.assertAllEqual(component[(i * shift + j * stride) % 7]**2, + result_component[j]) + if not drop_remainder: + num_partial_batches = (count * 7) // shift + ( + (count * 7) % shift > 0) - num_full_batches + for i in range(num_partial_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + remaining = (count * 7) - ((num_full_batches + i) * shift) + num_elements = remaining // stride + ((remaining % stride) > 0) + for j in range(num_elements): + self.assertAllEqual( + component[((num_full_batches + i) * shift + j * stride) % 7] + **2, result_component[j]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + @parameterized.named_parameters( + ("1", 14, 0, 3, 1), + ("2", 14, 3, 0, 1), + ("3", 14, 3, 3, 0), + ) + def testWindowDatasetInvalid(self, count, size, shift, stride): + count_t = array_ops.placeholder(dtypes.int64, shape=[]) + size_t = array_ops.placeholder(dtypes.int64, shape=[]) + shift_t = array_ops.placeholder(dtypes.int64, shape=[]) + stride_t = array_ops.placeholder(dtypes.int64, shape=[]) + + iterator = dataset_ops.Dataset.range(10).map(lambda x: x).repeat( + count_t).window( + size=size_t, shift=shift_t, + stride=stride_t).flat_map(lambda x: x.batch(batch_size=size_t) + ).make_initializable_iterator() + init_op = iterator.initializer + + with self.cached_session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run( + init_op, + feed_dict={ + count_t: count, + size_t: size, + shift_t: shift, + stride_t: stride + }) + + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testWindowSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = dataset_ops.Dataset.range(10).map(_sparse).window( + size=5, shift=3, drop_remainder=True).flat_map( + lambda x: x.batch(batch_size=5)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + num_batches = (10 - 5) // 3 + 1 + for i in range(num_batches): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], + values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4], + dense_shape=[5, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testWindowSparseWithDifferentDenseShapes(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=array_ops.expand_dims( + math_ops.range(i, dtype=dtypes.int64), 1), + values=array_ops.fill([math_ops.to_int32(i)], i), + dense_shape=[i]) + + iterator = dataset_ops.Dataset.range(10).map(_sparse).window( + size=5, shift=3, drop_remainder=True).flat_map( + lambda x: x.batch(batch_size=5)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + num_batches = (10 - 5) // 3 + 1 + for i in range(num_batches): + actual = sess.run(get_next) + expected_indices = [] + expected_values = [] + for j in range(5): + for k in range(i * 3 + j): + expected_indices.append([j, k]) + expected_values.append(i * 3 + j) + expected = sparse_tensor.SparseTensorValue( + indices=expected_indices, + values=expected_values, + dense_shape=[5, i * 3 + 5 - 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testNestedWindowSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = dataset_ops.Dataset.range(10).map(_sparse).window( + size=4, shift=2, + drop_remainder=True).flat_map(lambda x: x.batch(batch_size=4)).window( + size=3, shift=1, drop_remainder=True).flat_map( + lambda x: x.batch(batch_size=3)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + # Slide: 1st batch. + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0], + [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0], + [2, 2, 0], [2, 3, 0]], + values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7], + dense_shape=[3, 4, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + # Slide: 2nd batch. + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0], + [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0], + [2, 2, 0], [2, 3, 0]], + values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9], + dense_shape=[3, 4, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testWindowShapeError(self): + + def generator(): + yield [1.0, 2.0, 3.0] + yield [4.0, 5.0, 6.0] + yield [7.0, 8.0, 9.0, 10.0] + + iterator = dataset_ops.Dataset.from_generator( + generator, dtypes.float32, output_shapes=[None]).window( + size=3, shift=1).flat_map( + lambda x: x.batch(batch_size=3)).make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r"Cannot batch tensors with different shapes in component 0. " + r"First element had shape \[3\] and element 2 had shape \[4\]."): + sess.run(next_element) + + def testWindowIgnoreErrors(self): + input_values = np.float32([1., np.nan, 2., np.nan, 3.]) + dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map( + lambda x: array_ops.check_numerics(x, "message")).window( + size=2, shift=2, stride=2, + drop_remainder=True).flat_map(lambda x: x.batch(batch_size=2)) + get_next = dataset.make_one_shot_iterator().get_next() + + with self.cached_session() as sess: + self.assertAllEqual(np.float32([1., 2.]), sess.run(get_next)) + self.assertAllEqual(np.float32([2., 3.]), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index c985e00dd1..93b3a7b93b 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -1115,7 +1115,7 @@ class Dataset(object): return FilterDataset(self, predicate) def apply(self, transformation_func): - """Apply a transformation function to this dataset. + """Applies a transformation function to this dataset. `apply` enables chaining of custom `Dataset` transformations, which are represented as functions that take one `Dataset` argument and return a @@ -1131,7 +1131,7 @@ class Dataset(object): Args: transformation_func: A function that takes one `Dataset` argument and - returns a `Dataset`. + returns a `Dataset`. Returns: Dataset: The `Dataset` returned by applying `transformation_func` to this @@ -1142,6 +1142,45 @@ class Dataset(object): raise TypeError("`transformation_func` must return a Dataset.") return dataset + def window(self, size, shift=None, stride=1, drop_remainder=False): + """Combines input elements into a dataset of windows. + + Each window is a dataset itself and contains `size` elements (or + possibly fewer if there are not enough input elements to fill the window + and `drop_remainder` evaluates to false). + + The `stride` argument determines the stride of the input elements, + and the `shift` argument determines the shift of the window. + + For example: + - `tf.data.Dataset.range(7).window(2)` produces + `{{0, 1}, {2, 3}, {4, 5}, {6}}` + - `tf.data.Dataset.range(7).window(3, 2, 1, True)` produces + `{{0, 1, 2}, {2, 3, 4}, {4, 5, 6}}` + - `tf.data.Dataset.range(7).window(3, 1, 2, True)` produces + `{{0, 2, 4}, {1, 3, 5}, {2, 4, 6}}` + + Args: + size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements + of the input dataset to combine into a window. + shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + forward shift of the sliding window in each iteration. Defaults to + `size`. + stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + stride of the input elements in the sliding window. + drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing + whether a window should be dropped in case its size is smaller than + `window_size`. + + Returns: + Dataset: A `Dataset` of windows, each of which is a nested `Dataset` with + the same structure as this dataset, but a finite subsequence of its + elements. + """ + if shift is None: + shift = size + return WindowDataset(self, size, shift, stride, drop_remainder) + class TensorDataset(Dataset): """A `Dataset` with a single element, viz. a nested structure of tensors.""" @@ -2442,3 +2481,53 @@ class PrefetchDataset(Dataset): @property def output_types(self): return self._input_dataset.output_types + + +class WindowDataset(Dataset): + """A dataset that creates window datasets from the input elements.""" + + def __init__(self, input_dataset, size, shift, stride, drop_remainder): + """See `window_dataset()` for more details.""" + super(WindowDataset, self).__init__() + self._input_dataset = input_dataset + self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size") + self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift") + self._stride = ops.convert_to_tensor( + stride, dtype=dtypes.int64, name="stride") + self._drop_remainder = ops.convert_to_tensor( + drop_remainder, dtype=dtypes.bool, name="drop_remainder") + self._output_classes = nest.pack_sequence_as( + input_dataset.output_classes, + [ + _NestedDatasetComponent( # pylint: disable=protected-access + output_classes=output_class, + output_shapes=output_shape, + output_types=output_type) + for output_class, output_shape, output_type in zip( + nest.flatten(input_dataset.output_classes), + nest.flatten(input_dataset.output_shapes), + nest.flatten(input_dataset.output_types)) + ]) + self._output_shapes = self._output_classes + self._output_types = self._output_classes + + def _as_variant_tensor(self): + return gen_dataset_ops.window_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._size, + self._shift, + self._stride, + self._drop_remainder, + **flat_structure(self)) + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types |