diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/summary_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/summary_ops_test.py | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops_test.py new file mode 100644 index 0000000000..13e5021ccc --- /dev/null +++ b/tensorflow/python/kernel_tests/summary_ops_test.py @@ -0,0 +1,83 @@ +"""Tests for summary ops.""" +import tensorflow.python.platform + +import tensorflow as tf + +class SummaryOpsTest(tf.test.TestCase): + + def _AsSummary(self, s): + summ = tf.Summary() + summ.ParseFromString(s) + return summ + + def testScalarSummary(self): + with self.test_session() as sess: + const = tf.constant([10.0, 20.0]) + summ = tf.scalar_summary(["c1", "c2"], const, name="mysumm") + value = sess.run(summ) + self.assertEqual([], summ.get_shape()) + self.assertProtoEquals(""" + value { tag: "c1" simple_value: 10.0 } + value { tag: "c2" simple_value: 20.0 } + """, self._AsSummary(value)) + + def testScalarSummaryDefaultName(self): + with self.test_session() as sess: + const = tf.constant([10.0, 20.0]) + summ = tf.scalar_summary(["c1", "c2"], const) + value = sess.run(summ) + self.assertEqual([], summ.get_shape()) + self.assertProtoEquals(""" + value { tag: "c1" simple_value: 10.0 } + value { tag: "c2" simple_value: 20.0 } + """, self._AsSummary(value)) + + def testMergeSummary(self): + with self.test_session() as sess: + const = tf.constant(10.0) + summ1 = tf.histogram_summary("h", const, name="histo") + summ2 = tf.scalar_summary("c", const, name="summ") + merge = tf.merge_summary([summ1, summ2]) + value = sess.run(merge) + self.assertEqual([], merge.get_shape()) + self.assertProtoEquals(""" + value { + tag: "h" + histo { + min: 10.0 + max: 10.0 + num: 1.0 + sum: 10.0 + sum_squares: 100.0 + bucket_limit: 9.93809490288 + bucket_limit: 10.9319043932 + bucket_limit: 1.79769313486e+308 + bucket: 0.0 + bucket: 1.0 + bucket: 0.0 + } + } + value { tag: "c" simple_value: 10.0 } + """, self._AsSummary(value)) + + def testMergeAllSummaries(self): + with tf.Graph().as_default(): + const = tf.constant(10.0) + summ1 = tf.histogram_summary("h", const, name="histo") + summ2 = tf.scalar_summary("o", const, name="oops", + collections=["foo_key"]) + summ3 = tf.scalar_summary("c", const, name="summ") + merge = tf.merge_all_summaries() + self.assertEqual("MergeSummary", merge.op.type) + self.assertEqual(2, len(merge.op.inputs)) + self.assertEqual(summ1, merge.op.inputs[0]) + self.assertEqual(summ3, merge.op.inputs[1]) + merge = tf.merge_all_summaries("foo_key") + self.assertEqual("MergeSummary", merge.op.type) + self.assertEqual(1, len(merge.op.inputs)) + self.assertEqual(summ2, merge.op.inputs[0]) + self.assertTrue(tf.merge_all_summaries("bar_key") is None) + + +if __name__ == "__main__": + tf.test.main() |