aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-25 13:42:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 13:46:54 -0700
commit348478f642216cf3cbe1eb67b875252d8e6a6418 (patch)
treec4c7afd4283506b2c413e429cab02fc547e13481 /tensorflow/contrib/data
parent976fb3105312bb17accebcbca2ebae906bcf99fb (diff)
[tf.data] Adding a private method for (recursively) tracking dataset inputs.
PiperOrigin-RevId: 214495925
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py12
-rw-r--r--tensorflow/contrib/data/python/ops/error_ops.py4
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py12
-rw-r--r--tensorflow/contrib/data/python/ops/indexed_dataset_ops.py3
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py3
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py12
-rw-r--r--tensorflow/contrib/data/python/ops/parsing_ops.py4
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py10
-rw-r--r--tensorflow/contrib/data/python/ops/random_ops.py2
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py6
-rw-r--r--tensorflow/contrib/data/python/ops/scan_ops.py4
-rw-r--r--tensorflow/contrib/data/python/ops/shuffle_ops.py11
-rw-r--r--tensorflow/contrib/data/python/ops/sliding.py4
-rw-r--r--tensorflow/contrib/data/python/ops/stats_ops.py8
-rw-r--r--tensorflow/contrib/data/python/ops/threadpool.py4
-rw-r--r--tensorflow/contrib/data/python/ops/unique.py4
16 files changed, 55 insertions, 48 deletions
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 367c159dc5..7a0f221284 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -345,12 +345,12 @@ def _padded_batch_sparse_window(dataset, padded_shape):
dataset.apply(grouping.group_by_reducer(key_fn, reducer)))
-class _UnbatchDataset(dataset_ops.Dataset):
+class _UnbatchDataset(dataset_ops.UnaryDataset):
"""A dataset that splits the elements of its input into multiple elements."""
def __init__(self, input_dataset):
"""See `unbatch()` for more details."""
- super(_UnbatchDataset, self).__init__()
+ super(_UnbatchDataset, self).__init__(input_dataset)
flat_shapes = nest.flatten(input_dataset.output_shapes)
if any(s.ndims == 0 for s in flat_shapes):
raise ValueError("Cannot unbatch an input with scalar components.")
@@ -514,12 +514,12 @@ def padded_batch_and_drop_remainder(batch_size,
return _apply_fn
-class _DenseToSparseBatchDataset(dataset_ops.Dataset):
+class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s."""
def __init__(self, input_dataset, batch_size, row_shape):
"""See `Dataset.dense_to_sparse_batch()` for more details."""
- super(_DenseToSparseBatchDataset, self).__init__()
+ super(_DenseToSparseBatchDataset, self).__init__(input_dataset)
if not isinstance(input_dataset.output_types, dtypes.DType):
raise TypeError("DenseToSparseDataset requires an input whose elements "
"have a single component, whereas the input has %r." %
@@ -548,7 +548,7 @@ class _DenseToSparseBatchDataset(dataset_ops.Dataset):
return self._input_dataset.output_types
-class _RestructuredDataset(dataset_ops.Dataset):
+class _RestructuredDataset(dataset_ops.UnaryDataset):
"""An internal helper for changing the structure and shape of a dataset."""
def __init__(self,
@@ -583,7 +583,7 @@ class _RestructuredDataset(dataset_ops.Dataset):
ValueError: If either `output_types` or `output_shapes` is not compatible
with the structure of `dataset`.
"""
- super(_RestructuredDataset, self).__init__()
+ super(_RestructuredDataset, self).__init__(dataset)
self._input_dataset = dataset
if not allow_unsafe_cast:
diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py
index b4a7521e08..615dbcabd4 100644
--- a/tensorflow/contrib/data/python/ops/error_ops.py
+++ b/tensorflow/contrib/data/python/ops/error_ops.py
@@ -51,12 +51,12 @@ def ignore_errors():
return _apply_fn
-class _IgnoreErrorsDataset(dataset_ops.Dataset):
+class _IgnoreErrorsDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that silently ignores errors when computing its input."""
def __init__(self, input_dataset):
"""See `Dataset.ignore_errors()` for details."""
- super(_IgnoreErrorsDataset, self).__init__()
+ super(_IgnoreErrorsDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
def _as_variant_tensor(self):
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 020167e4d1..7cae33beb3 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -282,12 +282,12 @@ def window_dataset(window_size):
return _apply_fn
-class _GroupByReducerDataset(dataset_ops.Dataset):
+class _GroupByReducerDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that groups its input and performs a reduction."""
def __init__(self, input_dataset, key_func, reducer):
"""See `group_by_reducer()` for details."""
- super(_GroupByReducerDataset, self).__init__()
+ super(_GroupByReducerDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
@@ -416,12 +416,12 @@ class _GroupByReducerDataset(dataset_ops.Dataset):
**dataset_ops.flat_structure(self))
-class _GroupByWindowDataset(dataset_ops.Dataset):
+class _GroupByWindowDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that groups its input and performs a windowed reduction."""
def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
"""See `group_by_window()` for details."""
- super(_GroupByWindowDataset, self).__init__()
+ super(_GroupByWindowDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
@@ -525,12 +525,12 @@ class Reducer(object):
return self._finalize_func
-class _MapXDataset(dataset_ops.Dataset):
+class _MapXDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that maps a function over elements in its input."""
def __init__(self, input_dataset, map_func):
"""See `map_x_dataset()` for details."""
- super(_MapXDataset, self).__init__()
+ super(_MapXDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
wrapped_func = dataset_ops.StructuredFunctionWrapper(
diff --git a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
index a0932b4081..cc76ab0850 100644
--- a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
+++ b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
@@ -171,3 +171,6 @@ class IdentityIndexedDataset(IndexedDataset):
def _as_variant_tensor(self):
return gen_dataset_ops.identity_indexed_dataset(self._size)
+
+ def _inputs(self):
+ return []
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 92d4251a86..bfa3fdf543 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -173,6 +173,9 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
**dataset_ops.flat_structure(self))
# pylint: enable=protected-access
+ def _inputs(self):
+ return [self._selector_input] + self._data_inputs
+
@property
def output_classes(self):
return self._data_inputs[0].output_classes
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
index 73840452df..3eb172acd5 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -84,12 +84,12 @@ def optimize(optimizations=None):
return _apply_fn
-class _AssertNextDataset(dataset_ops.Dataset):
+class _AssertNextDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that asserts which transformations happen next."""
def __init__(self, input_dataset, transformations):
"""See `assert_next()` for details."""
- super(_AssertNextDataset, self).__init__()
+ super(_AssertNextDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if transformations is None:
raise ValueError("At least one transformation should be specified")
@@ -115,12 +115,12 @@ class _AssertNextDataset(dataset_ops.Dataset):
return self._input_dataset.output_types
-class _ModelDataset(dataset_ops.Dataset):
+class _ModelDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and models performance."""
def __init__(self, input_dataset):
"""See `optimize()` for details."""
- super(_ModelDataset, self).__init__()
+ super(_ModelDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
def _as_variant_tensor(self):
@@ -141,12 +141,12 @@ class _ModelDataset(dataset_ops.Dataset):
return self._input_dataset.output_types
-class _OptimizeDataset(dataset_ops.Dataset):
+class _OptimizeDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and applies optimizations."""
def __init__(self, input_dataset, optimizations):
"""See `optimize()` for details."""
- super(_OptimizeDataset, self).__init__()
+ super(_OptimizeDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if optimizations is None:
optimizations = []
diff --git a/tensorflow/contrib/data/python/ops/parsing_ops.py b/tensorflow/contrib/data/python/ops/parsing_ops.py
index 2701605e64..cfbba701b0 100644
--- a/tensorflow/contrib/data/python/ops/parsing_ops.py
+++ b/tensorflow/contrib/data/python/ops/parsing_ops.py
@@ -26,11 +26,11 @@ from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import parsing_ops
-class _ParseExampleDataset(dataset_ops.Dataset):
+class _ParseExampleDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that parses `example` dataset into a `dict` dataset."""
def __init__(self, input_dataset, features, num_parallel_calls):
- super(_ParseExampleDataset, self).__init__()
+ super(_ParseExampleDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if not all(types == dtypes.string
for types in nest.flatten(input_dataset.output_types)):
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index 5222011d04..f994425304 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -262,10 +262,11 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
# pylint: enable=protected-access
-class _PrefetchToDeviceDataset(dataset_ops.Dataset):
+class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
"""A `Dataset` whose iterator prefetches elements to another device."""
def __init__(self, input_dataset, device, buffer_size):
+ super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._device = device
self._buffer_size = buffer_size if buffer_size is not None else 1
@@ -374,7 +375,7 @@ def copy_to_device(target_device, source_device="/cpu:0"):
# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
# all inputs to the Op are in host memory, thereby avoiding some unnecessary
# Sends and Recvs.
-class _CopyToDeviceDataset(dataset_ops.Dataset):
+class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that copies elements to another device."""
def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
@@ -385,6 +386,7 @@ class _CopyToDeviceDataset(dataset_ops.Dataset):
target_device: The name of the device to which elements would be copied.
source_device: Device where input_dataset would be placed.
"""
+ super(_CopyToDeviceDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._target_device = target_device
spec = framework_device.DeviceSpec().from_string(self._target_device)
@@ -612,6 +614,10 @@ class _PerDeviceGenerator(dataset_ops.Dataset):
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
+ def _inputs(self):
+ # TODO(b/116506223): Determine which datasets should be used as inputs here.
+ return []
+
@property
def output_types(self):
return self._output_types
diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py
index e670c4c835..344a0763c8 100644
--- a/tensorflow/contrib/data/python/ops/random_ops.py
+++ b/tensorflow/contrib/data/python/ops/random_ops.py
@@ -25,7 +25,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_dataset_ops
-class RandomDataset(dataset_ops.Dataset):
+class RandomDataset(dataset_ops.DatasetSource):
"""A `Dataset` of pseudorandom values."""
def __init__(self, seed=None):
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 785b395707..d9d06e2703 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -508,7 +508,7 @@ def make_csv_dataset(
_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB
-class CsvDataset(dataset_ops.Dataset):
+class CsvDataset(dataset_ops.DatasetSource):
"""A Dataset comprising lines from one or more CSV files."""
def __init__(self,
@@ -924,7 +924,7 @@ def _get_file_names(file_pattern, shuffle):
return file_names
-class SqlDataset(dataset_ops.Dataset):
+class SqlDataset(dataset_ops.DatasetSource):
"""A `Dataset` consisting of the results from a SQL query."""
def __init__(self, driver_name, data_source_name, query, output_types):
@@ -985,7 +985,7 @@ class SqlDataset(dataset_ops.Dataset):
return self._output_types
-class LMDBDataset(dataset_ops.Dataset):
+class LMDBDataset(dataset_ops.DatasetSource):
"""A LMDB Dataset that reads the lmdb file."""
def __init__(self, filenames):
diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py
index 6b002b4a53..c52582cd35 100644
--- a/tensorflow/contrib/data/python/ops/scan_ops.py
+++ b/tensorflow/contrib/data/python/ops/scan_ops.py
@@ -27,12 +27,12 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gen_dataset_ops
-class _ScanDataset(dataset_ops.Dataset):
+class _ScanDataset(dataset_ops.UnaryDataset):
"""A dataset that scans a function across its input."""
def __init__(self, input_dataset, initial_state, scan_func):
"""See `scan()` for details."""
- super(_ScanDataset, self).__init__()
+ super(_ScanDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
with ops.name_scope("initial_state"):
diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py
index 4356721704..985d1d87d0 100644
--- a/tensorflow/contrib/data/python/ops/shuffle_ops.py
+++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py
@@ -25,16 +25,11 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
-class _ShuffleAndRepeatDataset(dataset_ops.Dataset):
+class _ShuffleAndRepeatDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that fuses `shuffle` and `repeat`."""
- def __init__(self,
- input_dataset,
- buffer_size,
- count=None,
- seed=None):
- """See `Dataset.map()` for details."""
- super(_ShuffleAndRepeatDataset, self).__init__()
+ def __init__(self, input_dataset, buffer_size, count=None, seed=None):
+ super(_ShuffleAndRepeatDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py
index b0d6a16c20..bcc383587c 100644
--- a/tensorflow/contrib/data/python/ops/sliding.py
+++ b/tensorflow/contrib/data/python/ops/sliding.py
@@ -26,12 +26,12 @@ from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.util import deprecation
-class _SlideDataset(dataset_ops.Dataset):
+class _SlideDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that passes a sliding window over its input."""
def __init__(self, input_dataset, window_size, window_shift, window_stride):
"""See `sliding_window_batch` for details."""
- super(_SlideDataset, self).__init__()
+ super(_SlideDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._window_size = ops.convert_to_tensor(
window_size, dtype=dtypes.int64, name="window_stride")
diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py
index 7410ee8e05..bc47c5989d 100644
--- a/tensorflow/contrib/data/python/ops/stats_ops.py
+++ b/tensorflow/contrib/data/python/ops/stats_ops.py
@@ -84,11 +84,11 @@ class StatsAggregator(object):
return gen_dataset_ops.stats_aggregator_summary(self._resource)
-class _SetStatsAggregatorDataset(dataset_ops.Dataset):
+class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and sets given stats_aggregator."""
def __init__(self, input_dataset, stats_aggregator):
- super(_SetStatsAggregatorDataset, self).__init__()
+ super(_SetStatsAggregatorDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._stats_aggregator = stats_aggregator
@@ -173,11 +173,11 @@ def latency_stats(tag):
return _apply_fn
-class _StatsDataset(dataset_ops.Dataset):
+class _StatsDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and also records statistics."""
def __init__(self, input_dataset, op_function, tag):
- super(_StatsDataset, self).__init__()
+ super(_StatsDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._op_function = op_function
self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string)
diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py
index dc67accdcf..9d165ad52a 100644
--- a/tensorflow/contrib/data/python/ops/threadpool.py
+++ b/tensorflow/contrib/data/python/ops/threadpool.py
@@ -61,11 +61,11 @@ class PrivateThreadPool(object):
display_name=display_name)
-class _ThreadPoolDataset(dataset_ops.Dataset):
+class _ThreadPoolDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and sets a custom threadpool."""
def __init__(self, input_dataset, thread_pool):
- super(_ThreadPoolDataset, self).__init__()
+ super(_ThreadPoolDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._thread_pool = thread_pool
diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py
index e0d606311c..bad67a580d 100644
--- a/tensorflow/contrib/data/python/ops/unique.py
+++ b/tensorflow/contrib/data/python/ops/unique.py
@@ -47,12 +47,12 @@ def unique():
return _apply_fn
-class _UniqueDataset(dataset_ops.Dataset):
+class _UniqueDataset(dataset_ops.UnaryDataset):
"""A `Dataset` contains the unique elements from its input."""
def __init__(self, input_dataset):
"""See `unique()` for details."""
- super(_UniqueDataset, self).__init__()
+ super(_UniqueDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
dtypes.string):