aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/summary_ops_test.py
blob: 13e5021ccca35fa2b56e9f437f68ec15d689ba5d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Tests for summary ops."""
import tensorflow.python.platform

import tensorflow as tf

class SummaryOpsTest(tf.test.TestCase):

  def _AsSummary(self, s):
    summ = tf.Summary()
    summ.ParseFromString(s)
    return summ

  def testScalarSummary(self):
    with self.test_session() as sess:
      const = tf.constant([10.0, 20.0])
      summ = tf.scalar_summary(["c1", "c2"], const, name="mysumm")
      value = sess.run(summ)
    self.assertEqual([], summ.get_shape())
    self.assertProtoEquals("""
      value { tag: "c1" simple_value: 10.0 }
      value { tag: "c2" simple_value: 20.0 }
      """, self._AsSummary(value))

  def testScalarSummaryDefaultName(self):
    with self.test_session() as sess:
      const = tf.constant([10.0, 20.0])
      summ = tf.scalar_summary(["c1", "c2"], const)
      value = sess.run(summ)
    self.assertEqual([], summ.get_shape())
    self.assertProtoEquals("""
      value { tag: "c1" simple_value: 10.0 }
      value { tag: "c2" simple_value: 20.0 }
      """, self._AsSummary(value))

  def testMergeSummary(self):
    with self.test_session() as sess:
      const = tf.constant(10.0)
      summ1 = tf.histogram_summary("h", const, name="histo")
      summ2 = tf.scalar_summary("c", const, name="summ")
      merge = tf.merge_summary([summ1, summ2])
      value = sess.run(merge)
    self.assertEqual([], merge.get_shape())
    self.assertProtoEquals("""
      value {
        tag: "h"
        histo {
          min: 10.0
          max: 10.0
          num: 1.0
          sum: 10.0
          sum_squares: 100.0
          bucket_limit: 9.93809490288
          bucket_limit: 10.9319043932
          bucket_limit: 1.79769313486e+308
          bucket: 0.0
          bucket: 1.0
          bucket: 0.0
        }
      }
      value { tag: "c" simple_value: 10.0 }
    """, self._AsSummary(value))

  def testMergeAllSummaries(self):
    with tf.Graph().as_default():
      const = tf.constant(10.0)
      summ1 = tf.histogram_summary("h", const, name="histo")
      summ2 = tf.scalar_summary("o", const, name="oops",
                                        collections=["foo_key"])
      summ3 = tf.scalar_summary("c", const, name="summ")
      merge = tf.merge_all_summaries()
      self.assertEqual("MergeSummary", merge.op.type)
      self.assertEqual(2, len(merge.op.inputs))
      self.assertEqual(summ1, merge.op.inputs[0])
      self.assertEqual(summ3, merge.op.inputs[1])
      merge = tf.merge_all_summaries("foo_key")
      self.assertEqual("MergeSummary", merge.op.type)
      self.assertEqual(1, len(merge.op.inputs))
      self.assertEqual(summ2, merge.op.inputs[0])
      self.assertTrue(tf.merge_all_summaries("bar_key") is None)


if __name__ == "__main__":
  tf.test.main()