diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/record_input_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/record_input_test.py | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/tensorflow/python/kernel_tests/record_input_test.py b/tensorflow/python/kernel_tests/record_input_test.py index a3fc98c20f..8fec2affa5 100644 --- a/tensorflow/python/kernel_tests/record_input_test.py +++ b/tensorflow/python/kernel_tests/record_input_test.py @@ -53,6 +53,7 @@ class RecordInputOpTest(test.TestCase): def testRecordInputEpochs(self): files = 100 records_per_file = 100 + batches = 2 with self.test_session() as sess: self.generateTestData("basic", files, records_per_file) @@ -63,17 +64,20 @@ class RecordInputOpTest(test.TestCase): batch_size=1, shift_ratio=0.33, seed=10, - name="record_input") + name="record_input", + batches=batches) yield_op = records.get_yield_op() # cycle over 3 epochs and make sure we never duplicate for _ in range(3): epoch_set = set() - for _ in range(files * records_per_file): - r = sess.run(yield_op) - self.assertTrue(r[0] not in epoch_set) - epoch_set.add(r[0]) + for _ in range(int(files * records_per_file / batches)): + op_list = sess.run(yield_op) + self.assertTrue(len(op_list) is batches) + for r in op_list: + self.assertTrue(r[0] not in epoch_set) + epoch_set.add(r[0]) def testDoesNotDeadlock(self): # Iterate multiple times to cause deadlock if there is a chance it can occur |