diff options
Diffstat (limited to 'tensorflow/contrib/data/python/ops/dataset_ops.py')
-rw-r--r-- | tensorflow/contrib/data/python/ops/dataset_ops.py | 691 |
1 files changed, 691 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py new file mode 100644 index 0000000000..bb6b049694 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -0,0 +1,691 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python wrappers for Datasets and Iterators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import enumerate_ops +from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import gen_io_ops +from tensorflow.python.util import deprecation + + +class Dataset(dataset_ops.Dataset): + """Represents a potentially large set of elements. + + A `Dataset` can be used to represent an input pipeline as a + collection of elements (nested structures of tensors) and a "logical + plan" of transformations that act on those elements. + """ + + def __init__(self, dataset): + super(Dataset, self).__init__() + self._dataset = dataset + + @deprecation.deprecated(None, "Use `ds._as_variant_tensor()`.") + def make_dataset_resource(self): + return self._as_variant_tensor() + + def _as_variant_tensor(self): + return self._dataset._as_variant_tensor() # pylint: disable=protected-access + + @property + def output_classes(self): + return self._dataset.output_classes + + @property + def output_shapes(self): + return self._dataset.output_shapes + + @property + def output_types(self): + return self._dataset.output_types + + @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensors()`.") + def from_tensors(tensors): + """Creates a `Dataset` with a single element, comprising the given tensors. + + Args: + tensors: A nested structure of tensors. + + Returns: + A `Dataset`. + """ + return Dataset(dataset_ops.TensorDataset(tensors)) + + @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.") + def from_tensor_slices(tensors): + """Creates a `Dataset` whose elements are slices of the given tensors. + + Args: + tensors: A nested structure of tensors, each having the same size in the + 0th dimension. + + Returns: + A `Dataset`. + """ + return Dataset(dataset_ops.TensorSliceDataset(tensors)) + + @staticmethod + @deprecation.deprecated(None, + "Use `tf.data.Dataset.from_sparse_tensor_slices()`.") + def from_sparse_tensor_slices(sparse_tensor): + """Splits each rank-N `tf.SparseTensor` in this dataset row-wise. + + Args: + sparse_tensor: A `tf.SparseTensor`. + + Returns: + A `Dataset` of rank-(N-1) sparse tensors. + """ + return Dataset(dataset_ops.SparseTensorSliceDataset(sparse_tensor)) + + @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.from_generator()`.") + def from_generator(generator, output_types, output_shapes=None): + """Creates a `Dataset` whose elements are generated by `generator`. + + The `generator` argument must be a callable object that returns + an object that support the `iter()` protocol (e.g. a generator function). + The elements generated by `generator` must be compatible with the given + `output_types` and (optional) `output_shapes` arguments. + + For example: + + ```python + import itertools + + def gen(): + for i in itertools.count(1): + yield (i, [1] * i) + + ds = Dataset.from_generator( + gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None]))) + value = ds.make_one_shot_iterator().get_next() + + sess.run(value) # (1, array([1])) + sess.run(value) # (2, array([1, 1])) + ``` + + Args: + generator: A callable object that takes no arguments and returns an + object that supports the `iter()` protocol. + output_types: A nested structure of `tf.DType` objects corresponding to + each component of an element yielded by `generator`. + output_shapes: (Optional.) A nested structure of `tf.TensorShape` + objects corresponding to each component of an element yielded by + `generator`. + + Returns: + A `Dataset`. + """ + return Dataset( + dataset_ops.Dataset.from_generator(generator, output_types, + output_shapes)) + + @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.range()`.") + def range(*args): + """Creates a `Dataset` of a step-separated range of values. + + For example: + + ```python + Dataset.range(5) == [0, 1, 2, 3, 4] + Dataset.range(2, 5) == [2, 3, 4] + Dataset.range(1, 5, 2) == [1, 3] + Dataset.range(1, 5, -2) == [] + Dataset.range(5, 1) == [] + Dataset.range(5, 1, -2) == [5, 3] + ``` + + Args: + *args: follow same semantics as python's xrange. + len(args) == 1 -> start = 0, stop = args[0], step = 1 + len(args) == 2 -> start = args[0], stop = args[1], step = 1 + len(args) == 3 -> start = args[0], stop = args[1, stop = args[2] + + Returns: + A `RangeDataset`. + + Raises: + ValueError: if len(args) == 0. + """ + return Dataset(dataset_ops.RangeDataset(*args)) + + @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.zip()`.") + def zip(datasets): + """Creates a `Dataset` by zipping together the given datasets. + + This method has similar semantics to the built-in `zip()` function + in Python, with the main difference being that the `datasets` + argument can be an arbitrary nested structure of `Dataset` objects. + For example: + + ```python + # NOTE: The following examples use `{ ... }` to represent the + # contents of a dataset. + a = { 1, 2, 3 } + b = { 4, 5, 6 } + c = { (7, 8), (9, 10), (11, 12) } + d = { 13, 14 } + + # The nested structure of the `datasets` argument determines the + # structure of elements in the resulting dataset. + Dataset.zip((a, b)) == { (1, 4), (2, 5), (3, 6) } + Dataset.zip((b, a)) == { (4, 1), (5, 2), (6, 3) } + + # The `datasets` argument may contain an arbitrary number of + # datasets. + Dataset.zip((a, b, c)) == { (1, 4, (7, 8)), + (2, 5, (9, 10)), + (3, 6, (11, 12)) } + + # The number of elements in the resulting dataset is the same as + # the size of the smallest dataset in `datasets`. + Dataset.zip((a, d)) == { (1, 13), (2, 14) } + ``` + + Args: + datasets: A nested structure of datasets. + + Returns: + A `Dataset`. + """ + return Dataset(dataset_ops.ZipDataset(datasets)) + + def concatenate(self, dataset): + """Creates a `Dataset` by concatenating given dataset with this dataset. + + ```python + # NOTE: The following examples use `{ ... }` to represent the + # contents of a dataset. + a = { 1, 2, 3 } + b = { 4, 5, 6, 7 } + + # Input dataset and dataset to be concatenated should have same + # nested structures and output types. + # c = { (8, 9), (10, 11), (12, 13) } + # d = { 14.0, 15.0, 16.0 } + # a.concatenate(c) and a.concatenate(d) would result in error. + + a.concatenate(b) == { 1, 2, 3, 4, 5, 6, 7 } + ``` + + Args: + dataset: `Dataset` to be concatenated. + + Returns: + A `Dataset`. + """ + return Dataset(dataset_ops.ConcatenateDataset(self._dataset, dataset)) + + def prefetch(self, buffer_size): + """Creates a `Dataset` that prefetches elements from this dataset. + + Args: + buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the + maximum number elements that will be buffered when prefetching. + + Returns: + A `Dataset`. + """ + return Dataset(dataset_ops.PrefetchDataset(self._dataset, buffer_size)) + + @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.list_files()`.") + def list_files(file_pattern): + """A dataset of all files matching a pattern. + + Example: + If we had the following files on our filesystem: + - /path/to/dir/a.txt + - /path/to/dir/b.py + - /path/to/dir/c.py + If we pass "/path/to/dir/*.py" as the directory, the dataset would + produce: + - /path/to/dir/b.py + - /path/to/dir/c.py + + Args: + file_pattern: A string or scalar string `tf.Tensor`, representing + the filename pattern that will be matched. + + Returns: + A `Dataset` of strings corresponding to file names. + """ + return Dataset.from_tensor_slices(gen_io_ops.matching_files(file_pattern)) + + def repeat(self, count=None): + """Repeats this dataset `count` times. + + Args: + count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + number of times the elements of this dataset should be repeated. The + default behavior (if `count` is `None` or `-1`) is for the elements to + be repeated indefinitely. + + Returns: + A `Dataset`. + """ + return Dataset(dataset_ops.RepeatDataset(self._dataset, count)) + + @deprecation.deprecated( + None, "Use `ds.apply(tf.contrib.data.enumerate_dataset())`.") + def enumerate(self, start=0): + """Deprecated: Use `Dataset.apply(tf.contrib.data.enumerate_dataset(..)`.""" + + return self.apply(enumerate_ops.enumerate_dataset(start)) + + def shuffle(self, buffer_size, seed=None): + """Randomly shuffles the elements of this dataset. + + Args: + buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the + number of elements from this dataset from which the new + dataset will sample. + seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + random seed that will be used to create the distribution. See + @{tf.set_random_seed} for behavior. + + Returns: + A `Dataset`. + """ + return Dataset(dataset_ops.ShuffleDataset(self._dataset, buffer_size, seed)) + + def cache(self, filename=""): + """Caches the elements in this dataset. + + Args: + filename: A `tf.string` scalar `tf.Tensor`, representing the name of a + directory on the filesystem to use for caching tensors in this Dataset. + If a filename is not provided, the dataset will be cached in memory. + + Returns: + A `Dataset`. + """ + return Dataset(dataset_ops.CacheDataset(self._dataset, filename)) + + def take(self, count): + """Creates a `Dataset` with at most `count` elements from this dataset. + + Args: + count: A `tf.int64` scalar `tf.Tensor`, representing the number of + elements of this dataset that should be taken to form the new dataset. + If `count` is -1, or if `count` is greater than the size of this + dataset, the new dataset will contain all elements of this dataset. + + Returns: + A `Dataset`. + """ + return Dataset(dataset_ops.TakeDataset(self._dataset, count)) + + def skip(self, count): + """Creates a `Dataset` that skips `count` elements from this dataset. + + Args: + count: A `tf.int64` scalar `tf.Tensor`, representing the number + of elements of this dataset that should be skipped to form the + new dataset. If `count` is greater than the size of this + dataset, the new dataset will contain no elements. If `count` + is -1, skips the entire dataset. + + Returns: + A `Dataset`. + """ + return Dataset(dataset_ops.SkipDataset(self._dataset, count)) + + def shard(self, num_shards, index): + """Creates a `Dataset` that includes only 1/`num_shards` of this dataset. + + This dataset operator is very useful when running distributed training, as + it allows each worker to read a unique subset. + + When reading a single input file, you can skip elements as follows: + + ```python + d = tf.data.TFRecordDataset(FLAGS.input_file) + d = d.shard(FLAGS.num_workers, FLAGS.worker_index) + d = d.repeat(FLAGS.num_epochs) + d = d.shuffle(FLAGS.shuffle_buffer_size) + d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads) + ``` + + Important caveats: + + - Be sure to shard before you use any randomizing operator (such as + shuffle). + - Generally it is best if the shard operator is used early in the dataset + pipeline. For example, when reading from a set of TFRecord files, shard + before converting the dataset to input samples. This avoids reading every + file on every worker. The following is an example of an efficient + sharding strategy within a complete pipeline: + + ```python + d = tf.data.Dataset.list_files(FLAGS.pattern) + d = d.shard(FLAGS.num_workers, FLAGS.worker_index) + d = d.repeat(FLAGS.num_epochs) + d = d.shuffle(FLAGS.shuffle_buffer_size) + d = d.interleave(tf.data.TFRecordDataset, + cycle_length=FLAGS.num_readers, block_length=1) + d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads) + ``` + + Args: + num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of + shards operating in parallel. + index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. + + Returns: + A `Dataset`. + + Raises: + ValueError: if `num_shards` or `index` are illegal values. Note: error + checking is done on a best-effort basis, and aren't guaranteed to be + caught upon dataset creation. (e.g. providing in a placeholder tensor + bypasses the early checking, and will instead result in an error during + a session.run call.) + """ + return Dataset(self._dataset.shard(num_shards, index)) + + @deprecation.deprecated(None, + "Use `ds.apply(tf.contrib.data.ignore_errors())`.") + def ignore_errors(self): + """Deprecated: Use `Dataset.apply(tf.contrib.data.ignore_errors())`.""" + + return self.apply(error_ops.ignore_errors()) + + def batch(self, batch_size): + """Combines consecutive elements of this dataset into batches. + + Args: + batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + consecutive elements of this dataset to combine in a single batch. + + Returns: + A `Dataset`. + """ + return Dataset(dataset_ops.BatchDataset(self._dataset, batch_size)) + + def padded_batch(self, batch_size, padded_shapes, padding_values=None): + """Combines consecutive elements of this dataset into padded batches. + + Like `Dataset.dense_to_sparse_batch()`, this method combines + multiple consecutive elements of this dataset, which might have + different shapes, into a single element. The tensors in the + resulting element have an additional outer dimension, and are + padded to the respective shape in `padded_shapes`. + + Args: + batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + consecutive elements of this dataset to combine in a single batch. + padded_shapes: A nested structure of `tf.TensorShape` or + `tf.int64` vector tensor-like objects representing the shape + to which the respective component of each input element should + be padded prior to batching. Any unknown dimensions + (e.g. `tf.Dimension(None)` in a `tf.TensorShape` or `-1` in a + tensor-like object) will be padded to the maximum size of that + dimension in each batch. + padding_values: (Optional.) A nested structure of scalar-shaped + `tf.Tensor`, representing the padding values to use for the + respective components. Defaults are `0` for numeric types and + the empty string for string types. + + Returns: + A `Dataset`. + """ + return Dataset( + dataset_ops.PaddedBatchDataset(self._dataset, batch_size, padded_shapes, + padding_values)) + + @deprecation.deprecated( + None, "Use `ds.apply(tf.contrib.data.dense_to_sparse_batch())`.") + def dense_to_sparse_batch(self, batch_size, row_shape): + """Use: `Dataset.apply(tf.contrib.data.dense_to_sparse_batch(...))`.""" + + return self.apply(batching.dense_to_sparse_batch(batch_size, row_shape)) + + @deprecation.deprecated(None, + "Use `ds.apply(tf.contrib.data.group_by_window())`.") + def group_by_window(self, key_func, reduce_func, window_size): + """Deprecated: Use `Dataset.apply(tf.contrib.data.group_by_window(...))`.""" + + return self.apply( + grouping.group_by_window(key_func, reduce_func, window_size)) + + @deprecation.deprecated_args( + None, "Replace `num_threads=T` with `num_parallel_calls=T`. Replace " + "`output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.", + "num_threads", "output_buffer_size") + def map(self, + map_func, + num_threads=None, + output_buffer_size=None, + num_parallel_calls=None): + """Maps `map_func` across this dataset. + + Args: + map_func: A function mapping a nested structure of tensors (having + shapes and types defined by `self.output_shapes` and + `self.output_types`) to another nested structure of tensors. + num_threads: (Optional.) Deprecated, use `num_parallel_calls` instead. + output_buffer_size: (Optional.) A `tf.int64` scalar `tf.Tensor`, + representing the maximum number of processed elements that will be + buffered. + num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, + representing the number elements to process in parallel. If not + specified, elements will be processed sequentially. + + Returns: + A `Dataset`. + """ + if num_threads is None and num_parallel_calls is None: + ret = Dataset(dataset_ops.MapDataset(self._dataset, map_func)) + else: + if num_threads is None: + ret = Dataset( + dataset_ops.ParallelMapDataset(self._dataset, map_func, + num_parallel_calls)) + else: + ret = Dataset( + dataset_ops.ParallelMapDataset(self._dataset, map_func, + num_threads)) + if output_buffer_size is not None: + ret = ret.prefetch(output_buffer_size) + return ret + + def flat_map(self, map_func): + """Maps `map_func` across this dataset and flattens the result. + + Args: + map_func: A function mapping a nested structure of tensors (having shapes + and types defined by `self.output_shapes` and `self.output_types`) to a + `Dataset`. + + Returns: + A `Dataset`. + """ + return Dataset(dataset_ops.FlatMapDataset(self._dataset, map_func)) + + def interleave(self, map_func, cycle_length, block_length=1): + """Maps `map_func` across this dataset, and interleaves the results. + + For example, you can use `Dataset.interleave()` to process many input files + concurrently: + + ```python + # Preprocess 4 files concurrently, and interleave blocks of 16 records from + # each file. + filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...] + dataset = (Dataset.from_tensor_slices(filenames) + .interleave(lambda x: + TextLineDataset(x).map(parse_fn, num_parallel_calls=1), + cycle_length=4, block_length=16)) + ``` + + The `cycle_length` and `block_length` arguments control the order in which + elements are produced. `cycle_length` controls the number of input elements + that are processed concurrently. If you set `cycle_length` to 1, this + transformation will handle one input element at a time, and will produce + identical results = to @{tf.data.Dataset.flat_map}. In general, + this transformation will apply `map_func` to `cycle_length` input elements, + open iterators on the returned `Dataset` objects, and cycle through them + producing `block_length` consecutive elements from each iterator, and + consuming the next input element each time it reaches the end of an + iterator. + + For example: + + ```python + # NOTE: The following examples use `{ ... }` to represent the + # contents of a dataset. + a = { 1, 2, 3, 4, 5 } + + # NOTE: New lines indicate "block" boundaries. + a.interleave(lambda x: Dataset.from_tensors(x).repeat(6), + cycle_length=2, block_length=4) == { + 1, 1, 1, 1, + 2, 2, 2, 2, + 1, 1, + 2, 2, + 3, 3, 3, 3, + 4, 4, 4, 4, + 3, 3, + 4, 4, + 5, 5, 5, 5, + 5, 5, + } + ``` + + NOTE: The order of elements yielded by this transformation is + deterministic, as long as `map_func` is a pure function. If + `map_func` contains any stateful operations, the order in which + that state is accessed is undefined. + + Args: + map_func: A function mapping a nested structure of tensors (having shapes + and types defined by `self.output_shapes` and `self.output_types`) to a + `Dataset`. + cycle_length: The number of elements from this dataset that will be + processed concurrently. + block_length: The number of consecutive elements to produce from each + input element before cycling to another input element. + + Returns: + A `Dataset`. + """ + return Dataset( + dataset_ops.InterleaveDataset(self._dataset, map_func, cycle_length, + block_length)) + + @deprecation.deprecated(None, "Use `ds.apply(tf.contrib.data.unbatch())`.") + def unbatch(self): + """Deprecated: Use `Dataset.apply(tf.contrib.data.unbatch()`.""" + + return self.apply(batching.unbatch()) + + def filter(self, predicate): + """Filters this dataset according to `predicate`. + + Args: + predicate: A function mapping a nested structure of tensors (having shapes + and types defined by `self.output_shapes` and `self.output_types`) to a + scalar `tf.bool` tensor. + + Returns: + A `Dataset`. + """ + return Dataset(dataset_ops.FilterDataset(self._dataset, predicate)) + + def apply(self, transformation_func): + """Apply a transformation function to this dataset. + + `apply` enables chaining of custom `Dataset` transformations, which are + represented as functions that take one `Dataset` argument and return a + transformed `Dataset`. + + For example: + + ``` + dataset = (dataset.map(lambda x: x ** 2) + .(group_by_window(key_func, reduce_func, window_size)) + .map(lambda x: x ** 3)) + ``` + + Args: + transformation_func: A function that takes one `Dataset` argument and + returns a `Dataset`. + + Returns: + The `Dataset` returned by applying `transformation_func` to this dataset. + """ + dataset = transformation_func(self) + if not isinstance(dataset, dataset_ops.Dataset): + raise TypeError("`transformation_func` must return a Dataset.") + return Dataset(dataset) + + +def get_single_element(dataset): + """Returns the single element in `dataset` as a nested structure of tensors. + + This function enables you to use a @{tf.data.Dataset} in a stateless + "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}. + This can be useful when your preprocessing transformations are expressed + as a `Dataset`, and you want to use the transformation at serving time. + For example: + + ```python + input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE]) + + def preprocessing_fn(input_str): + # ... + return image, label + + dataset = (tf.data.Dataset.from_tensor_slices(input_batch) + .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) + .batch(BATCH_SIZE)) + + image_batch, label_batch = tf.contrib.data.get_single_element(dataset) + ``` + + Args: + dataset: A @{tf.data.Dataset} object containing a single element. + + Returns: + A nested structure of @{tf.Tensor} objects, corresponding to the single + element of `dataset`. + + Raises: + TypeError: if `dataset` is not a `tf.data.Dataset` object. + InvalidArgumentError (at runtime): if `dataset` does not contain exactly + one element. + """ + if not isinstance(dataset, dataset_ops.Dataset): + raise TypeError("`dataset` must be a `tf.data.Dataset` object.") + return nest.pack_sequence_as( + dataset.output_types, + gen_dataset_ops.dataset_to_single_element( + dataset._as_variant_tensor(), # pylint: disable=protected-access + output_types=nest.flatten(dataset.output_types), + output_shapes=nest.flatten(dataset.output_shapes))) |