diff options
Diffstat (limited to 'tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py')
-rw-r--r-- | tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py b/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py index a46e4b00f9..10ea883e1f 100644 --- a/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py +++ b/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py @@ -24,6 +24,7 @@ from tensorflow.contrib.slim.python.slim.data import test_utils from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import variables @@ -74,6 +75,54 @@ class ParallelReaderTest(test.TestCase): self.assertGreater(count2, 0) self.assertEquals(count0 + count1 + count2, num_reads) + def _verify_read_up_to_out(self, shared_queue): + with self.test_session(): + num_files = 3 + num_records_per_file = 7 + tfrecord_paths = test_utils.create_tfrecord_files( + self.get_temp_dir(), + num_files=num_files, + num_records_per_file=num_records_per_file) + + p_reader = parallel_reader.ParallelReader( + io_ops.TFRecordReader, shared_queue, num_readers=5) + + data_files = parallel_reader.get_data_files(tfrecord_paths) + filename_queue = input_lib.string_input_producer(data_files, num_epochs=1) + key, value = p_reader.read_up_to(filename_queue, 4) + + count0 = 0 + count1 = 0 + count2 = 0 + all_keys_count = 0 + all_values_count = 0 + + sv = supervisor.Supervisor(logdir=self.get_temp_dir()) + with sv.prepare_or_wait_for_session() as sess: + sv.start_queue_runners(sess) + while True: + try: + current_keys, current_values = sess.run([key, value]) + self.assertEquals(len(current_keys), len(current_values)) + all_keys_count += len(current_keys) + all_values_count += len(current_values) + for current_key in current_keys: + if '0-of-3' in str(current_key): + count0 += 1 + if '1-of-3' in str(current_key): + count1 += 1 + if '2-of-3' in str(current_key): + count2 += 1 + except errors_impl.OutOfRangeError: + break + + self.assertEquals(count0, num_records_per_file) + self.assertEquals(count1, num_records_per_file) + self.assertEquals(count2, num_records_per_file) + self.assertEquals(all_keys_count, num_files * num_records_per_file) + self.assertEquals(all_values_count, all_keys_count) + self.assertEquals(count0 + count1 + count2, all_keys_count) + def testRandomShuffleQueue(self): shared_queue = data_flow_ops.RandomShuffleQueue( capacity=256, @@ -86,6 +135,23 @@ class ParallelReaderTest(test.TestCase): capacity=256, dtypes=[dtypes_lib.string, dtypes_lib.string]) self._verify_all_data_sources_read(shared_queue) + def testReadUpToFromRandomShuffleQueue(self): + shared_queue = data_flow_ops.RandomShuffleQueue( + capacity=55, + min_after_dequeue=28, + dtypes=[dtypes_lib.string, dtypes_lib.string], + shapes=[tensor_shape.scalar(), + tensor_shape.scalar()]) + self._verify_read_up_to_out(shared_queue) + + def testReadUpToFromFIFOQueue(self): + shared_queue = data_flow_ops.FIFOQueue( + capacity=99, + dtypes=[dtypes_lib.string, dtypes_lib.string], + shapes=[tensor_shape.scalar(), + tensor_shape.scalar()]) + self._verify_read_up_to_out(shared_queue) + class ParallelReadTest(test.TestCase): |