aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/summary_writer_test.py
blob: 2ec416f68fd8296d0c1f0591b82d06e0a42213cc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""Tests for training_coordinator.py."""
import glob
import os.path
import shutil
import time

import tensorflow.python.platform

import tensorflow as tf


class SummaryWriterTestCase(tf.test.TestCase):

  def _TestDir(self, test_name):
    test_dir = os.path.join(self.get_temp_dir(), test_name)
    return test_dir

  def _CleanTestDir(self, test_name):
    test_dir = self._TestDir(test_name)
    if os.path.exists(test_dir):
      shutil.rmtree(test_dir)
    return test_dir

  def _EventsReader(self, test_dir):
    event_paths = glob.glob(os.path.join(test_dir, "event*"))
    # If the tests runs multiple time in the same directory we can have
    # more than one matching event file.  We only want to read the last one.
    self.assertTrue(event_paths)
    return tf.train.summary_iterator(event_paths[-1])

  def _assertRecent(self, t):
    self.assertTrue(abs(t - time.time()) < 5)

  def testBasics(self):
    test_dir = self._CleanTestDir("basics")
    sw = tf.train.SummaryWriter(test_dir)
    sw.add_summary(tf.Summary(value=[tf.Summary.Value(tag="mee",
                                                      simple_value=10.0)]),
                   10)
    sw.add_summary(tf.Summary(value=[tf.Summary.Value(tag="boo",
                                                      simple_value=20.0)]),
                   20)
    with tf.Graph().as_default() as g:
      tf.constant([0], name="zero")
    gd = g.as_graph_def()
    sw.add_graph(gd, global_step=30)
    sw.close()
    rr = self._EventsReader(test_dir)

    # The first event should list the file_version.
    ev = next(rr)
    self._assertRecent(ev.wall_time)
    self.assertEquals("brain.Event:1", ev.file_version)

    # The next event should have the value 'mee=10.0'.
    ev = next(rr)
    self._assertRecent(ev.wall_time)
    self.assertEquals(10, ev.step)
    self.assertProtoEquals("""
      value { tag: 'mee' simple_value: 10.0 }
      """, ev.summary)

    # The next event should have the value 'boo=20.0'.
    ev = next(rr)
    self._assertRecent(ev.wall_time)
    self.assertEquals(20, ev.step)
    self.assertProtoEquals("""
      value { tag: 'boo' simple_value: 20.0 }
      """, ev.summary)

    # The next event should have the graph_def.
    ev = next(rr)
    self._assertRecent(ev.wall_time)
    self.assertEquals(30, ev.step)
    self.assertProtoEquals(gd, ev.graph_def)

    # We should be done.
    self.assertRaises(StopIteration, lambda: next(rr))

  def testConstructWithGraph(self):
    test_dir = self._CleanTestDir("basics_with_graph")
    with tf.Graph().as_default() as g:
      tf.constant([12], name="douze")
    gd = g.as_graph_def()
    sw = tf.train.SummaryWriter(test_dir, graph_def=gd)
    sw.close()
    rr = self._EventsReader(test_dir)

    # The first event should list the file_version.
    ev = next(rr)
    self._assertRecent(ev.wall_time)
    self.assertEquals("brain.Event:1", ev.file_version)

    # The next event should have the graph.
    ev = next(rr)
    self._assertRecent(ev.wall_time)
    self.assertEquals(0, ev.step)
    self.assertProtoEquals(gd, ev.graph_def)

    # We should be done.
    self.assertRaises(StopIteration, lambda: next(rr))

  # Checks that values returned from session Run() calls are added correctly to
  # summaries.  These are numpy types so we need to check they fit in the
  # protocol buffers correctly.
  def testSummariesAndStopFromSessionRunCalls(self):
    test_dir = self._CleanTestDir("global_step")
    sw = tf.train.SummaryWriter(test_dir)
    with self.test_session():
      i = tf.constant(1, dtype=tf.int32, shape=[])
      l = tf.constant(2, dtype=tf.int64, shape=[])
      # Test the summary can be passed serialized.
      summ = tf.Summary(value=[tf.Summary.Value(tag="i", simple_value=1.0)])
      sw.add_summary(summ.SerializeToString(), i.eval())
      sw.add_summary(tf.Summary(value=[tf.Summary.Value(tag="l",
                                                        simple_value=2.0)]),
                     l.eval())
      sw.close()

    rr = self._EventsReader(test_dir)

    # File_version.
    ev = next(rr)
    self.assertTrue(ev)
    self._assertRecent(ev.wall_time)
    self.assertEquals("brain.Event:1", ev.file_version)

    # Summary passed serialized.
    ev = next(rr)
    self.assertTrue(ev)
    self._assertRecent(ev.wall_time)
    self.assertEquals(1, ev.step)
    self.assertProtoEquals("""
      value { tag: 'i' simple_value: 1.0 }
      """, ev.summary)

    # Summary passed as SummaryObject.
    ev = next(rr)
    self.assertTrue(ev)
    self._assertRecent(ev.wall_time)
    self.assertEquals(2, ev.step)
    self.assertProtoEquals("""
      value { tag: 'l' simple_value: 2.0 }
      """, ev.summary)

    # We should be done.
    self.assertRaises(StopIteration, lambda: next(rr))


if __name__ == "__main__":
  tf.test.main()