aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/record_input_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/record_input_test.py')
-rw-r--r--tensorflow/python/kernel_tests/record_input_test.py14
1 files changed, 5 insertions, 9 deletions
diff --git a/tensorflow/python/kernel_tests/record_input_test.py b/tensorflow/python/kernel_tests/record_input_test.py
index 8fec2affa5..a3fc98c20f 100644
--- a/tensorflow/python/kernel_tests/record_input_test.py
+++ b/tensorflow/python/kernel_tests/record_input_test.py
@@ -53,7 +53,6 @@ class RecordInputOpTest(test.TestCase):
def testRecordInputEpochs(self):
files = 100
records_per_file = 100
- batches = 2
with self.test_session() as sess:
self.generateTestData("basic", files, records_per_file)
@@ -64,20 +63,17 @@ class RecordInputOpTest(test.TestCase):
batch_size=1,
shift_ratio=0.33,
seed=10,
- name="record_input",
- batches=batches)
+ name="record_input")
yield_op = records.get_yield_op()
# cycle over 3 epochs and make sure we never duplicate
for _ in range(3):
epoch_set = set()
- for _ in range(int(files * records_per_file / batches)):
- op_list = sess.run(yield_op)
- self.assertTrue(len(op_list) is batches)
- for r in op_list:
- self.assertTrue(r[0] not in epoch_set)
- epoch_set.add(r[0])
+ for _ in range(files * records_per_file):
+ r = sess.run(yield_op)
+ self.assertTrue(r[0] not in epoch_set)
+ epoch_set.add(r[0])
def testDoesNotDeadlock(self):
# Iterate multiple times to cause deadlock if there is a chance it can occur