aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/input_test.py
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
commitf41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch)
treeef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/python/training/input_test.py
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108
Diffstat (limited to 'tensorflow/python/training/input_test.py')
-rw-r--r--tensorflow/python/training/input_test.py477
1 files changed, 477 insertions, 0 deletions
diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py
new file mode 100644
index 0000000000..fe8c195e77
--- /dev/null
+++ b/tensorflow/python/training/input_test.py
@@ -0,0 +1,477 @@
+"""Tests for training.input."""
+
+import os
+import itertools
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class MatchFilenamesOnceTest(tf.test.TestCase):
+
+ def test(self):
+ temp_dir = self.get_temp_dir()
+ filenames = [os.path.join(temp_dir, n) for n in os.listdir(temp_dir)]
+ additional = [os.path.join(self.get_temp_dir(), "match_filenames.%d" % i)
+ for i in range(3)]
+ for name in additional:
+ open(name, "w").write("Some contents")
+ filenames += additional
+ with self.test_session():
+ star = tf.train.match_filenames_once(
+ os.path.join(self.get_temp_dir(), "*"))
+ question = tf.train.match_filenames_once(
+ os.path.join(self.get_temp_dir(), "match_filenames.?"))
+ one = tf.train.match_filenames_once(additional[1])
+ tf.initialize_all_variables().run()
+ self.assertItemsEqual(filenames, star.eval())
+ self.assertItemsEqual(additional, question.eval())
+ self.assertItemsEqual([additional[1]], one.eval())
+
+
+class LimitEpochsTest(tf.test.TestCase):
+
+ def testNoLimit(self):
+ with self.test_session():
+ seven = tf.constant(7)
+ seven_forever = tf.train.limit_epochs(seven)
+ tf.initialize_all_variables().run()
+ for i in range(100):
+ self.assertEqual(7, seven_forever.eval())
+
+ def testLimit(self):
+ with self.test_session():
+ love_me = tf.constant("Love Me")
+ love_me_two_times = tf.train.limit_epochs(love_me, num_epochs=2)
+ tf.initialize_all_variables().run()
+ self.assertEqual("Love Me", love_me_two_times.eval())
+ self.assertEqual("Love Me", love_me_two_times.eval())
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ love_me_two_times.eval()
+
+
+class StringInputProducerTest(tf.test.TestCase):
+
+ def testNoShuffle(self):
+ with self.test_session():
+ strings = ["to", "be", "or", "not", "to", "be"]
+ num_epochs = 3
+ queue = tf.train.string_input_producer(
+ strings, num_epochs=num_epochs, shuffle=False)
+ dequeue_many = queue.dequeue_many(len(strings) * num_epochs)
+ dequeue = queue.dequeue()
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # No randomness, so just see repeated copies of the input.
+ output = dequeue_many.eval()
+ self.assertAllEqual(strings * num_epochs, output)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ dequeue.eval()
+ for thread in threads:
+ thread.join()
+
+ def testShuffle(self):
+ with self.test_session():
+ strings = ["a", "b", "c"]
+ num_epochs = 600
+ queue = tf.train.string_input_producer(
+ strings, num_epochs=num_epochs, shuffle=True, seed=271828)
+ dequeue_many = queue.dequeue_many(len(strings))
+ dequeue = queue.dequeue()
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # Validate that we only shuffle the strings within an epoch and
+ # count how often each possible order appears.
+ expected = ["abc", "acb", "bac", "bca", "cab", "cba"]
+ frequency = {}
+ for e in expected:
+ frequency[e] = 0
+ for _ in range(num_epochs):
+ output = dequeue_many.eval()
+ key = "".join(output)
+ self.assertIn(key, expected)
+ frequency[key] += 1
+
+ # Expect an approximately even distribution over all possible orders.
+ expected_frequency = num_epochs / len(expected)
+ margin = expected_frequency * 0.4
+ tf.logging.info("Observed counts: %s", frequency)
+ for key in expected:
+ value = frequency[key]
+ self.assertGreater(value, expected_frequency - margin)
+ self.assertLess(value, expected_frequency + margin)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ dequeue.eval()
+ for thread in threads:
+ thread.join()
+
+
+class RangeInputProducerTest(tf.test.TestCase):
+
+ def testNoShuffle(self):
+ with self.test_session():
+ num_epochs = 3
+ range_size = 5
+ queue = tf.train.range_input_producer(
+ range_size, num_epochs=num_epochs, shuffle=False)
+ dequeue_many = queue.dequeue_many(range_size * num_epochs)
+ dequeue = queue.dequeue()
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # No randomness, so just see repeated copies of the input.
+ output = dequeue_many.eval()
+ self.assertAllEqual(range(range_size) * num_epochs, output)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ dequeue.eval()
+ for thread in threads:
+ thread.join()
+
+ def testShuffle(self):
+ with self.test_session():
+ num_epochs = 200
+ range_size = 2
+ queue = tf.train.range_input_producer(
+ range_size, num_epochs=num_epochs, shuffle=True, seed=314159)
+ dequeue_many = queue.dequeue_many(range_size)
+ dequeue = queue.dequeue()
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # Validate that we only shuffle the integers within an epoch and
+ # count how often each possible order appears.
+ expected = [12, 21]
+ frequency = {}
+ for e in expected:
+ frequency[e] = 0
+ for _ in range(num_epochs):
+ output = dequeue_many.eval()
+ key = 10 * (output[0] + 1) + (output[1] + 1)
+ self.assertIn(key, expected)
+ frequency[key] += 1
+
+ # Expect an approximately even distribution over all possible orders.
+ expected_frequency = num_epochs / len(expected)
+ margin = expected_frequency * 0.4
+ tf.logging.info("Observed counts: %s", frequency)
+ for key in expected:
+ value = frequency[key]
+ self.assertGreater(value, expected_frequency - margin)
+ self.assertLess(value, expected_frequency + margin)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ dequeue.eval()
+ for thread in threads:
+ thread.join()
+
+
+class SliceInputProducerTest(tf.test.TestCase):
+
+ def testNoShuffle(self):
+ with self.test_session() as sess:
+ num_epochs = 3
+ source_strings = ["Alpha", "Beta", "Delta", "Gamma"]
+ source_ints = [2, 3, 5, 7]
+ slices = tf.train.slice_input_producer(
+ [source_strings, source_ints], num_epochs=num_epochs, shuffle=False)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # No randomness, so just see repeated copies of the input.
+ num_items = len(source_strings) * num_epochs
+ output = [sess.run(slices) for _ in range(num_items)]
+ out_strings, out_ints = zip(*output)
+ self.assertAllEqual(source_strings * num_epochs, out_strings)
+ self.assertAllEqual(source_ints * num_epochs, out_ints)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(slices)
+ for thread in threads:
+ thread.join()
+
+ def testShuffle(self):
+ with self.test_session() as sess:
+ num_epochs = 1200
+ source_strings = ["A", "B", "D", "G"]
+ source_ints = [7, 3, 5, 2]
+ slices = tf.train.slice_input_producer(
+ [source_strings, source_ints], num_epochs=num_epochs, shuffle=True,
+ seed=161803)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # Validate that we only shuffle the integers within an epoch and
+ # count how often each possible order appears.
+ expected = [",".join(x) for x in
+ itertools.permutations(["A7", "B3", "D5", "G2"])]
+ frequency = {}
+ for e in expected:
+ frequency[e] = 0
+ for _ in range(num_epochs):
+ output = [sess.run(slices) for _ in range(len(source_strings))]
+ key = ",".join([s + str(i) for s, i in output])
+ self.assertIn(key, expected)
+ frequency[key] += 1
+
+ # Expect an approximately even distribution over all possible orders.
+ expected_frequency = num_epochs / len(expected)
+ margin = expected_frequency * 0.4
+ tf.logging.info("Observed counts: %s", frequency)
+ for key in expected:
+ value = frequency[key]
+ self.assertGreater(value, expected_frequency - margin)
+ self.assertLess(value, expected_frequency + margin)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(slices)
+ for thread in threads:
+ thread.join()
+
+
+class BatchTest(tf.test.TestCase):
+
+ def testOneThread(self):
+ with self.test_session() as sess:
+ batch_size = 10
+ num_batches = 3
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_batches * batch_size)
+ batched = tf.train.batch([counter, "string"], batch_size=batch_size)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ for i in range(num_batches):
+ results = sess.run(batched)
+ self.assertAllEqual(results[0],
+ range(i * batch_size, (i + 1) * batch_size))
+ self.assertAllEqual(results[1], ["string"] * batch_size)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+ def testManyThreads(self):
+ with self.test_session() as sess:
+ batch_size = 10
+ num_batches = 3
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_batches * batch_size)
+ batched = tf.train.batch([counter, "string"], batch_size=batch_size,
+ num_threads=4)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ all_counts = []
+ for i in range(num_batches):
+ results = sess.run(batched)
+ tf.logging.info("Batch %d: %s", i, results[0])
+ self.assertEqual(len(results[0]), batch_size)
+ all_counts.extend(results[0])
+ self.assertAllEqual(results[1], ["string"] * batch_size)
+ self.assertItemsEqual(all_counts, range(num_batches * batch_size))
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+
+class BatchJoinTest(tf.test.TestCase):
+
+ def testTwoThreads(self):
+ with self.test_session() as sess:
+ # Two threads, the first generates (0..34, "a").
+ num_a = 35
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_a)
+
+ # The second generates (99, "b") 45 times and then stops.
+ num_b = 45
+ ninety_nine = tf.train.limit_epochs(
+ tf.constant(99, dtype=tf.int64), num_b)
+
+ # These get joined together and grouped into batches of 5.
+ batch_size = 5
+ batched = tf.train.batch_join([[counter, "a"], [ninety_nine, "b"]],
+ batch_size=batch_size)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # Should see the "a" and "b" threads mixed together.
+ all_a = []
+ seen_b = 0
+ saw_both = 0
+ num_batches = (num_a + num_b) / batch_size
+ for i in range(num_batches):
+ results = sess.run(batched)
+ tf.logging.info("Batch %d: %s", i, results[0])
+ self.assertEqual(len(results[0]), batch_size)
+ self.assertEqual(len(results[1]), batch_size)
+ which_a = [i for i, s in enumerate(results[1]) if s == "a"]
+ which_b = [i for i, s in enumerate(results[1]) if s == "b"]
+ self.assertEqual(len(which_a) + len(which_b), batch_size)
+ if len(which_a) > 0 and len(which_b) > 0: saw_both += 1
+ all_a.extend([results[0][i] for i in which_a])
+ seen_b += len(which_b)
+ self.assertAllEqual([99] * len(which_b),
+ [results[0][i] for i in which_b])
+
+ # Some minimum level of mixing of the results of both threads.
+ self.assertGreater(saw_both, 1)
+
+ # Verify the order of results from "a" were preserved.
+ self.assertAllEqual(all_a, range(num_a))
+ self.assertEqual(seen_b, num_b)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+
+class ShuffleBatchTest(tf.test.TestCase):
+
+ def testOneThread(self):
+ with self.test_session() as sess:
+ batch_size = 10
+ num_batches = 3
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_batches * batch_size)
+ batched = tf.train.shuffle_batch(
+ [counter, "string"], batch_size=batch_size, capacity=32,
+ min_after_dequeue=16, seed=141421)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ all_counts = []
+ for i in range(num_batches):
+ results = sess.run(batched)
+ self.assertEqual(len(results[0]), batch_size)
+ all_counts.extend(results[0])
+ self.assertAllEqual(results[1], ["string"] * batch_size)
+ # Results scrambled, but include all the expected numbers.
+ deltas = [all_counts[i + 1] - all_counts[i]
+ for i in range(len(all_counts) - 1)]
+ self.assertFalse(all(d == deltas[0] for d in deltas))
+ self.assertItemsEqual(all_counts, range(num_batches * batch_size))
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+ def testManyThreads(self):
+ with self.test_session() as sess:
+ batch_size = 10
+ num_batches = 3
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_batches * batch_size)
+ batched = tf.train.shuffle_batch(
+ [counter, "string"], batch_size=batch_size, capacity=32,
+ min_after_dequeue=16, seed=173205, num_threads=4)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ all_counts = []
+ for i in range(num_batches):
+ results = sess.run(batched)
+ tf.logging.info("Batch %d: %s", i, results[0])
+ self.assertEqual(len(results[0]), batch_size)
+ all_counts.extend(results[0])
+ self.assertAllEqual(results[1], ["string"] * batch_size)
+ # Results scrambled, but include all the expected numbers.
+ deltas = [all_counts[i + 1] - all_counts[i]
+ for i in range(len(all_counts) - 1)]
+ self.assertFalse(all(d == deltas[0] for d in deltas))
+ self.assertItemsEqual(all_counts, range(num_batches * batch_size))
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+
+class ShuffleBatchJoinTest(tf.test.TestCase):
+
+ def testTwoThreads(self):
+ with self.test_session() as sess:
+ # Two threads, the first generates (0..24, "a").
+ num_a = 25
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_a)
+
+ # The second generates (99, "b") 35 times and then stops.
+ num_b = 35
+ ninety_nine = tf.train.limit_epochs(
+ tf.constant(99, dtype=tf.int64), num_b)
+
+ # These get joined together and grouped into batches of 5.
+ batch_size = 5
+ batched = tf.train.shuffle_batch_join(
+ [[counter, "a"], [ninety_nine, "b"]], batch_size=batch_size,
+ capacity=32, min_after_dequeue=16, seed=223607)
+
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # Should see the "a" and "b" threads mixed together.
+ all_a = []
+ seen_b = 0
+ saw_both = 0
+ num_batches = (num_a + num_b) / batch_size
+ for i in range(num_batches):
+ results = sess.run(batched)
+ tf.logging.info("Batch %d: %s", i, results[0])
+ self.assertEqual(len(results[0]), batch_size)
+ self.assertEqual(len(results[1]), batch_size)
+ which_a = [i for i, s in enumerate(results[1]) if s == "a"]
+ which_b = [i for i, s in enumerate(results[1]) if s == "b"]
+ self.assertEqual(len(which_a) + len(which_b), batch_size)
+ if len(which_a) > 0 and len(which_b) > 0: saw_both += 1
+ all_a.extend([results[0][i] for i in which_a])
+ seen_b += len(which_b)
+ self.assertAllEqual([99] * len(which_b),
+ [results[0][i] for i in which_b])
+
+ # Some minimum level of mixing of the results of both threads.
+ self.assertGreater(saw_both, 1)
+
+ # Saw all the items from "a", but scrambled.
+ self.assertItemsEqual(all_a, range(num_a))
+ deltas = [all_a[i + 1] - all_a[i]
+ for i in range(len(all_a) - 1)]
+ self.assertFalse(all(d == deltas[0] for d in deltas))
+ self.assertEqual(seen_b, num_b)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+
+if __name__ == "__main__":
+ tf.test.main()