aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py')
-rw-r--r--tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py66
1 files changed, 66 insertions, 0 deletions
diff --git a/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py b/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
index a46e4b00f9..10ea883e1f 100644
--- a/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
@@ -24,6 +24,7 @@ from tensorflow.contrib.slim.python.slim.data import test_utils
from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import variables
@@ -74,6 +75,54 @@ class ParallelReaderTest(test.TestCase):
self.assertGreater(count2, 0)
self.assertEquals(count0 + count1 + count2, num_reads)
+ def _verify_read_up_to_out(self, shared_queue):
+ with self.test_session():
+ num_files = 3
+ num_records_per_file = 7
+ tfrecord_paths = test_utils.create_tfrecord_files(
+ self.get_temp_dir(),
+ num_files=num_files,
+ num_records_per_file=num_records_per_file)
+
+ p_reader = parallel_reader.ParallelReader(
+ io_ops.TFRecordReader, shared_queue, num_readers=5)
+
+ data_files = parallel_reader.get_data_files(tfrecord_paths)
+ filename_queue = input_lib.string_input_producer(data_files, num_epochs=1)
+ key, value = p_reader.read_up_to(filename_queue, 4)
+
+ count0 = 0
+ count1 = 0
+ count2 = 0
+ all_keys_count = 0
+ all_values_count = 0
+
+ sv = supervisor.Supervisor(logdir=self.get_temp_dir())
+ with sv.prepare_or_wait_for_session() as sess:
+ sv.start_queue_runners(sess)
+ while True:
+ try:
+ current_keys, current_values = sess.run([key, value])
+ self.assertEquals(len(current_keys), len(current_values))
+ all_keys_count += len(current_keys)
+ all_values_count += len(current_values)
+ for current_key in current_keys:
+ if '0-of-3' in str(current_key):
+ count0 += 1
+ if '1-of-3' in str(current_key):
+ count1 += 1
+ if '2-of-3' in str(current_key):
+ count2 += 1
+ except errors_impl.OutOfRangeError:
+ break
+
+ self.assertEquals(count0, num_records_per_file)
+ self.assertEquals(count1, num_records_per_file)
+ self.assertEquals(count2, num_records_per_file)
+ self.assertEquals(all_keys_count, num_files * num_records_per_file)
+ self.assertEquals(all_values_count, all_keys_count)
+ self.assertEquals(count0 + count1 + count2, all_keys_count)
+
def testRandomShuffleQueue(self):
shared_queue = data_flow_ops.RandomShuffleQueue(
capacity=256,
@@ -86,6 +135,23 @@ class ParallelReaderTest(test.TestCase):
capacity=256, dtypes=[dtypes_lib.string, dtypes_lib.string])
self._verify_all_data_sources_read(shared_queue)
+ def testReadUpToFromRandomShuffleQueue(self):
+ shared_queue = data_flow_ops.RandomShuffleQueue(
+ capacity=55,
+ min_after_dequeue=28,
+ dtypes=[dtypes_lib.string, dtypes_lib.string],
+ shapes=[tensor_shape.scalar(),
+ tensor_shape.scalar()])
+ self._verify_read_up_to_out(shared_queue)
+
+ def testReadUpToFromFIFOQueue(self):
+ shared_queue = data_flow_ops.FIFOQueue(
+ capacity=99,
+ dtypes=[dtypes_lib.string, dtypes_lib.string],
+ shapes=[tensor_shape.scalar(),
+ tensor_shape.scalar()])
+ self._verify_read_up_to_out(shared_queue)
+
class ParallelReadTest(test.TestCase):