aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/record_input_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/record_input_test.py')
-rw-r--r--tensorflow/python/kernel_tests/record_input_test.py14
1 files changed, 9 insertions, 5 deletions
diff --git a/tensorflow/python/kernel_tests/record_input_test.py b/tensorflow/python/kernel_tests/record_input_test.py
index a3fc98c20f..8fec2affa5 100644
--- a/tensorflow/python/kernel_tests/record_input_test.py
+++ b/tensorflow/python/kernel_tests/record_input_test.py
@@ -53,6 +53,7 @@ class RecordInputOpTest(test.TestCase):
def testRecordInputEpochs(self):
files = 100
records_per_file = 100
+ batches = 2
with self.test_session() as sess:
self.generateTestData("basic", files, records_per_file)
@@ -63,17 +64,20 @@ class RecordInputOpTest(test.TestCase):
batch_size=1,
shift_ratio=0.33,
seed=10,
- name="record_input")
+ name="record_input",
+ batches=batches)
yield_op = records.get_yield_op()
# cycle over 3 epochs and make sure we never duplicate
for _ in range(3):
epoch_set = set()
- for _ in range(files * records_per_file):
- r = sess.run(yield_op)
- self.assertTrue(r[0] not in epoch_set)
- epoch_set.add(r[0])
+ for _ in range(int(files * records_per_file / batches)):
+ op_list = sess.run(yield_op)
+ self.assertTrue(len(op_list) is batches)
+ for r in op_list:
+ self.assertTrue(r[0] not in epoch_set)
+ epoch_set.add(r[0])
def testDoesNotDeadlock(self):
# Iterate multiple times to cause deadlock if there is a chance it can occur