aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/summary_writer_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/summary_writer_test.py')
-rw-r--r--tensorflow/python/training/summary_writer_test.py151
1 files changed, 151 insertions, 0 deletions
diff --git a/tensorflow/python/training/summary_writer_test.py b/tensorflow/python/training/summary_writer_test.py
new file mode 100644
index 0000000000..2ec416f68f
--- /dev/null
+++ b/tensorflow/python/training/summary_writer_test.py
@@ -0,0 +1,151 @@
+"""Tests for training_coordinator.py."""
+import glob
+import os.path
+import shutil
+import time
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class SummaryWriterTestCase(tf.test.TestCase):
+
+ def _TestDir(self, test_name):
+ test_dir = os.path.join(self.get_temp_dir(), test_name)
+ return test_dir
+
+ def _CleanTestDir(self, test_name):
+ test_dir = self._TestDir(test_name)
+ if os.path.exists(test_dir):
+ shutil.rmtree(test_dir)
+ return test_dir
+
+ def _EventsReader(self, test_dir):
+ event_paths = glob.glob(os.path.join(test_dir, "event*"))
+ # If the tests runs multiple time in the same directory we can have
+ # more than one matching event file. We only want to read the last one.
+ self.assertTrue(event_paths)
+ return tf.train.summary_iterator(event_paths[-1])
+
+ def _assertRecent(self, t):
+ self.assertTrue(abs(t - time.time()) < 5)
+
+ def testBasics(self):
+ test_dir = self._CleanTestDir("basics")
+ sw = tf.train.SummaryWriter(test_dir)
+ sw.add_summary(tf.Summary(value=[tf.Summary.Value(tag="mee",
+ simple_value=10.0)]),
+ 10)
+ sw.add_summary(tf.Summary(value=[tf.Summary.Value(tag="boo",
+ simple_value=20.0)]),
+ 20)
+ with tf.Graph().as_default() as g:
+ tf.constant([0], name="zero")
+ gd = g.as_graph_def()
+ sw.add_graph(gd, global_step=30)
+ sw.close()
+ rr = self._EventsReader(test_dir)
+
+ # The first event should list the file_version.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals("brain.Event:1", ev.file_version)
+
+ # The next event should have the value 'mee=10.0'.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(10, ev.step)
+ self.assertProtoEquals("""
+ value { tag: 'mee' simple_value: 10.0 }
+ """, ev.summary)
+
+ # The next event should have the value 'boo=20.0'.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(20, ev.step)
+ self.assertProtoEquals("""
+ value { tag: 'boo' simple_value: 20.0 }
+ """, ev.summary)
+
+ # The next event should have the graph_def.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(30, ev.step)
+ self.assertProtoEquals(gd, ev.graph_def)
+
+ # We should be done.
+ self.assertRaises(StopIteration, lambda: next(rr))
+
+ def testConstructWithGraph(self):
+ test_dir = self._CleanTestDir("basics_with_graph")
+ with tf.Graph().as_default() as g:
+ tf.constant([12], name="douze")
+ gd = g.as_graph_def()
+ sw = tf.train.SummaryWriter(test_dir, graph_def=gd)
+ sw.close()
+ rr = self._EventsReader(test_dir)
+
+ # The first event should list the file_version.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals("brain.Event:1", ev.file_version)
+
+ # The next event should have the graph.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(0, ev.step)
+ self.assertProtoEquals(gd, ev.graph_def)
+
+ # We should be done.
+ self.assertRaises(StopIteration, lambda: next(rr))
+
+ # Checks that values returned from session Run() calls are added correctly to
+ # summaries. These are numpy types so we need to check they fit in the
+ # protocol buffers correctly.
+ def testSummariesAndStopFromSessionRunCalls(self):
+ test_dir = self._CleanTestDir("global_step")
+ sw = tf.train.SummaryWriter(test_dir)
+ with self.test_session():
+ i = tf.constant(1, dtype=tf.int32, shape=[])
+ l = tf.constant(2, dtype=tf.int64, shape=[])
+ # Test the summary can be passed serialized.
+ summ = tf.Summary(value=[tf.Summary.Value(tag="i", simple_value=1.0)])
+ sw.add_summary(summ.SerializeToString(), i.eval())
+ sw.add_summary(tf.Summary(value=[tf.Summary.Value(tag="l",
+ simple_value=2.0)]),
+ l.eval())
+ sw.close()
+
+ rr = self._EventsReader(test_dir)
+
+ # File_version.
+ ev = next(rr)
+ self.assertTrue(ev)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals("brain.Event:1", ev.file_version)
+
+ # Summary passed serialized.
+ ev = next(rr)
+ self.assertTrue(ev)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(1, ev.step)
+ self.assertProtoEquals("""
+ value { tag: 'i' simple_value: 1.0 }
+ """, ev.summary)
+
+ # Summary passed as SummaryObject.
+ ev = next(rr)
+ self.assertTrue(ev)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(2, ev.step)
+ self.assertProtoEquals("""
+ value { tag: 'l' simple_value: 2.0 }
+ """, ev.summary)
+
+ # We should be done.
+ self.assertRaises(StopIteration, lambda: next(rr))
+
+
+if __name__ == "__main__":
+ tf.test.main()