diff options
Diffstat (limited to 'tensorflow/python/data/experimental/ops/optimization.py')
-rw-r--r-- | tensorflow/python/data/experimental/ops/optimization.py | 61 |
1 files changed, 2 insertions, 59 deletions
diff --git a/tensorflow/python/data/experimental/ops/optimization.py b/tensorflow/python/data/experimental/ops/optimization.py index 30348ede36..276dde8383 100644 --- a/tensorflow/python/data/experimental/ops/optimization.py +++ b/tensorflow/python/data/experimental/ops/optimization.py @@ -20,7 +20,6 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_experimental_dataset_ops # A constant that can be used to enable auto-tuning. @@ -58,7 +57,7 @@ def model(): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - return _ModelDataset(dataset) + return dataset_ops._ModelDataset(dataset) # pylint: disable=protected-access return _apply_fn @@ -78,7 +77,7 @@ def optimize(optimizations=None): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - return _OptimizeDataset(dataset, optimizations) + return dataset_ops._OptimizeDataset(dataset, optimizations) # pylint: disable=protected-access return _apply_fn @@ -113,59 +112,3 @@ class _AssertNextDataset(dataset_ops.UnaryDataset): def output_types(self): return self._input_dataset.output_types - -class _ModelDataset(dataset_ops.UnaryDataset): - """A `Dataset` that acts as an identity, and models performance.""" - - def __init__(self, input_dataset): - """See `optimize()` for details.""" - super(_ModelDataset, self).__init__(input_dataset) - self._input_dataset = input_dataset - - def _as_variant_tensor(self): - return gen_dataset_ops.model_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - **dataset_ops.flat_structure(self)) - - @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_types(self): - return self._input_dataset.output_types - - -class _OptimizeDataset(dataset_ops.UnaryDataset): - """A `Dataset` that acts as an identity, and applies optimizations.""" - - def __init__(self, input_dataset, optimizations): - """See `optimize()` for details.""" - super(_OptimizeDataset, self).__init__(input_dataset) - self._input_dataset = input_dataset - if optimizations is None: - optimizations = [] - self._optimizations = ops.convert_to_tensor( - optimizations, dtype=dtypes.string, name="optimizations") - - def _as_variant_tensor(self): - return gen_dataset_ops.optimize_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._optimizations, - **dataset_ops.flat_structure(self)) - - @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_types(self): - return self._input_dataset.output_types |