diff options
Diffstat (limited to 'tensorflow/contrib/data/python/ops/shuffle_ops.py')
-rw-r--r-- | tensorflow/contrib/data/python/ops/shuffle_ops.py | 56 |
1 files changed, 5 insertions, 51 deletions
diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py index 985d1d87d0..329b34fdfe 100644 --- a/tensorflow/contrib/data/python/ops/shuffle_ops.py +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -17,54 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import random_seed -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops - - -class _ShuffleAndRepeatDataset(dataset_ops.UnaryDataset): - """A `Dataset` that fuses `shuffle` and `repeat`.""" - - def __init__(self, input_dataset, buffer_size, count=None, seed=None): - super(_ShuffleAndRepeatDataset, self).__init__(input_dataset) - self._input_dataset = input_dataset - self._buffer_size = ops.convert_to_tensor( - buffer_size, dtype=dtypes.int64, name="buffer_size") - if count is None: - self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") - else: - self._count = ops.convert_to_tensor( - count, dtype=dtypes.int64, name="count") - self._seed, self._seed2 = random_seed.get_seed(seed) - - def _as_variant_tensor(self): - # pylint: disable=protected-access - input_resource = self._input_dataset._as_variant_tensor() - return gen_dataset_ops.shuffle_and_repeat_dataset( - input_resource, - buffer_size=self._buffer_size, - count=self._count, - seed=self._seed, - seed2=self._seed2, - **dataset_ops.flat_structure(self)) - # pylint: enable=protected-access - - @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_types(self): - return self._input_dataset.output_types +from tensorflow.python.data.experimental.ops import shuffle_ops +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, + "Use `tf.data.experimental.shuffle_and_repeat(...)`.") def shuffle_and_repeat(buffer_size, count=None, seed=None): """Shuffles and repeats a Dataset returning a new permutation for each epoch. @@ -93,8 +51,4 @@ def shuffle_and_repeat(buffer_size, count=None, seed=None): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - - def _apply_fn(dataset): # pylint: disable=missing-docstring - return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed) - - return _apply_fn + return shuffle_ops.shuffle_and_repeat(buffer_size, count, seed) |