aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/ops/grouping.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/ops/grouping.py')
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py136
1 files changed, 131 insertions, 5 deletions
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index ca9540bf13..bd8d398c58 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -149,9 +149,9 @@ def bucket_by_sequence_length(element_length_func,
@{tf.data.Dataset.padded_batch}. Defaults to padding with 0.
pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown
size to maximum length in batch. If `True`, will pad dimensions with
- unknown size to bucket boundary, and caller must ensure that the source
- `Dataset` does not contain any elements with length longer than
- `max(bucket_boundaries)`.
+ unknown size to bucket boundary minus 1 (i.e., the maximum length in each
+ bucket), and caller must ensure that the source `Dataset` does not contain
+ any elements with length longer than `max(bucket_boundaries)`.
Returns:
A `Dataset` transformation function, which can be passed to
@@ -203,7 +203,7 @@ def bucket_by_sequence_length(element_length_func,
none_filler = None
if pad_to_bucket_boundary:
err_msg = ("When pad_to_bucket_boundary=True, elements must have "
- "length <= max(bucket_boundaries).")
+ "length < max(bucket_boundaries).")
check = check_ops.assert_less(
bucket_id,
constant_op.constant(len(bucket_batch_sizes) - 1,
@@ -213,7 +213,7 @@ def bucket_by_sequence_length(element_length_func,
boundaries = constant_op.constant(bucket_boundaries,
dtype=dtypes.int64)
bucket_boundary = boundaries[bucket_id]
- none_filler = bucket_boundary
+ none_filler = bucket_boundary - 1
shapes = make_padded_shapes(
padded_shapes or grouped_dataset.output_shapes,
none_filler=none_filler)
@@ -227,6 +227,50 @@ def bucket_by_sequence_length(element_length_func,
return _apply_fn
+def _map_x_dataset(map_func):
+ """A transformation that maps `map_func` across its input.
+
+ This transformation is similar to `tf.data.Dataset.map`, but in addition to
+ supporting dense and sparse tensor inputs, it also supports dataset inputs.
+
+ Args:
+ map_func: A function mapping a nested structure of tensors and/or datasets
+ (having shapes and types defined by `self.output_shapes` and
+ `self.output_types`) to another nested structure of tensors and/or
+ datasets.
+
+ Returns:
+ Dataset: A `Dataset`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _MapXDataset(dataset, map_func)
+
+ return _apply_fn
+
+
+def window_dataset(window_size):
+ """A transformation that creates window datasets from the input dataset.
+
+ The resulting datasets will contain `window_size` elements (or
+ `N % window_size` for the last dataset if `window_size` does not divide the
+ number of input elements `N` evenly).
+
+ Args:
+ window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
+ consecutive elements of the input dataset to combine into a window.
+
+ Returns:
+ Dataset: A `Dataset`.
+ """
+
+ def _apply_fn(dataset):
+ return _WindowDataset(dataset, window_size)
+
+ return _apply_fn
+
+
class _GroupByReducerDataset(dataset_ops.Dataset):
"""A `Dataset` that groups its input and performs a reduction."""
@@ -468,3 +512,85 @@ class Reducer(object):
@property
def finalize_func(self):
return self._finalize_func
+
+
+class _MapXDataset(dataset_ops.Dataset):
+ """A `Dataset` that maps a function over elements in its input."""
+
+ def __init__(self, input_dataset, map_func):
+ """See `map_x_dataset()` for details."""
+ super(_MapXDataset, self).__init__()
+ self._input_dataset = input_dataset
+
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ map_func,
+ "tf.contrib.data.map_x_dataset()",
+ input_dataset,
+ experimental_nested_dataset_support=True)
+ self._output_classes = wrapped_func.output_classes
+ self._output_shapes = wrapped_func.output_shapes
+ self._output_types = wrapped_func.output_types
+ self._map_func = wrapped_func.function
+
+ def _as_variant_tensor(self):
+ input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
+ return gen_dataset_ops.map_dataset(
+ input_t,
+ self._map_func.captured_inputs,
+ f=self._map_func,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+
+class _WindowDataset(dataset_ops.Dataset):
+ """A dataset that creates window datasets from the input elements."""
+
+ def __init__(self, input_dataset, window_size):
+ """See `window_dataset()` for more details."""
+ super(_WindowDataset, self).__init__()
+ self._input_dataset = input_dataset
+ self._window_size = ops.convert_to_tensor(
+ window_size, dtype=dtypes.int64, name="window_size")
+ self._output_classes = nest.pack_sequence_as(
+ input_dataset.output_classes,
+ [
+ dataset_ops._NestedDatasetComponent( # pylint: disable=protected-access
+ output_classes=output_class,
+ output_shapes=output_shape,
+ output_types=output_type)
+ for output_class, output_shape, output_type in zip(
+ nest.flatten(input_dataset.output_classes),
+ nest.flatten(input_dataset.output_shapes),
+ nest.flatten(input_dataset.output_types))
+ ])
+ self._output_shapes = self._output_classes
+ self._output_types = self._output_classes
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.window_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._window_size,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types