aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/ops/optimization.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/ops/optimization.py')
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py53
1 files changed, 53 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
index cf89657226..018c5115e1 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -18,12 +18,34 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
+from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
+# TODO(jsimsa): Support RE matching for both individual transformation (e.g. to
+# account for indexing) and transformation sequence.
+def assert_next(transformations):
+ """A transformation that asserts which transformations happen next.
+
+ Args:
+ transformations: A `tf.string` vector `tf.Tensor` identifying the
+ transformations that are expected to happen next.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _AssertNextDataset(dataset, transformations)
+
+ return _apply_fn
+
+
def optimize(optimizations=None):
"""A transformation that applies optimizations.
@@ -44,6 +66,37 @@ def optimize(optimizations=None):
return _apply_fn
+class _AssertNextDataset(dataset_ops.Dataset):
+ """A `Dataset` that asserts which transformations happen next."""
+
+ def __init__(self, input_dataset, transformations):
+ """See `assert_next()` for details."""
+ super(_AssertNextDataset, self).__init__()
+ self._input_dataset = input_dataset
+ if transformations is None:
+ raise ValueError("At least one transformation should be specified")
+ self._transformations = ops.convert_to_tensor(
+ transformations, dtype=dtypes.string, name="transformations")
+
+ def _as_variant_tensor(self):
+ return contrib_gen_dataset_ops.assert_next_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._transformations,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
class _OptimizeDataset(dataset_ops.Dataset):
"""A `Dataset` that acts as an identity, and applies optimizations."""