aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-17 16:31:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 16:35:28 -0700
commit8ef1ece7d0ecdec633a22a8100fdae05cfbacb3e (patch)
tree7123d7e44983f26da690ac511ceb09b77c067114 /tensorflow/contrib/data
parentf5116dd366a5bb1d679e1682c13b8fa3c4830a84 (diff)
[tf.data] Introducing `tf.data.Dataset.window(size, shift, stride, drop_remainder)`, which can be used for combining elements of input dataset into "windows". A window
is itself a finite dataset and, among other things, can be used for generalized batching (see https://github.com/tensorflow/community/pull/5 for details). PiperOrigin-RevId: 213360134
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py7
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py51
-rw-r--r--tensorflow/contrib/data/python/ops/sliding.py4
3 files changed, 15 insertions, 47 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
index 6eaa0b1959..8b7b3ac0f7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
@@ -89,13 +89,14 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
return dataset_ops.Dataset.zip(
tuple([fn(*arg) if isinstance(arg, tuple) else arg for arg in args]))
- dataset = self._structuredDataset(structure, shape, dtype).apply(
+ dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
grouping.window_dataset(5)).flat_map(fn)
get_next = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
expected = sess.run(self._structuredElement(structure, shape, dtype))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
+ for _ in range(5):
+ actual = sess.run(get_next)
+ self._assertEqual(expected, actual)
@parameterized.named_parameters(
("1", None, np.int32([]), dtypes.bool),
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 099e10db92..020167e4d1 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -255,6 +255,7 @@ def _map_x_dataset(map_func):
return _apply_fn
+# TODO(b/115382007) Remove this once canned reducers move to core.
def window_dataset(window_size):
"""A transformation that creates window datasets from the input dataset.
@@ -271,7 +272,12 @@ def window_dataset(window_size):
"""
def _apply_fn(dataset):
- return _WindowDataset(dataset, window_size)
+ return dataset_ops.WindowDataset(
+ dataset,
+ size=window_size,
+ shift=window_size,
+ stride=1,
+ drop_remainder=False)
return _apply_fn
@@ -556,46 +562,3 @@ class _MapXDataset(dataset_ops.Dataset):
@property
def output_types(self):
return self._output_types
-
-
-class _WindowDataset(dataset_ops.Dataset):
- """A dataset that creates window datasets from the input elements."""
-
- def __init__(self, input_dataset, window_size):
- """See `window_dataset()` for more details."""
- super(_WindowDataset, self).__init__()
- self._input_dataset = input_dataset
- self._window_size = ops.convert_to_tensor(
- window_size, dtype=dtypes.int64, name="window_size")
- self._output_classes = nest.pack_sequence_as(
- input_dataset.output_classes,
- [
- dataset_ops._NestedDatasetComponent( # pylint: disable=protected-access
- output_classes=output_class,
- output_shapes=output_shape,
- output_types=output_type)
- for output_class, output_shape, output_type in zip(
- nest.flatten(input_dataset.output_classes),
- nest.flatten(input_dataset.output_shapes),
- nest.flatten(input_dataset.output_types))
- ])
- self._output_shapes = self._output_classes
- self._output_types = self._output_classes
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.window_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._window_size,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py
index 8025dcdd16..b0d6a16c20 100644
--- a/tensorflow/contrib/data/python/ops/sliding.py
+++ b/tensorflow/contrib/data/python/ops/sliding.py
@@ -67,6 +67,10 @@ class _SlideDataset(dataset_ops.Dataset):
@deprecation.deprecated_args(
None, "stride is deprecated, use window_shift instead", "stride")
+@deprecation.deprecated(
+ None, "Use `tf.data.Dataset.window(size=window_size, shift=window_shift, "
+ "stride=window_stride).flat_map(lambda x: x.batch(window.size))` "
+ "instead.")
def sliding_window_batch(window_size,
stride=None,
window_shift=None,