diff options
author | 2018-09-17 16:31:24 -0700 | |
---|---|---|
committer | 2018-09-17 16:35:28 -0700 | |
commit | 8ef1ece7d0ecdec633a22a8100fdae05cfbacb3e (patch) | |
tree | 7123d7e44983f26da690ac511ceb09b77c067114 /tensorflow/contrib/data | |
parent | f5116dd366a5bb1d679e1682c13b8fa3c4830a84 (diff) |
[tf.data] Introducing `tf.data.Dataset.window(size, shift, stride, drop_remainder)`, which can be used for combining elements of input dataset into "windows". A window
is itself a finite dataset and, among other things, can be used for generalized batching (see https://github.com/tensorflow/community/pull/5 for details).
PiperOrigin-RevId: 213360134
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/grouping.py | 51 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/sliding.py | 4 |
3 files changed, 15 insertions, 47 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py index 6eaa0b1959..8b7b3ac0f7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py @@ -89,13 +89,14 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): return dataset_ops.Dataset.zip( tuple([fn(*arg) if isinstance(arg, tuple) else arg for arg in args])) - dataset = self._structuredDataset(structure, shape, dtype).apply( + dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply( grouping.window_dataset(5)).flat_map(fn) get_next = dataset.make_one_shot_iterator().get_next() with self.cached_session() as sess: expected = sess.run(self._structuredElement(structure, shape, dtype)) - actual = sess.run(get_next) - self._assertEqual(expected, actual) + for _ in range(5): + actual = sess.run(get_next) + self._assertEqual(expected, actual) @parameterized.named_parameters( ("1", None, np.int32([]), dtypes.bool), diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 099e10db92..020167e4d1 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -255,6 +255,7 @@ def _map_x_dataset(map_func): return _apply_fn +# TODO(b/115382007) Remove this once canned reducers move to core. def window_dataset(window_size): """A transformation that creates window datasets from the input dataset. @@ -271,7 +272,12 @@ def window_dataset(window_size): """ def _apply_fn(dataset): - return _WindowDataset(dataset, window_size) + return dataset_ops.WindowDataset( + dataset, + size=window_size, + shift=window_size, + stride=1, + drop_remainder=False) return _apply_fn @@ -556,46 +562,3 @@ class _MapXDataset(dataset_ops.Dataset): @property def output_types(self): return self._output_types - - -class _WindowDataset(dataset_ops.Dataset): - """A dataset that creates window datasets from the input elements.""" - - def __init__(self, input_dataset, window_size): - """See `window_dataset()` for more details.""" - super(_WindowDataset, self).__init__() - self._input_dataset = input_dataset - self._window_size = ops.convert_to_tensor( - window_size, dtype=dtypes.int64, name="window_size") - self._output_classes = nest.pack_sequence_as( - input_dataset.output_classes, - [ - dataset_ops._NestedDatasetComponent( # pylint: disable=protected-access - output_classes=output_class, - output_shapes=output_shape, - output_types=output_type) - for output_class, output_shape, output_type in zip( - nest.flatten(input_dataset.output_classes), - nest.flatten(input_dataset.output_shapes), - nest.flatten(input_dataset.output_types)) - ]) - self._output_shapes = self._output_classes - self._output_types = self._output_classes - - def _as_variant_tensor(self): - return gen_dataset_ops.window_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._window_size, - **dataset_ops.flat_structure(self)) - - @property - def output_classes(self): - return self._output_classes - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index 8025dcdd16..b0d6a16c20 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -67,6 +67,10 @@ class _SlideDataset(dataset_ops.Dataset): @deprecation.deprecated_args( None, "stride is deprecated, use window_shift instead", "stride") +@deprecation.deprecated( + None, "Use `tf.data.Dataset.window(size=window_size, shift=window_shift, " + "stride=window_stride).flat_map(lambda x: x.batch(window.size))` " + "instead.") def sliding_window_batch(window_size, stride=None, window_shift=None, |