diff options
Diffstat (limited to 'tensorflow/contrib/data/python/ops/optimization.py')
-rw-r--r-- | tensorflow/contrib/data/python/ops/optimization.py | 53 |
1 files changed, 53 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py index cf89657226..018c5115e1 100644 --- a/tensorflow/contrib/data/python/ops/optimization.py +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -18,12 +18,34 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops +# TODO(jsimsa): Support RE matching for both individual transformation (e.g. to +# account for indexing) and transformation sequence. +def assert_next(transformations): + """A transformation that asserts which transformations happen next. + + Args: + transformations: A `tf.string` vector `tf.Tensor` identifying the + transformations that are expected to happen next. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return _AssertNextDataset(dataset, transformations) + + return _apply_fn + + def optimize(optimizations=None): """A transformation that applies optimizations. @@ -44,6 +66,37 @@ def optimize(optimizations=None): return _apply_fn +class _AssertNextDataset(dataset_ops.Dataset): + """A `Dataset` that asserts which transformations happen next.""" + + def __init__(self, input_dataset, transformations): + """See `assert_next()` for details.""" + super(_AssertNextDataset, self).__init__() + self._input_dataset = input_dataset + if transformations is None: + raise ValueError("At least one transformation should be specified") + self._transformations = ops.convert_to_tensor( + transformations, dtype=dtypes.string, name="transformations") + + def _as_variant_tensor(self): + return contrib_gen_dataset_ops.assert_next_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._transformations, + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + class _OptimizeDataset(dataset_ops.Dataset): """A `Dataset` that acts as an identity, and applies optimizations.""" |