aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-25 13:42:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 13:46:54 -0700
commit348478f642216cf3cbe1eb67b875252d8e6a6418 (patch)
treec4c7afd4283506b2c413e429cab02fc547e13481 /tensorflow/python/data
parent976fb3105312bb17accebcbca2ebae906bcf99fb (diff)
[tf.data] Adding a private method for (recursively) tracking dataset inputs.
PiperOrigin-RevId: 214495925
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD13
-rw-r--r--tensorflow/python/data/kernel_tests/inputs_test.py148
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py109
-rw-r--r--tensorflow/python/data/ops/multi_device_iterator_ops.py4
-rw-r--r--tensorflow/python/data/ops/readers.py12
5 files changed, 249 insertions, 37 deletions
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 28ee3ebaa6..7a6f03d4d3 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -445,3 +445,16 @@ tf_py_test(
"//tensorflow/python/data/ops:dataset_ops",
],
)
+
+tf_py_test(
+ name = "inputs_test",
+ size = "small",
+ srcs = ["inputs_test.py"],
+ additional_deps = [
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
diff --git a/tensorflow/python/data/kernel_tests/inputs_test.py b/tensorflow/python/data/kernel_tests/inputs_test.py
new file mode 100644
index 0000000000..4c9279dd95
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/inputs_test.py
@@ -0,0 +1,148 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.platform import test
+
+
+class InputsTest(test.TestCase, parameterized.TestCase):
+
+ @staticmethod
+ def make_apply_fn(dataset):
+
+ def apply_fn(dataset):
+
+ def _apply_fn(dataset):
+ return dataset.cache()
+
+ return dataset.apply(_apply_fn)
+
+ return apply_fn
+
+ @staticmethod
+ def make_gen():
+
+ def gen():
+ yield 42
+
+ return gen
+
+ @staticmethod
+ def make_interleave_fn(dataset, num_parallel_calls=None):
+
+ def interleave_fn(dataset):
+ return dataset.interleave(
+ lambda x: dataset_ops.Dataset.range(0),
+ cycle_length=2,
+ num_parallel_calls=num_parallel_calls)
+
+ return interleave_fn
+
+ @parameterized.named_parameters(
+ ("FixedLengthRecord", readers.FixedLengthRecordDataset("", 42)),
+ ("FromGenerator",
+ dataset_ops.Dataset.from_generator(make_gen.__func__(), dtypes.int32),
+ 1),
+ ("FromSparseTensorSlices",
+ dataset_ops.Dataset.from_sparse_tensor_slices(
+ sparse_tensor.SparseTensor(
+ indices=np.array([[0, 0], [1, 0], [2, 0]]),
+ values=np.array([0, 0, 0]),
+ dense_shape=np.array([3, 1])))),
+ ("FromTensors", dataset_ops.Dataset.from_tensors([42])),
+ ("FromTensorSlices", dataset_ops.Dataset.from_tensors([42])),
+ ("Range", dataset_ops.Dataset.range(10)),
+ ("TextLine", readers.TextLineDataset("")),
+ ("TFRecord", readers.TFRecordDataset(""), 1),
+ )
+ def testDatasetSourceInputs(self, dataset, num_inputs=0):
+ self.assertEqual(num_inputs, len(dataset._inputs()))
+
+ @parameterized.named_parameters(
+ ("Apply", make_apply_fn.__func__(dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Batch", lambda x: x.batch(10), dataset_ops.Dataset.range(0)),
+ ("Cache", lambda x: x.cache(), dataset_ops.Dataset.range(0)),
+ ("Filter", lambda x: x.filter(lambda x: True),
+ dataset_ops.Dataset.range(0)),
+ ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Interleave", make_interleave_fn.__func__(dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Map", lambda x: x.map(lambda x: x), dataset_ops.Dataset.range(0)),
+ ("PaddedBatch", lambda x: x.padded_batch(10, []),
+ dataset_ops.Dataset.range(0)),
+ ("ParallelInterleave",
+ make_interleave_fn.__func__(dataset_ops.Dataset.range(0), 2),
+ dataset_ops.Dataset.range(0)),
+ ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
+ dataset_ops.Dataset.range(0)),
+ ("Repeat", lambda x: x.repeat(), dataset_ops.Dataset.range(0)),
+ ("Shuffle", lambda x: x.shuffle(10), dataset_ops.Dataset.range(0)),
+ ("Skip", lambda x: x.skip(1), dataset_ops.Dataset.range(0)),
+ ("Take", lambda x: x.take(1), dataset_ops.Dataset.range(0)),
+ ("Window", lambda x: x.window(10), dataset_ops.Dataset.range(0)),
+ )
+ def testUnaryTransformationInputs(self, dataset_fn, input_dataset):
+ self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())
+
+ @parameterized.named_parameters(
+ ("Concatenate", lambda x, y: x.concatenate(y),
+ dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1)))
+ def testBinaryTransformationInputs(self, dataset_fn, input1, input2):
+ self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())
+
+ @parameterized.named_parameters(
+ ("ZipOne", dataset_ops.Dataset.zip, (dataset_ops.Dataset.range(0))),
+ ("ZipNest", dataset_ops.Dataset.zip,
+ (dataset_ops.Dataset.range(0),
+ (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
+ ("ZipTuple", dataset_ops.Dataset.zip,
+ (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))))
+ def testVariadicTransformationInputs(self, dataset_fn, input_datasets):
+ self.assertEqual(
+ nest.flatten(input_datasets),
+ dataset_fn(input_datasets)._inputs())
+
+ def testCollectInputs(self):
+ ds1 = dataset_ops.Dataset.range(0)
+ ds2 = ds1.concatenate(ds1)
+ ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))
+
+ inputs = []
+ queue = [ds3]
+ while queue:
+ ds = queue[0]
+ queue = queue[1:]
+ queue.extend(ds._inputs())
+ inputs.append(ds)
+
+ self.assertEqual(5, inputs.count(ds1))
+ self.assertEqual(2, inputs.count(ds2))
+ self.assertEqual(1, inputs.count(ds3))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 7c20c049f5..ac87a451b1 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -80,6 +80,12 @@ class Dataset(object):
"""
raise NotImplementedError("Dataset._as_variant_tensor")
+ @abc.abstractmethod
+ def _inputs(self):
+ """Returns a list of the input datasets of the dataset."""
+
+ raise NotImplementedError("Dataset._inputs")
+
def make_initializable_iterator(self, shared_name=None):
"""Creates an `Iterator` for enumerating the elements of this dataset.
@@ -1007,8 +1013,8 @@ class Dataset(object):
return ParallelMapDataset(self, map_func, num_parallel_calls)
def flat_map(self, map_func):
- """Maps `map_func` across this dataset and flattens the result.
-
+ """Maps `map_func` across this dataset and flattens the result.
+
Use `flat_map` if you want to make sure that the order of your dataset
stays the same. For example, to flatten a dataset of batches into a
dataset of their elements:
@@ -1017,15 +1023,15 @@ class Dataset(object):
# NOTE: The following examples use `{ ... }` to represent the
# contents of a dataset. '[...]' represents a tensor.
a = {[1,2,3,4,5], [6,7,8,9], [10]}
-
- a.flat_map(lambda x: Dataset.from_tensor_slices(x)) ==
+
+ a.flat_map(lambda x: Dataset.from_tensor_slices(x)) ==
{[1,2,3,4,5,6,7,8,9,10]}
```
-
- `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since
- `flat_map` produces the same output as
+
+ `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since
+ `flat_map` produces the same output as
`tf.data.Dataset.interleave(cycle_length=1)`
-
+
Args:
map_func: A function mapping a nested structure of tensors (having shapes
and types defined by `self.output_shapes` and `self.output_types`) to a
@@ -1157,6 +1163,7 @@ class Dataset(object):
dataset = transformation_func(self)
if not isinstance(dataset, Dataset):
raise TypeError("`transformation_func` must return a Dataset.")
+ dataset._input_datasets = [self] # pylint: disable=protected-access
return dataset
def window(self, size, shift=None, stride=1, drop_remainder=False):
@@ -1199,7 +1206,25 @@ class Dataset(object):
return WindowDataset(self, size, shift, stride, drop_remainder)
-class TensorDataset(Dataset):
+class DatasetSource(Dataset):
+ """Abstract class representing a dataset with no inputs."""
+
+ def _inputs(self):
+ return []
+
+
+class UnaryDataset(Dataset):
+ """Abstract class representing a dataset with one input."""
+
+ def __init__(self, input_dataset):
+ super(UnaryDataset, self).__init__()
+ self._input_dataset = input_dataset
+
+ def _inputs(self):
+ return [self._input_dataset]
+
+
+class TensorDataset(DatasetSource):
"""A `Dataset` with a single element, viz. a nested structure of tensors."""
def __init__(self, tensors):
@@ -1239,7 +1264,7 @@ class TensorDataset(Dataset):
return self._output_types
-class TensorSliceDataset(Dataset):
+class TensorSliceDataset(DatasetSource):
"""A `Dataset` of slices from a nested structure of tensors."""
def __init__(self, tensors):
@@ -1283,7 +1308,7 @@ class TensorSliceDataset(Dataset):
return self._output_types
-class SparseTensorSliceDataset(Dataset):
+class SparseTensorSliceDataset(DatasetSource):
"""A `Dataset` that splits a rank-N `tf.SparseTensor` into its rows."""
def __init__(self, sparse_tensor):
@@ -1384,6 +1409,9 @@ class _VariantDataset(Dataset):
def _as_variant_tensor(self):
return self._dataset_variant
+ def _inputs(self):
+ return []
+
@property
def output_classes(self):
return self._structure.output_classes
@@ -1624,7 +1652,7 @@ def flat_structure(dataset):
}
-class _GeneratorDataset(Dataset):
+class _GeneratorDataset(DatasetSource):
"""A `Dataset` that generates elements by invoking a function."""
def __init__(self, init_args, init_func, next_func, finalize_func):
@@ -1725,6 +1753,9 @@ class ZipDataset(Dataset):
**flat_structure(self))
# pylint: enable=protected-access
+ def _inputs(self):
+ return nest.flatten(self._datasets)
+
@property
def output_classes(self):
return nest.pack_sequence_as(
@@ -1760,6 +1791,7 @@ class ConcatenateDataset(Dataset):
raise TypeError(
"Two datasets to concatenate have different classes %s and %s" %
(input_dataset.output_classes, dataset_to_concatenate.output_classes))
+ self._input_datasets = [input_dataset, dataset_to_concatenate]
def _as_variant_tensor(self):
# pylint: disable=protected-access
@@ -1769,6 +1801,9 @@ class ConcatenateDataset(Dataset):
**flat_structure(self))
# pylint: enable=protected-access
+ def _inputs(self):
+ return [self._input_dataset, self._dataset_to_concatenate]
+
@property
def output_classes(self):
return self._input_dataset.output_classes
@@ -1787,12 +1822,12 @@ class ConcatenateDataset(Dataset):
return self._input_dataset.output_types
-class RepeatDataset(Dataset):
+class RepeatDataset(UnaryDataset):
"""A `Dataset` that repeats its input several times."""
def __init__(self, input_dataset, count):
"""See `Dataset.repeat()` for details."""
- super(RepeatDataset, self).__init__()
+ super(RepeatDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if count is None:
self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
@@ -1819,7 +1854,7 @@ class RepeatDataset(Dataset):
return self._input_dataset.output_types
-class RangeDataset(Dataset):
+class RangeDataset(DatasetSource):
"""A `Dataset` of a step separated range of values."""
def __init__(self, *args):
@@ -1867,12 +1902,12 @@ class RangeDataset(Dataset):
return dtypes.int64
-class CacheDataset(Dataset):
+class CacheDataset(UnaryDataset):
"""A `Dataset` that caches elements of its input."""
def __init__(self, input_dataset, filename):
"""See `Dataset.cache()` for details."""
- super(CacheDataset, self).__init__()
+ super(CacheDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._filename = ops.convert_to_tensor(
filename, dtype=dtypes.string, name="filename")
@@ -1896,7 +1931,7 @@ class CacheDataset(Dataset):
return self._input_dataset.output_types
-class ShuffleDataset(Dataset):
+class ShuffleDataset(UnaryDataset):
"""A `Dataset` that randomly shuffles the elements of its input."""
def __init__(self,
@@ -1924,7 +1959,7 @@ class ShuffleDataset(Dataset):
Raises:
ValueError: if invalid arguments are provided.
"""
- super(ShuffleDataset, self).__init__()
+ super(ShuffleDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
@@ -1956,12 +1991,12 @@ class ShuffleDataset(Dataset):
return self._input_dataset.output_types
-class TakeDataset(Dataset):
+class TakeDataset(UnaryDataset):
"""A `Dataset` containing the first `count` elements from its input."""
def __init__(self, input_dataset, count):
"""See `Dataset.take()` for details."""
- super(TakeDataset, self).__init__()
+ super(TakeDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
@@ -1984,12 +2019,12 @@ class TakeDataset(Dataset):
return self._input_dataset.output_types
-class SkipDataset(Dataset):
+class SkipDataset(UnaryDataset):
"""A `Dataset` skipping the first `count` elements from its input."""
def __init__(self, input_dataset, count):
"""See `Dataset.skip()` for details."""
- super(SkipDataset, self).__init__()
+ super(SkipDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
@@ -2012,12 +2047,12 @@ class SkipDataset(Dataset):
return self._input_dataset.output_types
-class BatchDataset(Dataset):
+class BatchDataset(UnaryDataset):
"""A `Dataset` that batches contiguous elements from its input."""
def __init__(self, input_dataset, batch_size, drop_remainder):
"""See `Dataset.batch()` for details."""
- super(BatchDataset, self).__init__()
+ super(BatchDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
@@ -2166,13 +2201,13 @@ def _default_padding(input_dataset):
return nest.map_structure(make_zero, input_dataset.output_types)
-class PaddedBatchDataset(Dataset):
+class PaddedBatchDataset(UnaryDataset):
"""A `Dataset` that batches and pads contiguous elements from its input."""
def __init__(self, input_dataset, batch_size, padded_shapes, padding_values,
drop_remainder):
"""See `Dataset.batch()` for details."""
- super(PaddedBatchDataset, self).__init__()
+ super(PaddedBatchDataset, self).__init__(input_dataset)
if sparse.any_sparse(input_dataset.output_classes):
# TODO(b/63669786): support batching of sparse tensors
raise TypeError(
@@ -2272,12 +2307,12 @@ def _warn_if_collections(transformation_name):
% transformation_name)
-class MapDataset(Dataset):
+class MapDataset(UnaryDataset):
"""A `Dataset` that maps a function over elements in its input."""
def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
"""See `Dataset.map()` for details."""
- super(MapDataset, self).__init__()
+ super(MapDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._use_inter_op_parallelism = use_inter_op_parallelism
@@ -2338,12 +2373,12 @@ class ParallelMapDataset(MapDataset):
# pylint: enable=protected-access
-class FlatMapDataset(Dataset):
+class FlatMapDataset(UnaryDataset):
"""A `Dataset` that maps a function over its input and flattens the result."""
def __init__(self, input_dataset, map_func):
"""See `Dataset.flat_map()` for details."""
- super(FlatMapDataset, self).__init__()
+ super(FlatMapDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
wrapped_func = StructuredFunctionWrapper(
@@ -2434,12 +2469,12 @@ class ParallelInterleaveDataset(FlatMapDataset):
return "Dataset.interleave()"
-class FilterDataset(Dataset):
+class FilterDataset(UnaryDataset):
"""A `Dataset` that filters its input according to a predicate function."""
def __init__(self, input_dataset, predicate):
"""See `Dataset.filter()` for details."""
- super(FilterDataset, self).__init__()
+ super(FilterDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
wrapped_func = StructuredFunctionWrapper(
predicate, "Dataset.filter()", input_dataset)
@@ -2469,12 +2504,12 @@ class FilterDataset(Dataset):
return self._input_dataset.output_types
-class PrefetchDataset(Dataset):
+class PrefetchDataset(UnaryDataset):
"""A `Dataset` that asynchronously prefetches its input."""
def __init__(self, input_dataset, buffer_size):
"""See `Dataset.prefetch()` for details."""
- super(PrefetchDataset, self).__init__()
+ super(PrefetchDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if buffer_size is None:
buffer_size = -1 # This is the sentinel for auto-tuning.
@@ -2500,12 +2535,12 @@ class PrefetchDataset(Dataset):
return self._input_dataset.output_types
-class WindowDataset(Dataset):
+class WindowDataset(UnaryDataset):
"""A dataset that creates window datasets from the input elements."""
def __init__(self, input_dataset, size, shift, stride, drop_remainder):
"""See `window_dataset()` for more details."""
- super(WindowDataset, self).__init__()
+ super(WindowDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size")
self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift")
diff --git a/tensorflow/python/data/ops/multi_device_iterator_ops.py b/tensorflow/python/data/ops/multi_device_iterator_ops.py
index c914a43956..b7d3aac206 100644
--- a/tensorflow/python/data/ops/multi_device_iterator_ops.py
+++ b/tensorflow/python/data/ops/multi_device_iterator_ops.py
@@ -116,6 +116,10 @@ class _PerDeviceGenerator(dataset_ops.Dataset):
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
+ def _inputs(self):
+ # TODO(b/116506223): Determine which datasets should be used as inputs here.
+ return []
+
@property
def output_types(self):
return self._output_types
diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py
index 066e09969c..b0f26631f9 100644
--- a/tensorflow/python/data/ops/readers.py
+++ b/tensorflow/python/data/ops/readers.py
@@ -61,6 +61,9 @@ class TextLineDataset(dataset_ops.Dataset):
return gen_dataset_ops.text_line_dataset(
self._filenames, self._compression_type, self._buffer_size)
+ def _inputs(self):
+ return []
+
@property
def output_classes(self):
return ops.Tensor
@@ -105,6 +108,9 @@ class _TFRecordDataset(dataset_ops.Dataset):
return gen_dataset_ops.tf_record_dataset(
self._filenames, self._compression_type, self._buffer_size)
+ def _inputs(self):
+ return []
+
@property
def output_classes(self):
return ops.Tensor
@@ -224,6 +230,9 @@ class TFRecordDataset(dataset_ops.Dataset):
def _as_variant_tensor(self):
return self._impl._as_variant_tensor() # pylint: disable=protected-access
+ def _inputs(self):
+ return self._impl._inputs() # pylint: disable=protected-access
+
@property
def output_classes(self):
return self._impl.output_classes
@@ -278,6 +287,9 @@ class FixedLengthRecordDataset(dataset_ops.Dataset):
self._filenames, self._header_bytes, self._record_bytes,
self._footer_bytes, self._buffer_size)
+ def _inputs(self):
+ return []
+
@property
def output_classes(self):
return ops.Tensor