aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/experimental/ops/optimization.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/experimental/ops/optimization.py')
-rw-r--r--tensorflow/python/data/experimental/ops/optimization.py61
1 files changed, 2 insertions, 59 deletions
diff --git a/tensorflow/python/data/experimental/ops/optimization.py b/tensorflow/python/data/experimental/ops/optimization.py
index 30348ede36..276dde8383 100644
--- a/tensorflow/python/data/experimental/ops/optimization.py
+++ b/tensorflow/python/data/experimental/ops/optimization.py
@@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_experimental_dataset_ops
# A constant that can be used to enable auto-tuning.
@@ -58,7 +57,7 @@ def model():
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
- return _ModelDataset(dataset)
+ return dataset_ops._ModelDataset(dataset) # pylint: disable=protected-access
return _apply_fn
@@ -78,7 +77,7 @@ def optimize(optimizations=None):
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
- return _OptimizeDataset(dataset, optimizations)
+ return dataset_ops._OptimizeDataset(dataset, optimizations) # pylint: disable=protected-access
return _apply_fn
@@ -113,59 +112,3 @@ class _AssertNextDataset(dataset_ops.UnaryDataset):
def output_types(self):
return self._input_dataset.output_types
-
-class _ModelDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that acts as an identity, and models performance."""
-
- def __init__(self, input_dataset):
- """See `optimize()` for details."""
- super(_ModelDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.model_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
-
-class _OptimizeDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that acts as an identity, and applies optimizations."""
-
- def __init__(self, input_dataset, optimizations):
- """See `optimize()` for details."""
- super(_OptimizeDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- if optimizations is None:
- optimizations = []
- self._optimizations = ops.convert_to_tensor(
- optimizations, dtype=dtypes.string, name="optimizations")
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.optimize_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._optimizations,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types