aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-09-23 23:59:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 00:03:46 -0700
commitfe4ae644e55ac776b310160f363bcf71a221ee04 (patch)
treed57623e169908c258156cb2fe6f46103852ca3ad /tensorflow/contrib/distribute
parent03f219c4dcd68127eb417358c9c7216d7a273418 (diff)
Remove dependency on contrib dataset ops.
PiperOrigin-RevId: 214219282
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r--tensorflow/contrib/distribute/python/input_ops_test.py20
1 files changed, 2 insertions, 18 deletions
diff --git a/tensorflow/contrib/distribute/python/input_ops_test.py b/tensorflow/contrib/distribute/python/input_ops_test.py
index c5acb7ced4..559de97bb1 100644
--- a/tensorflow/contrib/distribute/python/input_ops_test.py
+++ b/tensorflow/contrib/distribute/python/input_ops_test.py
@@ -20,8 +20,6 @@ from __future__ import print_function
import os
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.distribute.python import input_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
@@ -126,20 +124,6 @@ class AutoShardDatasetTest(test.TestCase):
# contain records in order of files.
self._verifySimpleShardingOutput(dataset, self._record)
- def testParallelInterleave(self):
- dataset = dataset_ops.Dataset.from_tensor_slices(
- self._createTFRecordFiles())
- dataset = dataset.apply(interleave_ops.parallel_interleave(
- readers.TFRecordDataset,
- cycle_length=4,
- block_length=self._num_records))
- dataset = input_ops.auto_shard_dataset(
- dataset, self._num_shards, self._shard_index)
-
- # Since block_length == num records in each file, the output will still
- # contain records in order of files.
- self._verifySimpleShardingOutput(dataset, self._record)
-
def testListfiles(self):
filenames = self._createTFRecordFiles()
file_pattern = filenames[0].rsplit("/", 1)[0] + "/tf_record.*.txt"
@@ -171,8 +155,8 @@ class AutoShardDatasetTest(test.TestCase):
dataset = dataset.prefetch(buffer_size=batch_size)
dataset = dataset.shuffle(2 * self._num_files * self._num_records)
dataset = dataset.repeat(num_epochs)
- dataset = dataset.apply(batching.map_and_batch(
- lambda x: x, batch_size=batch_size))
+ dataset = dataset.map(lambda x: x)
+ dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(buffer_size=None)
# Auto shard.