aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/summary_image_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/summary_image_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/summary_image_op_test.py63
1 files changed, 63 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/summary_image_op_test.py b/tensorflow/python/kernel_tests/summary_image_op_test.py
new file mode 100644
index 0000000000..dfdb2c8938
--- /dev/null
+++ b/tensorflow/python/kernel_tests/summary_image_op_test.py
@@ -0,0 +1,63 @@
+"""Tests for summary image op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.ops import image_ops
+
+
+class SummaryImageOpTest(tf.test.TestCase):
+
+ def _AsSummary(self, s):
+ summ = tf.Summary()
+ summ.ParseFromString(s)
+ return summ
+
+ def testImageSummary(self):
+ np.random.seed(7)
+ with self.test_session() as sess:
+ for depth in 1, 3, 4:
+ shape = (4, 5, 7) + (depth,)
+ bad_color = [255, 0, 0, 255][:depth]
+ for positive in False, True:
+ # Build a mostly random image with one nan
+ const = np.random.randn(*shape)
+ const[0, 1, 2] = 0 # Make the nan entry not the max
+ if positive:
+ const = 1 + np.maximum(const, 0)
+ scale = 255 / const.reshape(4, -1).max(axis=1)
+ offset = 0
+ else:
+ scale = 127 / np.abs(const.reshape(4, -1)).max(axis=1)
+ offset = 128
+ adjusted = np.floor(scale[:, None, None, None] * const + offset)
+ const[0, 1, 2, depth / 2] = np.nan
+
+ # Summarize
+ summ = tf.image_summary("img", const)
+ value = sess.run(summ)
+ self.assertEqual([], summ.get_shape())
+ image_summ = self._AsSummary(value)
+
+ # Decode the first image and check consistency
+ image = image_ops.decode_png(
+ image_summ.value[0].image.encoded_image_string).eval()
+ self.assertAllEqual(image[1, 2], bad_color)
+ image[1, 2] = adjusted[0, 1, 2]
+ self.assertAllClose(image, adjusted[0])
+
+ # Check the rest of the proto
+ # Only the first 3 images are returned.
+ for v in image_summ.value:
+ v.image.ClearField("encoded_image_string")
+ expected = '\n'.join("""
+ value {
+ tag: "img/image/%d"
+ image { height: %d width: %d colorspace: %d }
+ }""" % ((i,) + shape[1:]) for i in xrange(3))
+ self.assertProtoEquals(expected, image_summ)
+
+
+if __name__ == "__main__":
+ tf.test.main()