diff options
author | 2018-09-25 13:42:46 -0700 | |
---|---|---|
committer | 2018-09-25 13:46:54 -0700 | |
commit | 348478f642216cf3cbe1eb67b875252d8e6a6418 (patch) | |
tree | c4c7afd4283506b2c413e429cab02fc547e13481 /tensorflow/contrib/data | |
parent | 976fb3105312bb17accebcbca2ebae906bcf99fb (diff) |
[tf.data] Adding a private method for (recursively) tracking dataset inputs.
PiperOrigin-RevId: 214495925
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r-- | tensorflow/contrib/data/python/ops/batching.py | 12 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/error_ops.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/grouping.py | 12 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/indexed_dataset_ops.py | 3 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/interleave_ops.py | 3 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/optimization.py | 12 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/parsing_ops.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/prefetching_ops.py | 10 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/random_ops.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/readers.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/scan_ops.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/shuffle_ops.py | 11 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/sliding.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/stats_ops.py | 8 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/threadpool.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/unique.py | 4 |
16 files changed, 55 insertions, 48 deletions
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 367c159dc5..7a0f221284 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -345,12 +345,12 @@ def _padded_batch_sparse_window(dataset, padded_shape): dataset.apply(grouping.group_by_reducer(key_fn, reducer))) -class _UnbatchDataset(dataset_ops.Dataset): +class _UnbatchDataset(dataset_ops.UnaryDataset): """A dataset that splits the elements of its input into multiple elements.""" def __init__(self, input_dataset): """See `unbatch()` for more details.""" - super(_UnbatchDataset, self).__init__() + super(_UnbatchDataset, self).__init__(input_dataset) flat_shapes = nest.flatten(input_dataset.output_shapes) if any(s.ndims == 0 for s in flat_shapes): raise ValueError("Cannot unbatch an input with scalar components.") @@ -514,12 +514,12 @@ def padded_batch_and_drop_remainder(batch_size, return _apply_fn -class _DenseToSparseBatchDataset(dataset_ops.Dataset): +class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset): """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s.""" def __init__(self, input_dataset, batch_size, row_shape): """See `Dataset.dense_to_sparse_batch()` for more details.""" - super(_DenseToSparseBatchDataset, self).__init__() + super(_DenseToSparseBatchDataset, self).__init__(input_dataset) if not isinstance(input_dataset.output_types, dtypes.DType): raise TypeError("DenseToSparseDataset requires an input whose elements " "have a single component, whereas the input has %r." % @@ -548,7 +548,7 @@ class _DenseToSparseBatchDataset(dataset_ops.Dataset): return self._input_dataset.output_types -class _RestructuredDataset(dataset_ops.Dataset): +class _RestructuredDataset(dataset_ops.UnaryDataset): """An internal helper for changing the structure and shape of a dataset.""" def __init__(self, @@ -583,7 +583,7 @@ class _RestructuredDataset(dataset_ops.Dataset): ValueError: If either `output_types` or `output_shapes` is not compatible with the structure of `dataset`. """ - super(_RestructuredDataset, self).__init__() + super(_RestructuredDataset, self).__init__(dataset) self._input_dataset = dataset if not allow_unsafe_cast: diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index b4a7521e08..615dbcabd4 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -51,12 +51,12 @@ def ignore_errors(): return _apply_fn -class _IgnoreErrorsDataset(dataset_ops.Dataset): +class _IgnoreErrorsDataset(dataset_ops.UnaryDataset): """A `Dataset` that silently ignores errors when computing its input.""" def __init__(self, input_dataset): """See `Dataset.ignore_errors()` for details.""" - super(_IgnoreErrorsDataset, self).__init__() + super(_IgnoreErrorsDataset, self).__init__(input_dataset) self._input_dataset = input_dataset def _as_variant_tensor(self): diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 020167e4d1..7cae33beb3 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -282,12 +282,12 @@ def window_dataset(window_size): return _apply_fn -class _GroupByReducerDataset(dataset_ops.Dataset): +class _GroupByReducerDataset(dataset_ops.UnaryDataset): """A `Dataset` that groups its input and performs a reduction.""" def __init__(self, input_dataset, key_func, reducer): """See `group_by_reducer()` for details.""" - super(_GroupByReducerDataset, self).__init__() + super(_GroupByReducerDataset, self).__init__(input_dataset) self._input_dataset = input_dataset @@ -416,12 +416,12 @@ class _GroupByReducerDataset(dataset_ops.Dataset): **dataset_ops.flat_structure(self)) -class _GroupByWindowDataset(dataset_ops.Dataset): +class _GroupByWindowDataset(dataset_ops.UnaryDataset): """A `Dataset` that groups its input and performs a windowed reduction.""" def __init__(self, input_dataset, key_func, reduce_func, window_size_func): """See `group_by_window()` for details.""" - super(_GroupByWindowDataset, self).__init__() + super(_GroupByWindowDataset, self).__init__(input_dataset) self._input_dataset = input_dataset @@ -525,12 +525,12 @@ class Reducer(object): return self._finalize_func -class _MapXDataset(dataset_ops.Dataset): +class _MapXDataset(dataset_ops.UnaryDataset): """A `Dataset` that maps a function over elements in its input.""" def __init__(self, input_dataset, map_func): """See `map_x_dataset()` for details.""" - super(_MapXDataset, self).__init__() + super(_MapXDataset, self).__init__(input_dataset) self._input_dataset = input_dataset wrapped_func = dataset_ops.StructuredFunctionWrapper( diff --git a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py index a0932b4081..cc76ab0850 100644 --- a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py @@ -171,3 +171,6 @@ class IdentityIndexedDataset(IndexedDataset): def _as_variant_tensor(self): return gen_dataset_ops.identity_indexed_dataset(self._size) + + def _inputs(self): + return [] diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 92d4251a86..bfa3fdf543 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -173,6 +173,9 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset): **dataset_ops.flat_structure(self)) # pylint: enable=protected-access + def _inputs(self): + return [self._selector_input] + self._data_inputs + @property def output_classes(self): return self._data_inputs[0].output_classes diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py index 73840452df..3eb172acd5 100644 --- a/tensorflow/contrib/data/python/ops/optimization.py +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -84,12 +84,12 @@ def optimize(optimizations=None): return _apply_fn -class _AssertNextDataset(dataset_ops.Dataset): +class _AssertNextDataset(dataset_ops.UnaryDataset): """A `Dataset` that asserts which transformations happen next.""" def __init__(self, input_dataset, transformations): """See `assert_next()` for details.""" - super(_AssertNextDataset, self).__init__() + super(_AssertNextDataset, self).__init__(input_dataset) self._input_dataset = input_dataset if transformations is None: raise ValueError("At least one transformation should be specified") @@ -115,12 +115,12 @@ class _AssertNextDataset(dataset_ops.Dataset): return self._input_dataset.output_types -class _ModelDataset(dataset_ops.Dataset): +class _ModelDataset(dataset_ops.UnaryDataset): """A `Dataset` that acts as an identity, and models performance.""" def __init__(self, input_dataset): """See `optimize()` for details.""" - super(_ModelDataset, self).__init__() + super(_ModelDataset, self).__init__(input_dataset) self._input_dataset = input_dataset def _as_variant_tensor(self): @@ -141,12 +141,12 @@ class _ModelDataset(dataset_ops.Dataset): return self._input_dataset.output_types -class _OptimizeDataset(dataset_ops.Dataset): +class _OptimizeDataset(dataset_ops.UnaryDataset): """A `Dataset` that acts as an identity, and applies optimizations.""" def __init__(self, input_dataset, optimizations): """See `optimize()` for details.""" - super(_OptimizeDataset, self).__init__() + super(_OptimizeDataset, self).__init__(input_dataset) self._input_dataset = input_dataset if optimizations is None: optimizations = [] diff --git a/tensorflow/contrib/data/python/ops/parsing_ops.py b/tensorflow/contrib/data/python/ops/parsing_ops.py index 2701605e64..cfbba701b0 100644 --- a/tensorflow/contrib/data/python/ops/parsing_ops.py +++ b/tensorflow/contrib/data/python/ops/parsing_ops.py @@ -26,11 +26,11 @@ from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import parsing_ops -class _ParseExampleDataset(dataset_ops.Dataset): +class _ParseExampleDataset(dataset_ops.UnaryDataset): """A `Dataset` that parses `example` dataset into a `dict` dataset.""" def __init__(self, input_dataset, features, num_parallel_calls): - super(_ParseExampleDataset, self).__init__() + super(_ParseExampleDataset, self).__init__(input_dataset) self._input_dataset = input_dataset if not all(types == dtypes.string for types in nest.flatten(input_dataset.output_types)): diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index 5222011d04..f994425304 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -262,10 +262,11 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator): # pylint: enable=protected-access -class _PrefetchToDeviceDataset(dataset_ops.Dataset): +class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset): """A `Dataset` whose iterator prefetches elements to another device.""" def __init__(self, input_dataset, device, buffer_size): + super(_PrefetchToDeviceDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._device = device self._buffer_size = buffer_size if buffer_size is not None else 1 @@ -374,7 +375,7 @@ def copy_to_device(target_device, source_device="/cpu:0"): # TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate # all inputs to the Op are in host memory, thereby avoiding some unnecessary # Sends and Recvs. -class _CopyToDeviceDataset(dataset_ops.Dataset): +class _CopyToDeviceDataset(dataset_ops.UnaryDataset): """A `Dataset` that copies elements to another device.""" def __init__(self, input_dataset, target_device, source_device="/cpu:0"): @@ -385,6 +386,7 @@ class _CopyToDeviceDataset(dataset_ops.Dataset): target_device: The name of the device to which elements would be copied. source_device: Device where input_dataset would be placed. """ + super(_CopyToDeviceDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._target_device = target_device spec = framework_device.DeviceSpec().from_string(self._target_device) @@ -612,6 +614,10 @@ class _PerDeviceGenerator(dataset_ops.Dataset): output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) + def _inputs(self): + # TODO(b/116506223): Determine which datasets should be used as inputs here. + return [] + @property def output_types(self): return self._output_types diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py index e670c4c835..344a0763c8 100644 --- a/tensorflow/contrib/data/python/ops/random_ops.py +++ b/tensorflow/contrib/data/python/ops/random_ops.py @@ -25,7 +25,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops -class RandomDataset(dataset_ops.Dataset): +class RandomDataset(dataset_ops.DatasetSource): """A `Dataset` of pseudorandom values.""" def __init__(self, seed=None): diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 785b395707..d9d06e2703 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -508,7 +508,7 @@ def make_csv_dataset( _DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB -class CsvDataset(dataset_ops.Dataset): +class CsvDataset(dataset_ops.DatasetSource): """A Dataset comprising lines from one or more CSV files.""" def __init__(self, @@ -924,7 +924,7 @@ def _get_file_names(file_pattern, shuffle): return file_names -class SqlDataset(dataset_ops.Dataset): +class SqlDataset(dataset_ops.DatasetSource): """A `Dataset` consisting of the results from a SQL query.""" def __init__(self, driver_name, data_source_name, query, output_types): @@ -985,7 +985,7 @@ class SqlDataset(dataset_ops.Dataset): return self._output_types -class LMDBDataset(dataset_ops.Dataset): +class LMDBDataset(dataset_ops.DatasetSource): """A LMDB Dataset that reads the lmdb file.""" def __init__(self, filenames): diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index 6b002b4a53..c52582cd35 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -27,12 +27,12 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import gen_dataset_ops -class _ScanDataset(dataset_ops.Dataset): +class _ScanDataset(dataset_ops.UnaryDataset): """A dataset that scans a function across its input.""" def __init__(self, input_dataset, initial_state, scan_func): """See `scan()` for details.""" - super(_ScanDataset, self).__init__() + super(_ScanDataset, self).__init__(input_dataset) self._input_dataset = input_dataset with ops.name_scope("initial_state"): diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py index 4356721704..985d1d87d0 100644 --- a/tensorflow/contrib/data/python/ops/shuffle_ops.py +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -25,16 +25,11 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops -class _ShuffleAndRepeatDataset(dataset_ops.Dataset): +class _ShuffleAndRepeatDataset(dataset_ops.UnaryDataset): """A `Dataset` that fuses `shuffle` and `repeat`.""" - def __init__(self, - input_dataset, - buffer_size, - count=None, - seed=None): - """See `Dataset.map()` for details.""" - super(_ShuffleAndRepeatDataset, self).__init__() + def __init__(self, input_dataset, buffer_size, count=None, seed=None): + super(_ShuffleAndRepeatDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._buffer_size = ops.convert_to_tensor( buffer_size, dtype=dtypes.int64, name="buffer_size") diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index b0d6a16c20..bcc383587c 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -26,12 +26,12 @@ from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.util import deprecation -class _SlideDataset(dataset_ops.Dataset): +class _SlideDataset(dataset_ops.UnaryDataset): """A `Dataset` that passes a sliding window over its input.""" def __init__(self, input_dataset, window_size, window_shift, window_stride): """See `sliding_window_batch` for details.""" - super(_SlideDataset, self).__init__() + super(_SlideDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._window_size = ops.convert_to_tensor( window_size, dtype=dtypes.int64, name="window_stride") diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index 7410ee8e05..bc47c5989d 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -84,11 +84,11 @@ class StatsAggregator(object): return gen_dataset_ops.stats_aggregator_summary(self._resource) -class _SetStatsAggregatorDataset(dataset_ops.Dataset): +class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset): """A `Dataset` that acts as an identity, and sets given stats_aggregator.""" def __init__(self, input_dataset, stats_aggregator): - super(_SetStatsAggregatorDataset, self).__init__() + super(_SetStatsAggregatorDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._stats_aggregator = stats_aggregator @@ -173,11 +173,11 @@ def latency_stats(tag): return _apply_fn -class _StatsDataset(dataset_ops.Dataset): +class _StatsDataset(dataset_ops.UnaryDataset): """A `Dataset` that acts as an identity, and also records statistics.""" def __init__(self, input_dataset, op_function, tag): - super(_StatsDataset, self).__init__() + super(_StatsDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._op_function = op_function self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string) diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py index dc67accdcf..9d165ad52a 100644 --- a/tensorflow/contrib/data/python/ops/threadpool.py +++ b/tensorflow/contrib/data/python/ops/threadpool.py @@ -61,11 +61,11 @@ class PrivateThreadPool(object): display_name=display_name) -class _ThreadPoolDataset(dataset_ops.Dataset): +class _ThreadPoolDataset(dataset_ops.UnaryDataset): """A `Dataset` that acts as an identity, and sets a custom threadpool.""" def __init__(self, input_dataset, thread_pool): - super(_ThreadPoolDataset, self).__init__() + super(_ThreadPoolDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._thread_pool = thread_pool diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py index e0d606311c..bad67a580d 100644 --- a/tensorflow/contrib/data/python/ops/unique.py +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -47,12 +47,12 @@ def unique(): return _apply_fn -class _UniqueDataset(dataset_ops.Dataset): +class _UniqueDataset(dataset_ops.UnaryDataset): """A `Dataset` contains the unique elements from its input.""" def __init__(self, input_dataset): """See `unique()` for details.""" - super(_UniqueDataset, self).__init__() + super(_UniqueDataset, self).__init__(input_dataset) self._input_dataset = input_dataset if input_dataset.output_types not in (dtypes.int32, dtypes.int64, dtypes.string): |