aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/ops/dataset_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/ops/dataset_ops.py')
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py201
1 files changed, 171 insertions, 30 deletions
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index c985e00dd1..ac87a451b1 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -80,6 +80,12 @@ class Dataset(object):
"""
raise NotImplementedError("Dataset._as_variant_tensor")
+ @abc.abstractmethod
+ def _inputs(self):
+ """Returns a list of the input datasets of the dataset."""
+
+ raise NotImplementedError("Dataset._inputs")
+
def make_initializable_iterator(self, shared_name=None):
"""Creates an `Iterator` for enumerating the elements of this dataset.
@@ -1009,6 +1015,23 @@ class Dataset(object):
def flat_map(self, map_func):
"""Maps `map_func` across this dataset and flattens the result.
+ Use `flat_map` if you want to make sure that the order of your dataset
+ stays the same. For example, to flatten a dataset of batches into a
+ dataset of their elements:
+
+ ```python
+ # NOTE: The following examples use `{ ... }` to represent the
+ # contents of a dataset. '[...]' represents a tensor.
+ a = {[1,2,3,4,5], [6,7,8,9], [10]}
+
+ a.flat_map(lambda x: Dataset.from_tensor_slices(x)) ==
+ {[1,2,3,4,5,6,7,8,9,10]}
+ ```
+
+ `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since
+ `flat_map` produces the same output as
+ `tf.data.Dataset.interleave(cycle_length=1)`
+
Args:
map_func: A function mapping a nested structure of tensors (having shapes
and types defined by `self.output_shapes` and `self.output_types`) to a
@@ -1043,7 +1066,7 @@ class Dataset(object):
elements are produced. `cycle_length` controls the number of input elements
that are processed concurrently. If you set `cycle_length` to 1, this
transformation will handle one input element at a time, and will produce
- identical results = to `tf.data.Dataset.flat_map`. In general,
+ identical results to `tf.data.Dataset.flat_map`. In general,
this transformation will apply `map_func` to `cycle_length` input elements,
open iterators on the returned `Dataset` objects, and cycle through them
producing `block_length` consecutive elements from each iterator, and
@@ -1115,7 +1138,7 @@ class Dataset(object):
return FilterDataset(self, predicate)
def apply(self, transformation_func):
- """Apply a transformation function to this dataset.
+ """Applies a transformation function to this dataset.
`apply` enables chaining of custom `Dataset` transformations, which are
represented as functions that take one `Dataset` argument and return a
@@ -1131,7 +1154,7 @@ class Dataset(object):
Args:
transformation_func: A function that takes one `Dataset` argument and
- returns a `Dataset`.
+ returns a `Dataset`.
Returns:
Dataset: The `Dataset` returned by applying `transformation_func` to this
@@ -1140,10 +1163,68 @@ class Dataset(object):
dataset = transformation_func(self)
if not isinstance(dataset, Dataset):
raise TypeError("`transformation_func` must return a Dataset.")
+ dataset._input_datasets = [self] # pylint: disable=protected-access
return dataset
+ def window(self, size, shift=None, stride=1, drop_remainder=False):
+ """Combines input elements into a dataset of windows.
+
+ Each window is a dataset itself and contains `size` elements (or
+ possibly fewer if there are not enough input elements to fill the window
+ and `drop_remainder` evaluates to false).
+
+ The `stride` argument determines the stride of the input elements,
+ and the `shift` argument determines the shift of the window.
+
+ For example:
+ - `tf.data.Dataset.range(7).window(2)` produces
+ `{{0, 1}, {2, 3}, {4, 5}, {6}}`
+ - `tf.data.Dataset.range(7).window(3, 2, 1, True)` produces
+ `{{0, 1, 2}, {2, 3, 4}, {4, 5, 6}}`
+ - `tf.data.Dataset.range(7).window(3, 1, 2, True)` produces
+ `{{0, 2, 4}, {1, 3, 5}, {2, 4, 6}}`
+
+ Args:
+ size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements
+ of the input dataset to combine into a window.
+ shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ forward shift of the sliding window in each iteration. Defaults to
+ `size`.
+ stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ stride of the input elements in the sliding window.
+ drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
+ whether a window should be dropped in case its size is smaller than
+ `window_size`.
+
+ Returns:
+ Dataset: A `Dataset` of windows, each of which is a nested `Dataset` with
+ the same structure as this dataset, but a finite subsequence of its
+ elements.
+ """
+ if shift is None:
+ shift = size
+ return WindowDataset(self, size, shift, stride, drop_remainder)
+
+
+class DatasetSource(Dataset):
+ """Abstract class representing a dataset with no inputs."""
+
+ def _inputs(self):
+ return []
+
+
+class UnaryDataset(Dataset):
+ """Abstract class representing a dataset with one input."""
+
+ def __init__(self, input_dataset):
+ super(UnaryDataset, self).__init__()
+ self._input_dataset = input_dataset
+
+ def _inputs(self):
+ return [self._input_dataset]
+
-class TensorDataset(Dataset):
+class TensorDataset(DatasetSource):
"""A `Dataset` with a single element, viz. a nested structure of tensors."""
def __init__(self, tensors):
@@ -1183,7 +1264,7 @@ class TensorDataset(Dataset):
return self._output_types
-class TensorSliceDataset(Dataset):
+class TensorSliceDataset(DatasetSource):
"""A `Dataset` of slices from a nested structure of tensors."""
def __init__(self, tensors):
@@ -1227,7 +1308,7 @@ class TensorSliceDataset(Dataset):
return self._output_types
-class SparseTensorSliceDataset(Dataset):
+class SparseTensorSliceDataset(DatasetSource):
"""A `Dataset` that splits a rank-N `tf.SparseTensor` into its rows."""
def __init__(self, sparse_tensor):
@@ -1328,6 +1409,9 @@ class _VariantDataset(Dataset):
def _as_variant_tensor(self):
return self._dataset_variant
+ def _inputs(self):
+ return []
+
@property
def output_classes(self):
return self._structure.output_classes
@@ -1568,7 +1652,7 @@ def flat_structure(dataset):
}
-class _GeneratorDataset(Dataset):
+class _GeneratorDataset(DatasetSource):
"""A `Dataset` that generates elements by invoking a function."""
def __init__(self, init_args, init_func, next_func, finalize_func):
@@ -1669,6 +1753,9 @@ class ZipDataset(Dataset):
**flat_structure(self))
# pylint: enable=protected-access
+ def _inputs(self):
+ return nest.flatten(self._datasets)
+
@property
def output_classes(self):
return nest.pack_sequence_as(
@@ -1704,6 +1791,7 @@ class ConcatenateDataset(Dataset):
raise TypeError(
"Two datasets to concatenate have different classes %s and %s" %
(input_dataset.output_classes, dataset_to_concatenate.output_classes))
+ self._input_datasets = [input_dataset, dataset_to_concatenate]
def _as_variant_tensor(self):
# pylint: disable=protected-access
@@ -1713,6 +1801,9 @@ class ConcatenateDataset(Dataset):
**flat_structure(self))
# pylint: enable=protected-access
+ def _inputs(self):
+ return [self._input_dataset, self._dataset_to_concatenate]
+
@property
def output_classes(self):
return self._input_dataset.output_classes
@@ -1731,12 +1822,12 @@ class ConcatenateDataset(Dataset):
return self._input_dataset.output_types
-class RepeatDataset(Dataset):
+class RepeatDataset(UnaryDataset):
"""A `Dataset` that repeats its input several times."""
def __init__(self, input_dataset, count):
"""See `Dataset.repeat()` for details."""
- super(RepeatDataset, self).__init__()
+ super(RepeatDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if count is None:
self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
@@ -1763,7 +1854,7 @@ class RepeatDataset(Dataset):
return self._input_dataset.output_types
-class RangeDataset(Dataset):
+class RangeDataset(DatasetSource):
"""A `Dataset` of a step separated range of values."""
def __init__(self, *args):
@@ -1811,12 +1902,12 @@ class RangeDataset(Dataset):
return dtypes.int64
-class CacheDataset(Dataset):
+class CacheDataset(UnaryDataset):
"""A `Dataset` that caches elements of its input."""
def __init__(self, input_dataset, filename):
"""See `Dataset.cache()` for details."""
- super(CacheDataset, self).__init__()
+ super(CacheDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._filename = ops.convert_to_tensor(
filename, dtype=dtypes.string, name="filename")
@@ -1840,7 +1931,7 @@ class CacheDataset(Dataset):
return self._input_dataset.output_types
-class ShuffleDataset(Dataset):
+class ShuffleDataset(UnaryDataset):
"""A `Dataset` that randomly shuffles the elements of its input."""
def __init__(self,
@@ -1868,7 +1959,7 @@ class ShuffleDataset(Dataset):
Raises:
ValueError: if invalid arguments are provided.
"""
- super(ShuffleDataset, self).__init__()
+ super(ShuffleDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
@@ -1900,12 +1991,12 @@ class ShuffleDataset(Dataset):
return self._input_dataset.output_types
-class TakeDataset(Dataset):
+class TakeDataset(UnaryDataset):
"""A `Dataset` containing the first `count` elements from its input."""
def __init__(self, input_dataset, count):
"""See `Dataset.take()` for details."""
- super(TakeDataset, self).__init__()
+ super(TakeDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
@@ -1928,12 +2019,12 @@ class TakeDataset(Dataset):
return self._input_dataset.output_types
-class SkipDataset(Dataset):
+class SkipDataset(UnaryDataset):
"""A `Dataset` skipping the first `count` elements from its input."""
def __init__(self, input_dataset, count):
"""See `Dataset.skip()` for details."""
- super(SkipDataset, self).__init__()
+ super(SkipDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
@@ -1956,12 +2047,12 @@ class SkipDataset(Dataset):
return self._input_dataset.output_types
-class BatchDataset(Dataset):
+class BatchDataset(UnaryDataset):
"""A `Dataset` that batches contiguous elements from its input."""
def __init__(self, input_dataset, batch_size, drop_remainder):
"""See `Dataset.batch()` for details."""
- super(BatchDataset, self).__init__()
+ super(BatchDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
@@ -2110,13 +2201,13 @@ def _default_padding(input_dataset):
return nest.map_structure(make_zero, input_dataset.output_types)
-class PaddedBatchDataset(Dataset):
+class PaddedBatchDataset(UnaryDataset):
"""A `Dataset` that batches and pads contiguous elements from its input."""
def __init__(self, input_dataset, batch_size, padded_shapes, padding_values,
drop_remainder):
"""See `Dataset.batch()` for details."""
- super(PaddedBatchDataset, self).__init__()
+ super(PaddedBatchDataset, self).__init__(input_dataset)
if sparse.any_sparse(input_dataset.output_classes):
# TODO(b/63669786): support batching of sparse tensors
raise TypeError(
@@ -2216,12 +2307,12 @@ def _warn_if_collections(transformation_name):
% transformation_name)
-class MapDataset(Dataset):
+class MapDataset(UnaryDataset):
"""A `Dataset` that maps a function over elements in its input."""
def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
"""See `Dataset.map()` for details."""
- super(MapDataset, self).__init__()
+ super(MapDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._use_inter_op_parallelism = use_inter_op_parallelism
@@ -2282,12 +2373,12 @@ class ParallelMapDataset(MapDataset):
# pylint: enable=protected-access
-class FlatMapDataset(Dataset):
+class FlatMapDataset(UnaryDataset):
"""A `Dataset` that maps a function over its input and flattens the result."""
def __init__(self, input_dataset, map_func):
"""See `Dataset.flat_map()` for details."""
- super(FlatMapDataset, self).__init__()
+ super(FlatMapDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
wrapped_func = StructuredFunctionWrapper(
@@ -2378,12 +2469,12 @@ class ParallelInterleaveDataset(FlatMapDataset):
return "Dataset.interleave()"
-class FilterDataset(Dataset):
+class FilterDataset(UnaryDataset):
"""A `Dataset` that filters its input according to a predicate function."""
def __init__(self, input_dataset, predicate):
"""See `Dataset.filter()` for details."""
- super(FilterDataset, self).__init__()
+ super(FilterDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
wrapped_func = StructuredFunctionWrapper(
predicate, "Dataset.filter()", input_dataset)
@@ -2413,12 +2504,12 @@ class FilterDataset(Dataset):
return self._input_dataset.output_types
-class PrefetchDataset(Dataset):
+class PrefetchDataset(UnaryDataset):
"""A `Dataset` that asynchronously prefetches its input."""
def __init__(self, input_dataset, buffer_size):
"""See `Dataset.prefetch()` for details."""
- super(PrefetchDataset, self).__init__()
+ super(PrefetchDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if buffer_size is None:
buffer_size = -1 # This is the sentinel for auto-tuning.
@@ -2442,3 +2533,53 @@ class PrefetchDataset(Dataset):
@property
def output_types(self):
return self._input_dataset.output_types
+
+
+class WindowDataset(UnaryDataset):
+ """A dataset that creates window datasets from the input elements."""
+
+ def __init__(self, input_dataset, size, shift, stride, drop_remainder):
+ """See `window_dataset()` for more details."""
+ super(WindowDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size")
+ self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift")
+ self._stride = ops.convert_to_tensor(
+ stride, dtype=dtypes.int64, name="stride")
+ self._drop_remainder = ops.convert_to_tensor(
+ drop_remainder, dtype=dtypes.bool, name="drop_remainder")
+ self._output_classes = nest.pack_sequence_as(
+ input_dataset.output_classes,
+ [
+ _NestedDatasetComponent( # pylint: disable=protected-access
+ output_classes=output_class,
+ output_shapes=output_shape,
+ output_types=output_type)
+ for output_class, output_shape, output_type in zip(
+ nest.flatten(input_dataset.output_classes),
+ nest.flatten(input_dataset.output_shapes),
+ nest.flatten(input_dataset.output_types))
+ ])
+ self._output_shapes = self._output_classes
+ self._output_types = self._output_classes
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.window_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._size,
+ self._shift,
+ self._stride,
+ self._drop_remainder,
+ **flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types