aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/ops/optimization.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/ops/optimization.py')
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
index 73840452df..3eb172acd5 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -84,12 +84,12 @@ def optimize(optimizations=None):
return _apply_fn
-class _AssertNextDataset(dataset_ops.Dataset):
+class _AssertNextDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that asserts which transformations happen next."""
def __init__(self, input_dataset, transformations):
"""See `assert_next()` for details."""
- super(_AssertNextDataset, self).__init__()
+ super(_AssertNextDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if transformations is None:
raise ValueError("At least one transformation should be specified")
@@ -115,12 +115,12 @@ class _AssertNextDataset(dataset_ops.Dataset):
return self._input_dataset.output_types
-class _ModelDataset(dataset_ops.Dataset):
+class _ModelDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and models performance."""
def __init__(self, input_dataset):
"""See `optimize()` for details."""
- super(_ModelDataset, self).__init__()
+ super(_ModelDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
def _as_variant_tensor(self):
@@ -141,12 +141,12 @@ class _ModelDataset(dataset_ops.Dataset):
return self._input_dataset.output_types
-class _OptimizeDataset(dataset_ops.Dataset):
+class _OptimizeDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and applies optimizations."""
def __init__(self, input_dataset, optimizations):
"""See `optimize()` for details."""
- super(_OptimizeDataset, self).__init__()
+ super(_OptimizeDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if optimizations is None:
optimizations = []