aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/ops/shuffle_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/ops/shuffle_ops.py')
-rw-r--r--tensorflow/contrib/data/python/ops/shuffle_ops.py56
1 files changed, 5 insertions, 51 deletions
diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py
index 985d1d87d0..329b34fdfe 100644
--- a/tensorflow/contrib/data/python/ops/shuffle_ops.py
+++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py
@@ -17,54 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import random_seed
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
-
-
-class _ShuffleAndRepeatDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that fuses `shuffle` and `repeat`."""
-
- def __init__(self, input_dataset, buffer_size, count=None, seed=None):
- super(_ShuffleAndRepeatDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._buffer_size = ops.convert_to_tensor(
- buffer_size, dtype=dtypes.int64, name="buffer_size")
- if count is None:
- self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
- else:
- self._count = ops.convert_to_tensor(
- count, dtype=dtypes.int64, name="count")
- self._seed, self._seed2 = random_seed.get_seed(seed)
-
- def _as_variant_tensor(self):
- # pylint: disable=protected-access
- input_resource = self._input_dataset._as_variant_tensor()
- return gen_dataset_ops.shuffle_and_repeat_dataset(
- input_resource,
- buffer_size=self._buffer_size,
- count=self._count,
- seed=self._seed,
- seed2=self._seed2,
- **dataset_ops.flat_structure(self))
- # pylint: enable=protected-access
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
+from tensorflow.python.data.experimental.ops import shuffle_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.shuffle_and_repeat(...)`.")
def shuffle_and_repeat(buffer_size, count=None, seed=None):
"""Shuffles and repeats a Dataset returning a new permutation for each epoch.
@@ -93,8 +51,4 @@ def shuffle_and_repeat(buffer_size, count=None, seed=None):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset): # pylint: disable=missing-docstring
- return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed)
-
- return _apply_fn
+ return shuffle_ops.shuffle_and_repeat(buffer_size, count, seed)