From 348478f642216cf3cbe1eb67b875252d8e6a6418 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Tue, 25 Sep 2018 13:42:46 -0700 Subject: [tf.data] Adding a private method for (recursively) tracking dataset inputs. PiperOrigin-RevId: 214495925 --- tensorflow/python/data/kernel_tests/BUILD | 13 ++ tensorflow/python/data/kernel_tests/inputs_test.py | 148 +++++++++++++++++++++ tensorflow/python/data/ops/dataset_ops.py | 109 +++++++++------ .../python/data/ops/multi_device_iterator_ops.py | 4 + tensorflow/python/data/ops/readers.py | 12 ++ 5 files changed, 249 insertions(+), 37 deletions(-) create mode 100644 tensorflow/python/data/kernel_tests/inputs_test.py (limited to 'tensorflow/python/data') diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 28ee3ebaa6..7a6f03d4d3 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -445,3 +445,16 @@ tf_py_test( "//tensorflow/python/data/ops:dataset_ops", ], ) + +tf_py_test( + name = "inputs_test", + size = "small", + srcs = ["inputs_test.py"], + additional_deps = [ + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + ], +) diff --git a/tensorflow/python/data/kernel_tests/inputs_test.py b/tensorflow/python/data/kernel_tests/inputs_test.py new file mode 100644 index 0000000000..4c9279dd95 --- /dev/null +++ b/tensorflow/python/data/kernel_tests/inputs_test.py @@ -0,0 +1,148 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.platform import test + + +class InputsTest(test.TestCase, parameterized.TestCase): + + @staticmethod + def make_apply_fn(dataset): + + def apply_fn(dataset): + + def _apply_fn(dataset): + return dataset.cache() + + return dataset.apply(_apply_fn) + + return apply_fn + + @staticmethod + def make_gen(): + + def gen(): + yield 42 + + return gen + + @staticmethod + def make_interleave_fn(dataset, num_parallel_calls=None): + + def interleave_fn(dataset): + return dataset.interleave( + lambda x: dataset_ops.Dataset.range(0), + cycle_length=2, + num_parallel_calls=num_parallel_calls) + + return interleave_fn + + @parameterized.named_parameters( + ("FixedLengthRecord", readers.FixedLengthRecordDataset("", 42)), + ("FromGenerator", + dataset_ops.Dataset.from_generator(make_gen.__func__(), dtypes.int32), + 1), + ("FromSparseTensorSlices", + dataset_ops.Dataset.from_sparse_tensor_slices( + sparse_tensor.SparseTensor( + indices=np.array([[0, 0], [1, 0], [2, 0]]), + values=np.array([0, 0, 0]), + dense_shape=np.array([3, 1])))), + ("FromTensors", dataset_ops.Dataset.from_tensors([42])), + ("FromTensorSlices", dataset_ops.Dataset.from_tensors([42])), + ("Range", dataset_ops.Dataset.range(10)), + ("TextLine", readers.TextLineDataset("")), + ("TFRecord", readers.TFRecordDataset(""), 1), + ) + def testDatasetSourceInputs(self, dataset, num_inputs=0): + self.assertEqual(num_inputs, len(dataset._inputs())) + + @parameterized.named_parameters( + ("Apply", make_apply_fn.__func__(dataset_ops.Dataset.range(0)), + dataset_ops.Dataset.range(0)), + ("Batch", lambda x: x.batch(10), dataset_ops.Dataset.range(0)), + ("Cache", lambda x: x.cache(), dataset_ops.Dataset.range(0)), + ("Filter", lambda x: x.filter(lambda x: True), + dataset_ops.Dataset.range(0)), + ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)), + dataset_ops.Dataset.range(0)), + ("Interleave", make_interleave_fn.__func__(dataset_ops.Dataset.range(0)), + dataset_ops.Dataset.range(0)), + ("Map", lambda x: x.map(lambda x: x), dataset_ops.Dataset.range(0)), + ("PaddedBatch", lambda x: x.padded_batch(10, []), + dataset_ops.Dataset.range(0)), + ("ParallelInterleave", + make_interleave_fn.__func__(dataset_ops.Dataset.range(0), 2), + dataset_ops.Dataset.range(0)), + ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2), + dataset_ops.Dataset.range(0)), + ("Repeat", lambda x: x.repeat(), dataset_ops.Dataset.range(0)), + ("Shuffle", lambda x: x.shuffle(10), dataset_ops.Dataset.range(0)), + ("Skip", lambda x: x.skip(1), dataset_ops.Dataset.range(0)), + ("Take", lambda x: x.take(1), dataset_ops.Dataset.range(0)), + ("Window", lambda x: x.window(10), dataset_ops.Dataset.range(0)), + ) + def testUnaryTransformationInputs(self, dataset_fn, input_dataset): + self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) + + @parameterized.named_parameters( + ("Concatenate", lambda x, y: x.concatenate(y), + dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))) + def testBinaryTransformationInputs(self, dataset_fn, input1, input2): + self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs()) + + @parameterized.named_parameters( + ("ZipOne", dataset_ops.Dataset.zip, (dataset_ops.Dataset.range(0))), + ("ZipNest", dataset_ops.Dataset.zip, + (dataset_ops.Dataset.range(0), + (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))), + ("ZipTuple", dataset_ops.Dataset.zip, + (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1)))) + def testVariadicTransformationInputs(self, dataset_fn, input_datasets): + self.assertEqual( + nest.flatten(input_datasets), + dataset_fn(input_datasets)._inputs()) + + def testCollectInputs(self): + ds1 = dataset_ops.Dataset.range(0) + ds2 = ds1.concatenate(ds1) + ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2)) + + inputs = [] + queue = [ds3] + while queue: + ds = queue[0] + queue = queue[1:] + queue.extend(ds._inputs()) + inputs.append(ds) + + self.assertEqual(5, inputs.count(ds1)) + self.assertEqual(2, inputs.count(ds2)) + self.assertEqual(1, inputs.count(ds3)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 7c20c049f5..ac87a451b1 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -80,6 +80,12 @@ class Dataset(object): """ raise NotImplementedError("Dataset._as_variant_tensor") + @abc.abstractmethod + def _inputs(self): + """Returns a list of the input datasets of the dataset.""" + + raise NotImplementedError("Dataset._inputs") + def make_initializable_iterator(self, shared_name=None): """Creates an `Iterator` for enumerating the elements of this dataset. @@ -1007,8 +1013,8 @@ class Dataset(object): return ParallelMapDataset(self, map_func, num_parallel_calls) def flat_map(self, map_func): - """Maps `map_func` across this dataset and flattens the result. - + """Maps `map_func` across this dataset and flattens the result. + Use `flat_map` if you want to make sure that the order of your dataset stays the same. For example, to flatten a dataset of batches into a dataset of their elements: @@ -1017,15 +1023,15 @@ class Dataset(object): # NOTE: The following examples use `{ ... }` to represent the # contents of a dataset. '[...]' represents a tensor. a = {[1,2,3,4,5], [6,7,8,9], [10]} - - a.flat_map(lambda x: Dataset.from_tensor_slices(x)) == + + a.flat_map(lambda x: Dataset.from_tensor_slices(x)) == {[1,2,3,4,5,6,7,8,9,10]} ``` - - `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since - `flat_map` produces the same output as + + `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since + `flat_map` produces the same output as `tf.data.Dataset.interleave(cycle_length=1)` - + Args: map_func: A function mapping a nested structure of tensors (having shapes and types defined by `self.output_shapes` and `self.output_types`) to a @@ -1157,6 +1163,7 @@ class Dataset(object): dataset = transformation_func(self) if not isinstance(dataset, Dataset): raise TypeError("`transformation_func` must return a Dataset.") + dataset._input_datasets = [self] # pylint: disable=protected-access return dataset def window(self, size, shift=None, stride=1, drop_remainder=False): @@ -1199,7 +1206,25 @@ class Dataset(object): return WindowDataset(self, size, shift, stride, drop_remainder) -class TensorDataset(Dataset): +class DatasetSource(Dataset): + """Abstract class representing a dataset with no inputs.""" + + def _inputs(self): + return [] + + +class UnaryDataset(Dataset): + """Abstract class representing a dataset with one input.""" + + def __init__(self, input_dataset): + super(UnaryDataset, self).__init__() + self._input_dataset = input_dataset + + def _inputs(self): + return [self._input_dataset] + + +class TensorDataset(DatasetSource): """A `Dataset` with a single element, viz. a nested structure of tensors.""" def __init__(self, tensors): @@ -1239,7 +1264,7 @@ class TensorDataset(Dataset): return self._output_types -class TensorSliceDataset(Dataset): +class TensorSliceDataset(DatasetSource): """A `Dataset` of slices from a nested structure of tensors.""" def __init__(self, tensors): @@ -1283,7 +1308,7 @@ class TensorSliceDataset(Dataset): return self._output_types -class SparseTensorSliceDataset(Dataset): +class SparseTensorSliceDataset(DatasetSource): """A `Dataset` that splits a rank-N `tf.SparseTensor` into its rows.""" def __init__(self, sparse_tensor): @@ -1384,6 +1409,9 @@ class _VariantDataset(Dataset): def _as_variant_tensor(self): return self._dataset_variant + def _inputs(self): + return [] + @property def output_classes(self): return self._structure.output_classes @@ -1624,7 +1652,7 @@ def flat_structure(dataset): } -class _GeneratorDataset(Dataset): +class _GeneratorDataset(DatasetSource): """A `Dataset` that generates elements by invoking a function.""" def __init__(self, init_args, init_func, next_func, finalize_func): @@ -1725,6 +1753,9 @@ class ZipDataset(Dataset): **flat_structure(self)) # pylint: enable=protected-access + def _inputs(self): + return nest.flatten(self._datasets) + @property def output_classes(self): return nest.pack_sequence_as( @@ -1760,6 +1791,7 @@ class ConcatenateDataset(Dataset): raise TypeError( "Two datasets to concatenate have different classes %s and %s" % (input_dataset.output_classes, dataset_to_concatenate.output_classes)) + self._input_datasets = [input_dataset, dataset_to_concatenate] def _as_variant_tensor(self): # pylint: disable=protected-access @@ -1769,6 +1801,9 @@ class ConcatenateDataset(Dataset): **flat_structure(self)) # pylint: enable=protected-access + def _inputs(self): + return [self._input_dataset, self._dataset_to_concatenate] + @property def output_classes(self): return self._input_dataset.output_classes @@ -1787,12 +1822,12 @@ class ConcatenateDataset(Dataset): return self._input_dataset.output_types -class RepeatDataset(Dataset): +class RepeatDataset(UnaryDataset): """A `Dataset` that repeats its input several times.""" def __init__(self, input_dataset, count): """See `Dataset.repeat()` for details.""" - super(RepeatDataset, self).__init__() + super(RepeatDataset, self).__init__(input_dataset) self._input_dataset = input_dataset if count is None: self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") @@ -1819,7 +1854,7 @@ class RepeatDataset(Dataset): return self._input_dataset.output_types -class RangeDataset(Dataset): +class RangeDataset(DatasetSource): """A `Dataset` of a step separated range of values.""" def __init__(self, *args): @@ -1867,12 +1902,12 @@ class RangeDataset(Dataset): return dtypes.int64 -class CacheDataset(Dataset): +class CacheDataset(UnaryDataset): """A `Dataset` that caches elements of its input.""" def __init__(self, input_dataset, filename): """See `Dataset.cache()` for details.""" - super(CacheDataset, self).__init__() + super(CacheDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._filename = ops.convert_to_tensor( filename, dtype=dtypes.string, name="filename") @@ -1896,7 +1931,7 @@ class CacheDataset(Dataset): return self._input_dataset.output_types -class ShuffleDataset(Dataset): +class ShuffleDataset(UnaryDataset): """A `Dataset` that randomly shuffles the elements of its input.""" def __init__(self, @@ -1924,7 +1959,7 @@ class ShuffleDataset(Dataset): Raises: ValueError: if invalid arguments are provided. """ - super(ShuffleDataset, self).__init__() + super(ShuffleDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._buffer_size = ops.convert_to_tensor( buffer_size, dtype=dtypes.int64, name="buffer_size") @@ -1956,12 +1991,12 @@ class ShuffleDataset(Dataset): return self._input_dataset.output_types -class TakeDataset(Dataset): +class TakeDataset(UnaryDataset): """A `Dataset` containing the first `count` elements from its input.""" def __init__(self, input_dataset, count): """See `Dataset.take()` for details.""" - super(TakeDataset, self).__init__() + super(TakeDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count") @@ -1984,12 +2019,12 @@ class TakeDataset(Dataset): return self._input_dataset.output_types -class SkipDataset(Dataset): +class SkipDataset(UnaryDataset): """A `Dataset` skipping the first `count` elements from its input.""" def __init__(self, input_dataset, count): """See `Dataset.skip()` for details.""" - super(SkipDataset, self).__init__() + super(SkipDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count") @@ -2012,12 +2047,12 @@ class SkipDataset(Dataset): return self._input_dataset.output_types -class BatchDataset(Dataset): +class BatchDataset(UnaryDataset): """A `Dataset` that batches contiguous elements from its input.""" def __init__(self, input_dataset, batch_size, drop_remainder): """See `Dataset.batch()` for details.""" - super(BatchDataset, self).__init__() + super(BatchDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._batch_size = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") @@ -2166,13 +2201,13 @@ def _default_padding(input_dataset): return nest.map_structure(make_zero, input_dataset.output_types) -class PaddedBatchDataset(Dataset): +class PaddedBatchDataset(UnaryDataset): """A `Dataset` that batches and pads contiguous elements from its input.""" def __init__(self, input_dataset, batch_size, padded_shapes, padding_values, drop_remainder): """See `Dataset.batch()` for details.""" - super(PaddedBatchDataset, self).__init__() + super(PaddedBatchDataset, self).__init__(input_dataset) if sparse.any_sparse(input_dataset.output_classes): # TODO(b/63669786): support batching of sparse tensors raise TypeError( @@ -2272,12 +2307,12 @@ def _warn_if_collections(transformation_name): % transformation_name) -class MapDataset(Dataset): +class MapDataset(UnaryDataset): """A `Dataset` that maps a function over elements in its input.""" def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True): """See `Dataset.map()` for details.""" - super(MapDataset, self).__init__() + super(MapDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._use_inter_op_parallelism = use_inter_op_parallelism @@ -2338,12 +2373,12 @@ class ParallelMapDataset(MapDataset): # pylint: enable=protected-access -class FlatMapDataset(Dataset): +class FlatMapDataset(UnaryDataset): """A `Dataset` that maps a function over its input and flattens the result.""" def __init__(self, input_dataset, map_func): """See `Dataset.flat_map()` for details.""" - super(FlatMapDataset, self).__init__() + super(FlatMapDataset, self).__init__(input_dataset) self._input_dataset = input_dataset wrapped_func = StructuredFunctionWrapper( @@ -2434,12 +2469,12 @@ class ParallelInterleaveDataset(FlatMapDataset): return "Dataset.interleave()" -class FilterDataset(Dataset): +class FilterDataset(UnaryDataset): """A `Dataset` that filters its input according to a predicate function.""" def __init__(self, input_dataset, predicate): """See `Dataset.filter()` for details.""" - super(FilterDataset, self).__init__() + super(FilterDataset, self).__init__(input_dataset) self._input_dataset = input_dataset wrapped_func = StructuredFunctionWrapper( predicate, "Dataset.filter()", input_dataset) @@ -2469,12 +2504,12 @@ class FilterDataset(Dataset): return self._input_dataset.output_types -class PrefetchDataset(Dataset): +class PrefetchDataset(UnaryDataset): """A `Dataset` that asynchronously prefetches its input.""" def __init__(self, input_dataset, buffer_size): """See `Dataset.prefetch()` for details.""" - super(PrefetchDataset, self).__init__() + super(PrefetchDataset, self).__init__(input_dataset) self._input_dataset = input_dataset if buffer_size is None: buffer_size = -1 # This is the sentinel for auto-tuning. @@ -2500,12 +2535,12 @@ class PrefetchDataset(Dataset): return self._input_dataset.output_types -class WindowDataset(Dataset): +class WindowDataset(UnaryDataset): """A dataset that creates window datasets from the input elements.""" def __init__(self, input_dataset, size, shift, stride, drop_remainder): """See `window_dataset()` for more details.""" - super(WindowDataset, self).__init__() + super(WindowDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size") self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift") diff --git a/tensorflow/python/data/ops/multi_device_iterator_ops.py b/tensorflow/python/data/ops/multi_device_iterator_ops.py index c914a43956..b7d3aac206 100644 --- a/tensorflow/python/data/ops/multi_device_iterator_ops.py +++ b/tensorflow/python/data/ops/multi_device_iterator_ops.py @@ -116,6 +116,10 @@ class _PerDeviceGenerator(dataset_ops.Dataset): output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) + def _inputs(self): + # TODO(b/116506223): Determine which datasets should be used as inputs here. + return [] + @property def output_types(self): return self._output_types diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py index 066e09969c..b0f26631f9 100644 --- a/tensorflow/python/data/ops/readers.py +++ b/tensorflow/python/data/ops/readers.py @@ -61,6 +61,9 @@ class TextLineDataset(dataset_ops.Dataset): return gen_dataset_ops.text_line_dataset( self._filenames, self._compression_type, self._buffer_size) + def _inputs(self): + return [] + @property def output_classes(self): return ops.Tensor @@ -105,6 +108,9 @@ class _TFRecordDataset(dataset_ops.Dataset): return gen_dataset_ops.tf_record_dataset( self._filenames, self._compression_type, self._buffer_size) + def _inputs(self): + return [] + @property def output_classes(self): return ops.Tensor @@ -224,6 +230,9 @@ class TFRecordDataset(dataset_ops.Dataset): def _as_variant_tensor(self): return self._impl._as_variant_tensor() # pylint: disable=protected-access + def _inputs(self): + return self._impl._inputs() # pylint: disable=protected-access + @property def output_classes(self): return self._impl.output_classes @@ -278,6 +287,9 @@ class FixedLengthRecordDataset(dataset_ops.Dataset): self._filenames, self._header_bytes, self._record_bytes, self._footer_bytes, self._buffer_size) + def _inputs(self): + return [] + @property def output_classes(self): return ops.Tensor -- cgit v1.2.3