aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/models/image/cifar10/cifar10_input_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/models/image/cifar10/cifar10_input_test.py')
-rw-r--r--tensorflow/models/image/cifar10/cifar10_input_test.py49
1 files changed, 49 insertions, 0 deletions
diff --git a/tensorflow/models/image/cifar10/cifar10_input_test.py b/tensorflow/models/image/cifar10/cifar10_input_test.py
new file mode 100644
index 0000000000..d43f5aedcf
--- /dev/null
+++ b/tensorflow/models/image/cifar10/cifar10_input_test.py
@@ -0,0 +1,49 @@
+"""Tests for cifar10 input."""
+
+import os
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+from tensorflow.models.image.cifar10 import cifar10_input
+
+
+class CIFAR10InputTest(tf.test.TestCase):
+
+ def _record(self, label, red, green, blue):
+ image_size = 32 * 32
+ record = "%s%s%s%s" % (chr(label), chr(red) * image_size,
+ chr(green) * image_size, chr(blue) * image_size)
+ expected = [[[red, green, blue]] * 32] * 32
+ return record, expected
+
+ def testSimple(self):
+ labels = [9, 3, 0]
+ records = [self._record(labels[0], 0, 128, 255),
+ self._record(labels[1], 255, 0, 1),
+ self._record(labels[2], 254, 255, 0)]
+ contents = "".join([record for record, _ in records])
+ expected = [expected for _, expected in records]
+ filename = os.path.join(self.get_temp_dir(), "cifar")
+ open(filename, "w").write(contents)
+
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(99, [tf.string], shapes=())
+ q.enqueue([filename]).run()
+ q.close().run()
+ result = cifar10_input.read_cifar10(q)
+
+ for i in range(3):
+ key, label, uint8image = sess.run([
+ result.key, result.label, result.uint8image])
+ self.assertEqual("%s:%d" % (filename, i), key)
+ self.assertEqual(labels[i], label)
+ self.assertAllEqual(expected[i], uint8image)
+
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run([result.key, result.uint8image])
+
+
+if __name__ == "__main__":
+ tf.test.main()