aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/summary_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/summary_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/summary_ops_test.py83
1 files changed, 83 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops_test.py
new file mode 100644
index 0000000000..13e5021ccc
--- /dev/null
+++ b/tensorflow/python/kernel_tests/summary_ops_test.py
@@ -0,0 +1,83 @@
+"""Tests for summary ops."""
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+class SummaryOpsTest(tf.test.TestCase):
+
+ def _AsSummary(self, s):
+ summ = tf.Summary()
+ summ.ParseFromString(s)
+ return summ
+
+ def testScalarSummary(self):
+ with self.test_session() as sess:
+ const = tf.constant([10.0, 20.0])
+ summ = tf.scalar_summary(["c1", "c2"], const, name="mysumm")
+ value = sess.run(summ)
+ self.assertEqual([], summ.get_shape())
+ self.assertProtoEquals("""
+ value { tag: "c1" simple_value: 10.0 }
+ value { tag: "c2" simple_value: 20.0 }
+ """, self._AsSummary(value))
+
+ def testScalarSummaryDefaultName(self):
+ with self.test_session() as sess:
+ const = tf.constant([10.0, 20.0])
+ summ = tf.scalar_summary(["c1", "c2"], const)
+ value = sess.run(summ)
+ self.assertEqual([], summ.get_shape())
+ self.assertProtoEquals("""
+ value { tag: "c1" simple_value: 10.0 }
+ value { tag: "c2" simple_value: 20.0 }
+ """, self._AsSummary(value))
+
+ def testMergeSummary(self):
+ with self.test_session() as sess:
+ const = tf.constant(10.0)
+ summ1 = tf.histogram_summary("h", const, name="histo")
+ summ2 = tf.scalar_summary("c", const, name="summ")
+ merge = tf.merge_summary([summ1, summ2])
+ value = sess.run(merge)
+ self.assertEqual([], merge.get_shape())
+ self.assertProtoEquals("""
+ value {
+ tag: "h"
+ histo {
+ min: 10.0
+ max: 10.0
+ num: 1.0
+ sum: 10.0
+ sum_squares: 100.0
+ bucket_limit: 9.93809490288
+ bucket_limit: 10.9319043932
+ bucket_limit: 1.79769313486e+308
+ bucket: 0.0
+ bucket: 1.0
+ bucket: 0.0
+ }
+ }
+ value { tag: "c" simple_value: 10.0 }
+ """, self._AsSummary(value))
+
+ def testMergeAllSummaries(self):
+ with tf.Graph().as_default():
+ const = tf.constant(10.0)
+ summ1 = tf.histogram_summary("h", const, name="histo")
+ summ2 = tf.scalar_summary("o", const, name="oops",
+ collections=["foo_key"])
+ summ3 = tf.scalar_summary("c", const, name="summ")
+ merge = tf.merge_all_summaries()
+ self.assertEqual("MergeSummary", merge.op.type)
+ self.assertEqual(2, len(merge.op.inputs))
+ self.assertEqual(summ1, merge.op.inputs[0])
+ self.assertEqual(summ3, merge.op.inputs[1])
+ merge = tf.merge_all_summaries("foo_key")
+ self.assertEqual("MergeSummary", merge.op.type)
+ self.assertEqual(1, len(merge.op.inputs))
+ self.assertEqual(summ2, merge.op.inputs[0])
+ self.assertTrue(tf.merge_all_summaries("bar_key") is None)
+
+
+if __name__ == "__main__":
+ tf.test.main()