aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-06-21 10:02:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-21 10:05:44 -0700
commit293b21eddc34ee0ceda1143ec7699e54c9768a1c (patch)
tree30dca0da1a878616369ad96fc0275c06bf8770c7
parent6eb9820f131448fcbb8a8cfc195a112dcb503fcc (diff)
[tf.data] Cleanup of tf.data.contrib, propertly exporting public API.
PiperOrigin-RevId: 201542140
-rw-r--r--tensorflow/contrib/data/__init__.py13
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD20
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py (renamed from tensorflow/contrib/data/python/ops/iterator_ops_test.py)0
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD20
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py14
-rw-r--r--tensorflow/contrib/data/python/ops/error_ops.py6
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py14
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py6
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py6
-rw-r--r--tensorflow/contrib/data/python/ops/stats_ops.py11
-rw-r--r--tensorflow/contrib/data/python/ops/threadpool.py4
-rw-r--r--tensorflow/contrib/data/python/ops/unique.py6
13 files changed, 73 insertions, 51 deletions
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 99699cd6d6..2a4cf877f0 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -25,7 +25,10 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
@@Counter
@@CheckpointInputPipelineHook
@@CsvDataset
+@@RandomDataset
+@@Reducer
@@SqlDataset
+@@TFRecordWriter
@@assert_element_shape
@@batch_and_drop_remainder
@@ -33,12 +36,15 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
@@choose_from_datasets
@@dense_to_sparse_batch
@@enumerate_dataset
+
+@@get_single_element
@@group_by_reducer
@@group_by_window
@@ignore_errors
@@make_batched_features_dataset
@@make_csv_dataset
@@make_saveable_from_iterator
+
@@map_and_batch
@@padded_batch_and_drop_remainder
@@parallel_interleave
@@ -51,8 +57,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
@@sliding_window_batch
@@sloppy_interleave
@@unbatch
-
-@@get_single_element
+@@unique
"""
from __future__ import absolute_import
@@ -74,6 +79,7 @@ from tensorflow.contrib.data.python.ops.get_single_element import get_single_ele
from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length
from tensorflow.contrib.data.python.ops.grouping import group_by_reducer
from tensorflow.contrib.data.python.ops.grouping import group_by_window
+from tensorflow.contrib.data.python.ops.grouping import Reducer
from tensorflow.contrib.data.python.ops.interleave_ops import choose_from_datasets
from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave
from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets
@@ -81,6 +87,7 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
+from tensorflow.contrib.data.python.ops.random_ops import RandomDataset
from tensorflow.contrib.data.python.ops.readers import CsvDataset
from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset
from tensorflow.contrib.data.python.ops.readers import make_csv_dataset
@@ -90,6 +97,8 @@ from tensorflow.contrib.data.python.ops.resampling import rejection_resample
from tensorflow.contrib.data.python.ops.scan_ops import scan
from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch
+from tensorflow.contrib.data.python.ops.unique import unique
+from tensorflow.contrib.data.python.ops.writers import TFRecordWriter
# pylint: enable=unused-import
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index ed1542d03f..ef9f966fab 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -158,6 +158,26 @@ py_test(
)
py_test(
+ name = "iterator_ops_test",
+ size = "small",
+ srcs = ["iterator_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/contrib/data/python/ops:iterator_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:model_fn",
+ ],
+)
+
+py_test(
name = "map_dataset_op_test",
size = "medium",
srcs = ["map_dataset_op_test.py"],
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
index fe618cdce6..9b1857de1a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
@@ -33,8 +33,8 @@ class DirectedInterleaveDatasetTest(test.TestCase):
input_datasets = [
dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10)
]
- dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset,
- input_datasets)
+ dataset = interleave_ops._DirectedInterleaveDataset(selector_dataset,
+ input_datasets)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
diff --git a/tensorflow/contrib/data/python/ops/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
index 30a993b1f7..30a993b1f7 100644
--- a/tensorflow/contrib/data/python/ops/iterator_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 33b7a75046..0240814562 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -49,26 +49,6 @@ py_library(
],
)
-py_test(
- name = "iterator_ops_test",
- size = "small",
- srcs = ["iterator_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":iterator_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/estimator",
- "//tensorflow/python/estimator:model_fn",
- ],
-)
-
py_library(
name = "random_ops",
srcs = [
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 052618e08c..5708d47c20 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -77,17 +77,17 @@ def dense_to_sparse_batch(batch_size, row_shape):
"""
def _apply_fn(dataset):
- return DenseToSparseBatchDataset(dataset, batch_size, row_shape)
+ return _DenseToSparseBatchDataset(dataset, batch_size, row_shape)
return _apply_fn
-class UnbatchDataset(dataset_ops.Dataset):
+class _UnbatchDataset(dataset_ops.Dataset):
"""A dataset that splits the elements of its input into multiple elements."""
def __init__(self, input_dataset):
"""See `unbatch()` for more details."""
- super(UnbatchDataset, self).__init__()
+ super(_UnbatchDataset, self).__init__()
flat_shapes = nest.flatten(input_dataset.output_shapes)
if any(s.ndims == 0 for s in flat_shapes):
raise ValueError("Cannot unbatch an input with scalar components.")
@@ -144,7 +144,7 @@ def unbatch():
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
if not sparse.any_sparse(dataset.output_classes):
- return UnbatchDataset(dataset)
+ return _UnbatchDataset(dataset)
# NOTE(mrry): We must ensure that any SparseTensors in `dataset`
# are normalized to the rank-1 dense representation, so that the
@@ -170,7 +170,7 @@ def unbatch():
dataset.output_shapes,
dataset.output_classes,
allow_unsafe_cast=True)
- return UnbatchDataset(restructured_dataset)
+ return _UnbatchDataset(restructured_dataset)
return _apply_fn
@@ -298,12 +298,12 @@ def padded_batch_and_drop_remainder(batch_size,
return _apply_fn
-class DenseToSparseBatchDataset(dataset_ops.Dataset):
+class _DenseToSparseBatchDataset(dataset_ops.Dataset):
"""A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s."""
def __init__(self, input_dataset, batch_size, row_shape):
"""See `Dataset.dense_to_sparse_batch()` for more details."""
- super(DenseToSparseBatchDataset, self).__init__()
+ super(_DenseToSparseBatchDataset, self).__init__()
if not isinstance(input_dataset.output_types, dtypes.DType):
raise TypeError("DenseToSparseDataset requires an input whose elements "
"have a single component, whereas the input has %r." %
diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py
index 5f5513849c..d46d96c461 100644
--- a/tensorflow/contrib/data/python/ops/error_ops.py
+++ b/tensorflow/contrib/data/python/ops/error_ops.py
@@ -46,17 +46,17 @@ def ignore_errors():
"""
def _apply_fn(dataset):
- return IgnoreErrorsDataset(dataset)
+ return _IgnoreErrorsDataset(dataset)
return _apply_fn
-class IgnoreErrorsDataset(dataset_ops.Dataset):
+class _IgnoreErrorsDataset(dataset_ops.Dataset):
"""A `Dataset` that silently ignores errors when computing its input."""
def __init__(self, input_dataset):
"""See `Dataset.ignore_errors()` for details."""
- super(IgnoreErrorsDataset, self).__init__()
+ super(_IgnoreErrorsDataset, self).__init__()
self._input_dataset = input_dataset
def _as_variant_tensor(self):
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 4068a2ffa5..348884e9fa 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -55,7 +55,7 @@ def group_by_reducer(key_func, reducer):
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
- return GroupByReducerDataset(dataset, key_func, reducer)
+ return _GroupByReducerDataset(dataset, key_func, reducer)
return _apply_fn
@@ -113,8 +113,8 @@ def group_by_window(key_func,
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
- return GroupByWindowDataset(dataset, key_func, reduce_func,
- window_size_func)
+ return _GroupByWindowDataset(dataset, key_func, reduce_func,
+ window_size_func)
return _apply_fn
@@ -254,12 +254,12 @@ class _VariantDataset(dataset_ops.Dataset):
return self._output_types
-class GroupByReducerDataset(dataset_ops.Dataset):
+class _GroupByReducerDataset(dataset_ops.Dataset):
"""A `Dataset` that groups its input and performs a reduction."""
def __init__(self, input_dataset, key_func, reducer):
"""See `group_by_reducer()` for details."""
- super(GroupByReducerDataset, self).__init__()
+ super(_GroupByReducerDataset, self).__init__()
self._input_dataset = input_dataset
@@ -388,12 +388,12 @@ class GroupByReducerDataset(dataset_ops.Dataset):
**dataset_ops.flat_structure(self))
-class GroupByWindowDataset(dataset_ops.Dataset):
+class _GroupByWindowDataset(dataset_ops.Dataset):
"""A `Dataset` that groups its input and performs a windowed reduction."""
def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
"""See `group_by_window()` for details."""
- super(GroupByWindowDataset, self).__init__()
+ super(_GroupByWindowDataset, self).__init__()
self._input_dataset = input_dataset
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 70153ac575..bcc959594a 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -153,7 +153,7 @@ def sloppy_interleave(map_func, cycle_length, block_length=1):
return _apply_fn
-class DirectedInterleaveDataset(dataset_ops.Dataset):
+class _DirectedInterleaveDataset(dataset_ops.Dataset):
"""A substitute for `Dataset.interleave()` on a fixed list of datasets."""
def __init__(self, selector_input, data_inputs):
@@ -236,7 +236,7 @@ def sample_from_datasets(datasets, weights=None, seed=None):
selector_input = dataset_ops.Dataset.zip(
(logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset)
- return DirectedInterleaveDataset(selector_input, datasets)
+ return _DirectedInterleaveDataset(selector_input, datasets)
def choose_from_datasets(datasets, choice_dataset):
@@ -280,4 +280,4 @@ def choose_from_datasets(datasets, choice_dataset):
and choice_dataset.output_classes == ops.Tensor):
raise TypeError("`choice_dataset` must be a dataset of scalar "
"`tf.int64` tensors.")
- return DirectedInterleaveDataset(choice_dataset, datasets)
+ return _DirectedInterleaveDataset(choice_dataset, datasets)
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
index 2ca3805d66..cf89657226 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -39,17 +39,17 @@ def optimize(optimizations=None):
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
- return OptimizeDataset(dataset, optimizations)
+ return _OptimizeDataset(dataset, optimizations)
return _apply_fn
-class OptimizeDataset(dataset_ops.Dataset):
+class _OptimizeDataset(dataset_ops.Dataset):
"""A `Dataset` that acts as an identity, and applies optimizations."""
def __init__(self, input_dataset, optimizations):
"""See `optimize()` for details."""
- super(OptimizeDataset, self).__init__()
+ super(_OptimizeDataset, self).__init__()
self._input_dataset = input_dataset
if optimizations is None:
optimizations = []
diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py
index 3c82a03df1..97931f75bd 100644
--- a/tensorflow/contrib/data/python/ops/stats_ops.py
+++ b/tensorflow/contrib/data/python/ops/stats_ops.py
@@ -23,6 +23,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
+# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
class StatsAggregator(object):
"""A stateful resource that aggregates statistics from one or more iterators.
@@ -110,7 +112,8 @@ class _SetStatsAggregatorDataset(dataset_ops.Dataset):
return self._input_dataset.output_classes
-# TODO(shivaniagrawal): Expose these methods in `tf.contrib.data`.
+# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
def set_stats_aggregator(stats_aggregator):
"""Set the given stats_aggregator for aggregating the input dataset stats.
@@ -128,6 +131,8 @@ def set_stats_aggregator(stats_aggregator):
return _apply_fn
+# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
def bytes_produced_stats(tag):
"""Records the number of bytes produced by each element of the input dataset.
@@ -150,6 +155,8 @@ def bytes_produced_stats(tag):
return _apply_fn
+# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
def latency_stats(tag):
"""Records the latency of producing each element of the input dataset.
@@ -171,6 +178,8 @@ def latency_stats(tag):
return _apply_fn
+# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
def feature_stats(tag):
"""Records the features stats from `Example` records of the input dataset.
diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py
index bb49604d4d..f228660176 100644
--- a/tensorflow/contrib/data/python/ops/threadpool.py
+++ b/tensorflow/contrib/data/python/ops/threadpool.py
@@ -37,6 +37,8 @@ def _generate_shared_name(prefix):
return "{}{}".format(prefix, uid)
+# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
class PrivateThreadPool(object):
"""A stateful resource that represents a private thread pool."""
@@ -82,6 +84,8 @@ class _ThreadPoolDataset(dataset_ops.Dataset):
return self._input_dataset.output_classes
+# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
def override_threadpool(dataset, thread_pool):
"""Returns a new dataset that uses the given thread pool for its operations.
diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py
index 4ce6ddede8..e0ce0a4ef1 100644
--- a/tensorflow/contrib/data/python/ops/unique.py
+++ b/tensorflow/contrib/data/python/ops/unique.py
@@ -42,17 +42,17 @@ def unique():
"""
def _apply_fn(dataset):
- return UniqueDataset(dataset)
+ return _UniqueDataset(dataset)
return _apply_fn
-class UniqueDataset(dataset_ops.Dataset):
+class _UniqueDataset(dataset_ops.Dataset):
"""A `Dataset` contains the unique elements from its input."""
def __init__(self, input_dataset):
"""See `unique()` for details."""
- super(UniqueDataset, self).__init__()
+ super(_UniqueDataset, self).__init__()
self._input_dataset = input_dataset
if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
dtypes.string):