diff options
Diffstat (limited to 'tensorflow/python/data/ops/dataset_ops.py')
-rw-r--r-- | tensorflow/python/data/ops/dataset_ops.py | 201 |
1 files changed, 171 insertions, 30 deletions
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index c985e00dd1..ac87a451b1 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -80,6 +80,12 @@ class Dataset(object): """ raise NotImplementedError("Dataset._as_variant_tensor") + @abc.abstractmethod + def _inputs(self): + """Returns a list of the input datasets of the dataset.""" + + raise NotImplementedError("Dataset._inputs") + def make_initializable_iterator(self, shared_name=None): """Creates an `Iterator` for enumerating the elements of this dataset. @@ -1009,6 +1015,23 @@ class Dataset(object): def flat_map(self, map_func): """Maps `map_func` across this dataset and flattens the result. + Use `flat_map` if you want to make sure that the order of your dataset + stays the same. For example, to flatten a dataset of batches into a + dataset of their elements: + + ```python + # NOTE: The following examples use `{ ... }` to represent the + # contents of a dataset. '[...]' represents a tensor. + a = {[1,2,3,4,5], [6,7,8,9], [10]} + + a.flat_map(lambda x: Dataset.from_tensor_slices(x)) == + {[1,2,3,4,5,6,7,8,9,10]} + ``` + + `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since + `flat_map` produces the same output as + `tf.data.Dataset.interleave(cycle_length=1)` + Args: map_func: A function mapping a nested structure of tensors (having shapes and types defined by `self.output_shapes` and `self.output_types`) to a @@ -1043,7 +1066,7 @@ class Dataset(object): elements are produced. `cycle_length` controls the number of input elements that are processed concurrently. If you set `cycle_length` to 1, this transformation will handle one input element at a time, and will produce - identical results = to `tf.data.Dataset.flat_map`. In general, + identical results to `tf.data.Dataset.flat_map`. In general, this transformation will apply `map_func` to `cycle_length` input elements, open iterators on the returned `Dataset` objects, and cycle through them producing `block_length` consecutive elements from each iterator, and @@ -1115,7 +1138,7 @@ class Dataset(object): return FilterDataset(self, predicate) def apply(self, transformation_func): - """Apply a transformation function to this dataset. + """Applies a transformation function to this dataset. `apply` enables chaining of custom `Dataset` transformations, which are represented as functions that take one `Dataset` argument and return a @@ -1131,7 +1154,7 @@ class Dataset(object): Args: transformation_func: A function that takes one `Dataset` argument and - returns a `Dataset`. + returns a `Dataset`. Returns: Dataset: The `Dataset` returned by applying `transformation_func` to this @@ -1140,10 +1163,68 @@ class Dataset(object): dataset = transformation_func(self) if not isinstance(dataset, Dataset): raise TypeError("`transformation_func` must return a Dataset.") + dataset._input_datasets = [self] # pylint: disable=protected-access return dataset + def window(self, size, shift=None, stride=1, drop_remainder=False): + """Combines input elements into a dataset of windows. + + Each window is a dataset itself and contains `size` elements (or + possibly fewer if there are not enough input elements to fill the window + and `drop_remainder` evaluates to false). + + The `stride` argument determines the stride of the input elements, + and the `shift` argument determines the shift of the window. + + For example: + - `tf.data.Dataset.range(7).window(2)` produces + `{{0, 1}, {2, 3}, {4, 5}, {6}}` + - `tf.data.Dataset.range(7).window(3, 2, 1, True)` produces + `{{0, 1, 2}, {2, 3, 4}, {4, 5, 6}}` + - `tf.data.Dataset.range(7).window(3, 1, 2, True)` produces + `{{0, 2, 4}, {1, 3, 5}, {2, 4, 6}}` + + Args: + size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements + of the input dataset to combine into a window. + shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + forward shift of the sliding window in each iteration. Defaults to + `size`. + stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + stride of the input elements in the sliding window. + drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing + whether a window should be dropped in case its size is smaller than + `window_size`. + + Returns: + Dataset: A `Dataset` of windows, each of which is a nested `Dataset` with + the same structure as this dataset, but a finite subsequence of its + elements. + """ + if shift is None: + shift = size + return WindowDataset(self, size, shift, stride, drop_remainder) + + +class DatasetSource(Dataset): + """Abstract class representing a dataset with no inputs.""" + + def _inputs(self): + return [] + + +class UnaryDataset(Dataset): + """Abstract class representing a dataset with one input.""" + + def __init__(self, input_dataset): + super(UnaryDataset, self).__init__() + self._input_dataset = input_dataset + + def _inputs(self): + return [self._input_dataset] + -class TensorDataset(Dataset): +class TensorDataset(DatasetSource): """A `Dataset` with a single element, viz. a nested structure of tensors.""" def __init__(self, tensors): @@ -1183,7 +1264,7 @@ class TensorDataset(Dataset): return self._output_types -class TensorSliceDataset(Dataset): +class TensorSliceDataset(DatasetSource): """A `Dataset` of slices from a nested structure of tensors.""" def __init__(self, tensors): @@ -1227,7 +1308,7 @@ class TensorSliceDataset(Dataset): return self._output_types -class SparseTensorSliceDataset(Dataset): +class SparseTensorSliceDataset(DatasetSource): """A `Dataset` that splits a rank-N `tf.SparseTensor` into its rows.""" def __init__(self, sparse_tensor): @@ -1328,6 +1409,9 @@ class _VariantDataset(Dataset): def _as_variant_tensor(self): return self._dataset_variant + def _inputs(self): + return [] + @property def output_classes(self): return self._structure.output_classes @@ -1568,7 +1652,7 @@ def flat_structure(dataset): } -class _GeneratorDataset(Dataset): +class _GeneratorDataset(DatasetSource): """A `Dataset` that generates elements by invoking a function.""" def __init__(self, init_args, init_func, next_func, finalize_func): @@ -1669,6 +1753,9 @@ class ZipDataset(Dataset): **flat_structure(self)) # pylint: enable=protected-access + def _inputs(self): + return nest.flatten(self._datasets) + @property def output_classes(self): return nest.pack_sequence_as( @@ -1704,6 +1791,7 @@ class ConcatenateDataset(Dataset): raise TypeError( "Two datasets to concatenate have different classes %s and %s" % (input_dataset.output_classes, dataset_to_concatenate.output_classes)) + self._input_datasets = [input_dataset, dataset_to_concatenate] def _as_variant_tensor(self): # pylint: disable=protected-access @@ -1713,6 +1801,9 @@ class ConcatenateDataset(Dataset): **flat_structure(self)) # pylint: enable=protected-access + def _inputs(self): + return [self._input_dataset, self._dataset_to_concatenate] + @property def output_classes(self): return self._input_dataset.output_classes @@ -1731,12 +1822,12 @@ class ConcatenateDataset(Dataset): return self._input_dataset.output_types -class RepeatDataset(Dataset): +class RepeatDataset(UnaryDataset): """A `Dataset` that repeats its input several times.""" def __init__(self, input_dataset, count): """See `Dataset.repeat()` for details.""" - super(RepeatDataset, self).__init__() + super(RepeatDataset, self).__init__(input_dataset) self._input_dataset = input_dataset if count is None: self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") @@ -1763,7 +1854,7 @@ class RepeatDataset(Dataset): return self._input_dataset.output_types -class RangeDataset(Dataset): +class RangeDataset(DatasetSource): """A `Dataset` of a step separated range of values.""" def __init__(self, *args): @@ -1811,12 +1902,12 @@ class RangeDataset(Dataset): return dtypes.int64 -class CacheDataset(Dataset): +class CacheDataset(UnaryDataset): """A `Dataset` that caches elements of its input.""" def __init__(self, input_dataset, filename): """See `Dataset.cache()` for details.""" - super(CacheDataset, self).__init__() + super(CacheDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._filename = ops.convert_to_tensor( filename, dtype=dtypes.string, name="filename") @@ -1840,7 +1931,7 @@ class CacheDataset(Dataset): return self._input_dataset.output_types -class ShuffleDataset(Dataset): +class ShuffleDataset(UnaryDataset): """A `Dataset` that randomly shuffles the elements of its input.""" def __init__(self, @@ -1868,7 +1959,7 @@ class ShuffleDataset(Dataset): Raises: ValueError: if invalid arguments are provided. """ - super(ShuffleDataset, self).__init__() + super(ShuffleDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._buffer_size = ops.convert_to_tensor( buffer_size, dtype=dtypes.int64, name="buffer_size") @@ -1900,12 +1991,12 @@ class ShuffleDataset(Dataset): return self._input_dataset.output_types -class TakeDataset(Dataset): +class TakeDataset(UnaryDataset): """A `Dataset` containing the first `count` elements from its input.""" def __init__(self, input_dataset, count): """See `Dataset.take()` for details.""" - super(TakeDataset, self).__init__() + super(TakeDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count") @@ -1928,12 +2019,12 @@ class TakeDataset(Dataset): return self._input_dataset.output_types -class SkipDataset(Dataset): +class SkipDataset(UnaryDataset): """A `Dataset` skipping the first `count` elements from its input.""" def __init__(self, input_dataset, count): """See `Dataset.skip()` for details.""" - super(SkipDataset, self).__init__() + super(SkipDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count") @@ -1956,12 +2047,12 @@ class SkipDataset(Dataset): return self._input_dataset.output_types -class BatchDataset(Dataset): +class BatchDataset(UnaryDataset): """A `Dataset` that batches contiguous elements from its input.""" def __init__(self, input_dataset, batch_size, drop_remainder): """See `Dataset.batch()` for details.""" - super(BatchDataset, self).__init__() + super(BatchDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._batch_size = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") @@ -2110,13 +2201,13 @@ def _default_padding(input_dataset): return nest.map_structure(make_zero, input_dataset.output_types) -class PaddedBatchDataset(Dataset): +class PaddedBatchDataset(UnaryDataset): """A `Dataset` that batches and pads contiguous elements from its input.""" def __init__(self, input_dataset, batch_size, padded_shapes, padding_values, drop_remainder): """See `Dataset.batch()` for details.""" - super(PaddedBatchDataset, self).__init__() + super(PaddedBatchDataset, self).__init__(input_dataset) if sparse.any_sparse(input_dataset.output_classes): # TODO(b/63669786): support batching of sparse tensors raise TypeError( @@ -2216,12 +2307,12 @@ def _warn_if_collections(transformation_name): % transformation_name) -class MapDataset(Dataset): +class MapDataset(UnaryDataset): """A `Dataset` that maps a function over elements in its input.""" def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True): """See `Dataset.map()` for details.""" - super(MapDataset, self).__init__() + super(MapDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._use_inter_op_parallelism = use_inter_op_parallelism @@ -2282,12 +2373,12 @@ class ParallelMapDataset(MapDataset): # pylint: enable=protected-access -class FlatMapDataset(Dataset): +class FlatMapDataset(UnaryDataset): """A `Dataset` that maps a function over its input and flattens the result.""" def __init__(self, input_dataset, map_func): """See `Dataset.flat_map()` for details.""" - super(FlatMapDataset, self).__init__() + super(FlatMapDataset, self).__init__(input_dataset) self._input_dataset = input_dataset wrapped_func = StructuredFunctionWrapper( @@ -2378,12 +2469,12 @@ class ParallelInterleaveDataset(FlatMapDataset): return "Dataset.interleave()" -class FilterDataset(Dataset): +class FilterDataset(UnaryDataset): """A `Dataset` that filters its input according to a predicate function.""" def __init__(self, input_dataset, predicate): """See `Dataset.filter()` for details.""" - super(FilterDataset, self).__init__() + super(FilterDataset, self).__init__(input_dataset) self._input_dataset = input_dataset wrapped_func = StructuredFunctionWrapper( predicate, "Dataset.filter()", input_dataset) @@ -2413,12 +2504,12 @@ class FilterDataset(Dataset): return self._input_dataset.output_types -class PrefetchDataset(Dataset): +class PrefetchDataset(UnaryDataset): """A `Dataset` that asynchronously prefetches its input.""" def __init__(self, input_dataset, buffer_size): """See `Dataset.prefetch()` for details.""" - super(PrefetchDataset, self).__init__() + super(PrefetchDataset, self).__init__(input_dataset) self._input_dataset = input_dataset if buffer_size is None: buffer_size = -1 # This is the sentinel for auto-tuning. @@ -2442,3 +2533,53 @@ class PrefetchDataset(Dataset): @property def output_types(self): return self._input_dataset.output_types + + +class WindowDataset(UnaryDataset): + """A dataset that creates window datasets from the input elements.""" + + def __init__(self, input_dataset, size, shift, stride, drop_remainder): + """See `window_dataset()` for more details.""" + super(WindowDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size") + self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift") + self._stride = ops.convert_to_tensor( + stride, dtype=dtypes.int64, name="stride") + self._drop_remainder = ops.convert_to_tensor( + drop_remainder, dtype=dtypes.bool, name="drop_remainder") + self._output_classes = nest.pack_sequence_as( + input_dataset.output_classes, + [ + _NestedDatasetComponent( # pylint: disable=protected-access + output_classes=output_class, + output_shapes=output_shape, + output_types=output_type) + for output_class, output_shape, output_type in zip( + nest.flatten(input_dataset.output_classes), + nest.flatten(input_dataset.output_shapes), + nest.flatten(input_dataset.output_types)) + ]) + self._output_shapes = self._output_classes + self._output_types = self._output_classes + + def _as_variant_tensor(self): + return gen_dataset_ops.window_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._size, + self._shift, + self._stride, + self._drop_remainder, + **flat_structure(self)) + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types |