diff options
author | Jiri Simsa <jsimsa@google.com> | 2018-06-21 10:02:46 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-21 10:05:44 -0700 |
commit | 293b21eddc34ee0ceda1143ec7699e54c9768a1c (patch) | |
tree | 30dca0da1a878616369ad96fc0275c06bf8770c7 | |
parent | 6eb9820f131448fcbb8a8cfc195a112dcb503fcc (diff) |
[tf.data] Cleanup of tf.data.contrib, propertly exporting public API.
PiperOrigin-RevId: 201542140
-rw-r--r-- | tensorflow/contrib/data/__init__.py | 13 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/BUILD | 20 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py (renamed from tensorflow/contrib/data/python/ops/iterator_ops_test.py) | 0 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/BUILD | 20 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/batching.py | 14 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/error_ops.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/grouping.py | 14 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/interleave_ops.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/optimization.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/stats_ops.py | 11 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/threadpool.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/unique.py | 6 |
13 files changed, 73 insertions, 51 deletions
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 99699cd6d6..2a4cf877f0 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -25,7 +25,10 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@Counter @@CheckpointInputPipelineHook @@CsvDataset +@@RandomDataset +@@Reducer @@SqlDataset +@@TFRecordWriter @@assert_element_shape @@batch_and_drop_remainder @@ -33,12 +36,15 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@choose_from_datasets @@dense_to_sparse_batch @@enumerate_dataset + +@@get_single_element @@group_by_reducer @@group_by_window @@ignore_errors @@make_batched_features_dataset @@make_csv_dataset @@make_saveable_from_iterator + @@map_and_batch @@padded_batch_and_drop_remainder @@parallel_interleave @@ -51,8 +57,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@sliding_window_batch @@sloppy_interleave @@unbatch - -@@get_single_element +@@unique """ from __future__ import absolute_import @@ -74,6 +79,7 @@ from tensorflow.contrib.data.python.ops.get_single_element import get_single_ele from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length from tensorflow.contrib.data.python.ops.grouping import group_by_reducer from tensorflow.contrib.data.python.ops.grouping import group_by_window +from tensorflow.contrib.data.python.ops.grouping import Reducer from tensorflow.contrib.data.python.ops.interleave_ops import choose_from_datasets from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets @@ -81,6 +87,7 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device +from tensorflow.contrib.data.python.ops.random_ops import RandomDataset from tensorflow.contrib.data.python.ops.readers import CsvDataset from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset from tensorflow.contrib.data.python.ops.readers import make_csv_dataset @@ -90,6 +97,8 @@ from tensorflow.contrib.data.python.ops.resampling import rejection_resample from tensorflow.contrib.data.python.ops.scan_ops import scan from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch +from tensorflow.contrib.data.python.ops.unique import unique +from tensorflow.contrib.data.python.ops.writers import TFRecordWriter # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index ed1542d03f..ef9f966fab 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -158,6 +158,26 @@ py_test( ) py_test( + name = "iterator_ops_test", + size = "small", + srcs = ["iterator_ops_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", + ], +) + +py_test( name = "map_dataset_op_test", size = "medium", srcs = ["map_dataset_op_test.py"], diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py index fe618cdce6..9b1857de1a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py @@ -33,8 +33,8 @@ class DirectedInterleaveDatasetTest(test.TestCase): input_datasets = [ dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10) ] - dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset, - input_datasets) + dataset = interleave_ops._DirectedInterleaveDataset(selector_dataset, + input_datasets) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() diff --git a/tensorflow/contrib/data/python/ops/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index 30a993b1f7..30a993b1f7 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 33b7a75046..0240814562 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -49,26 +49,6 @@ py_library( ], ) -py_test( - name = "iterator_ops_test", - size = "small", - srcs = ["iterator_ops_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":iterator_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:model_fn", - ], -) - py_library( name = "random_ops", srcs = [ diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 052618e08c..5708d47c20 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -77,17 +77,17 @@ def dense_to_sparse_batch(batch_size, row_shape): """ def _apply_fn(dataset): - return DenseToSparseBatchDataset(dataset, batch_size, row_shape) + return _DenseToSparseBatchDataset(dataset, batch_size, row_shape) return _apply_fn -class UnbatchDataset(dataset_ops.Dataset): +class _UnbatchDataset(dataset_ops.Dataset): """A dataset that splits the elements of its input into multiple elements.""" def __init__(self, input_dataset): """See `unbatch()` for more details.""" - super(UnbatchDataset, self).__init__() + super(_UnbatchDataset, self).__init__() flat_shapes = nest.flatten(input_dataset.output_shapes) if any(s.ndims == 0 for s in flat_shapes): raise ValueError("Cannot unbatch an input with scalar components.") @@ -144,7 +144,7 @@ def unbatch(): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" if not sparse.any_sparse(dataset.output_classes): - return UnbatchDataset(dataset) + return _UnbatchDataset(dataset) # NOTE(mrry): We must ensure that any SparseTensors in `dataset` # are normalized to the rank-1 dense representation, so that the @@ -170,7 +170,7 @@ def unbatch(): dataset.output_shapes, dataset.output_classes, allow_unsafe_cast=True) - return UnbatchDataset(restructured_dataset) + return _UnbatchDataset(restructured_dataset) return _apply_fn @@ -298,12 +298,12 @@ def padded_batch_and_drop_remainder(batch_size, return _apply_fn -class DenseToSparseBatchDataset(dataset_ops.Dataset): +class _DenseToSparseBatchDataset(dataset_ops.Dataset): """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s.""" def __init__(self, input_dataset, batch_size, row_shape): """See `Dataset.dense_to_sparse_batch()` for more details.""" - super(DenseToSparseBatchDataset, self).__init__() + super(_DenseToSparseBatchDataset, self).__init__() if not isinstance(input_dataset.output_types, dtypes.DType): raise TypeError("DenseToSparseDataset requires an input whose elements " "have a single component, whereas the input has %r." % diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index 5f5513849c..d46d96c461 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -46,17 +46,17 @@ def ignore_errors(): """ def _apply_fn(dataset): - return IgnoreErrorsDataset(dataset) + return _IgnoreErrorsDataset(dataset) return _apply_fn -class IgnoreErrorsDataset(dataset_ops.Dataset): +class _IgnoreErrorsDataset(dataset_ops.Dataset): """A `Dataset` that silently ignores errors when computing its input.""" def __init__(self, input_dataset): """See `Dataset.ignore_errors()` for details.""" - super(IgnoreErrorsDataset, self).__init__() + super(_IgnoreErrorsDataset, self).__init__() self._input_dataset = input_dataset def _as_variant_tensor(self): diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 4068a2ffa5..348884e9fa 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -55,7 +55,7 @@ def group_by_reducer(key_func, reducer): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - return GroupByReducerDataset(dataset, key_func, reducer) + return _GroupByReducerDataset(dataset, key_func, reducer) return _apply_fn @@ -113,8 +113,8 @@ def group_by_window(key_func, def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - return GroupByWindowDataset(dataset, key_func, reduce_func, - window_size_func) + return _GroupByWindowDataset(dataset, key_func, reduce_func, + window_size_func) return _apply_fn @@ -254,12 +254,12 @@ class _VariantDataset(dataset_ops.Dataset): return self._output_types -class GroupByReducerDataset(dataset_ops.Dataset): +class _GroupByReducerDataset(dataset_ops.Dataset): """A `Dataset` that groups its input and performs a reduction.""" def __init__(self, input_dataset, key_func, reducer): """See `group_by_reducer()` for details.""" - super(GroupByReducerDataset, self).__init__() + super(_GroupByReducerDataset, self).__init__() self._input_dataset = input_dataset @@ -388,12 +388,12 @@ class GroupByReducerDataset(dataset_ops.Dataset): **dataset_ops.flat_structure(self)) -class GroupByWindowDataset(dataset_ops.Dataset): +class _GroupByWindowDataset(dataset_ops.Dataset): """A `Dataset` that groups its input and performs a windowed reduction.""" def __init__(self, input_dataset, key_func, reduce_func, window_size_func): """See `group_by_window()` for details.""" - super(GroupByWindowDataset, self).__init__() + super(_GroupByWindowDataset, self).__init__() self._input_dataset = input_dataset diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 70153ac575..bcc959594a 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -153,7 +153,7 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): return _apply_fn -class DirectedInterleaveDataset(dataset_ops.Dataset): +class _DirectedInterleaveDataset(dataset_ops.Dataset): """A substitute for `Dataset.interleave()` on a fixed list of datasets.""" def __init__(self, selector_input, data_inputs): @@ -236,7 +236,7 @@ def sample_from_datasets(datasets, weights=None, seed=None): selector_input = dataset_ops.Dataset.zip( (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset) - return DirectedInterleaveDataset(selector_input, datasets) + return _DirectedInterleaveDataset(selector_input, datasets) def choose_from_datasets(datasets, choice_dataset): @@ -280,4 +280,4 @@ def choose_from_datasets(datasets, choice_dataset): and choice_dataset.output_classes == ops.Tensor): raise TypeError("`choice_dataset` must be a dataset of scalar " "`tf.int64` tensors.") - return DirectedInterleaveDataset(choice_dataset, datasets) + return _DirectedInterleaveDataset(choice_dataset, datasets) diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py index 2ca3805d66..cf89657226 100644 --- a/tensorflow/contrib/data/python/ops/optimization.py +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -39,17 +39,17 @@ def optimize(optimizations=None): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - return OptimizeDataset(dataset, optimizations) + return _OptimizeDataset(dataset, optimizations) return _apply_fn -class OptimizeDataset(dataset_ops.Dataset): +class _OptimizeDataset(dataset_ops.Dataset): """A `Dataset` that acts as an identity, and applies optimizations.""" def __init__(self, input_dataset, optimizations): """See `optimize()` for details.""" - super(OptimizeDataset, self).__init__() + super(_OptimizeDataset, self).__init__() self._input_dataset = input_dataset if optimizations is None: optimizations = [] diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index 3c82a03df1..97931f75bd 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -23,6 +23,8 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. class StatsAggregator(object): """A stateful resource that aggregates statistics from one or more iterators. @@ -110,7 +112,8 @@ class _SetStatsAggregatorDataset(dataset_ops.Dataset): return self._input_dataset.output_classes -# TODO(shivaniagrawal): Expose these methods in `tf.contrib.data`. +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def set_stats_aggregator(stats_aggregator): """Set the given stats_aggregator for aggregating the input dataset stats. @@ -128,6 +131,8 @@ def set_stats_aggregator(stats_aggregator): return _apply_fn +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def bytes_produced_stats(tag): """Records the number of bytes produced by each element of the input dataset. @@ -150,6 +155,8 @@ def bytes_produced_stats(tag): return _apply_fn +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def latency_stats(tag): """Records the latency of producing each element of the input dataset. @@ -171,6 +178,8 @@ def latency_stats(tag): return _apply_fn +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def feature_stats(tag): """Records the features stats from `Example` records of the input dataset. diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py index bb49604d4d..f228660176 100644 --- a/tensorflow/contrib/data/python/ops/threadpool.py +++ b/tensorflow/contrib/data/python/ops/threadpool.py @@ -37,6 +37,8 @@ def _generate_shared_name(prefix): return "{}{}".format(prefix, uid) +# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. class PrivateThreadPool(object): """A stateful resource that represents a private thread pool.""" @@ -82,6 +84,8 @@ class _ThreadPoolDataset(dataset_ops.Dataset): return self._input_dataset.output_classes +# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def override_threadpool(dataset, thread_pool): """Returns a new dataset that uses the given thread pool for its operations. diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py index 4ce6ddede8..e0ce0a4ef1 100644 --- a/tensorflow/contrib/data/python/ops/unique.py +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -42,17 +42,17 @@ def unique(): """ def _apply_fn(dataset): - return UniqueDataset(dataset) + return _UniqueDataset(dataset) return _apply_fn -class UniqueDataset(dataset_ops.Dataset): +class _UniqueDataset(dataset_ops.Dataset): """A `Dataset` contains the unique elements from its input.""" def __init__(self, input_dataset): """See `unique()` for details.""" - super(UniqueDataset, self).__init__() + super(_UniqueDataset, self).__init__() self._input_dataset = input_dataset if input_dataset.output_types not in (dtypes.int32, dtypes.int64, dtypes.string): |