aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/summary
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-11-16 10:23:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-16 10:28:33 -0800
commitaa4162ac9f1812a0966d3cd9b5e441e47f035828 (patch)
tree540cd3e4813da0188c3a290b313c802871b31748 /tensorflow/contrib/summary
parent9d737356147a730326cfcbdc08b0b876dd0766e6 (diff)
contrib/summary: refactor summary_test_util
A logdir may contain files other than summary event files, e.g., checkpoints. So add a method "events_from_file" to load events from a single file. The existing "events_from_logdir" method now calls the new method. PiperOrigin-RevId: 175981886
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py6
-rw-r--r--tensorflow/contrib/summary/summary_test_util.py35
2 files changed, 31 insertions, 10 deletions
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index 09169fa6d7..c5ca054f77 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -79,7 +79,7 @@ class TargetTest(test_util.TensorFlowTestCase):
summary_ops.scalar('scalar', 2.0)
write()
- events = summary_test_util.events_from_file(logdir)
+ events = summary_test_util.events_from_logdir(logdir)
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].simple_value, 2.0)
@@ -92,7 +92,7 @@ class TargetTest(test_util.TensorFlowTestCase):
summary_ops.scalar('scalar', 2.0)
- events = summary_test_util.events_from_file(logdir)
+ events = summary_test_util.events_from_logdir(logdir)
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].tag, 'scalar')
@@ -105,7 +105,7 @@ class TargetTest(test_util.TensorFlowTestCase):
summary_ops.scalar('scalar', 2.0, global_step=global_step)
- events = summary_test_util.events_from_file(logdir)
+ events = summary_test_util.events_from_logdir(logdir)
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].tag, 'scalar')
diff --git a/tensorflow/contrib/summary/summary_test_util.py b/tensorflow/contrib/summary/summary_test_util.py
index 37b546d3ab..794c5b8bab 100644
--- a/tensorflow/contrib/summary/summary_test_util.py
+++ b/tensorflow/contrib/summary/summary_test_util.py
@@ -26,16 +26,37 @@ from tensorflow.python.lib.io import tf_record
from tensorflow.python.platform import gfile
-def events_from_file(logdir):
- """Returns all events in the single eventfile in logdir."""
- assert gfile.Exists(logdir)
- files = gfile.ListDirectory(logdir)
- assert len(files) == 1, "Found more than one file in logdir: %s" % files
- records = list(
- tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
+def events_from_file(filepath):
+ """Returns all events in a single event file.
+
+ Args:
+ filepath: Path to the event file.
+
+ Returns:
+ A list of all tf.Event protos in the event file.
+ """
+ records = list(tf_record.tf_record_iterator(filepath))
result = []
for r in records:
event = event_pb2.Event()
event.ParseFromString(r)
result.append(event)
return result
+
+
+def events_from_logdir(logdir):
+ """Returns all events in the single eventfile in logdir.
+
+ Args:
+ logdir: The directory in which the single event file is sought.
+
+ Returns:
+ A list of all tf.Event protos from the single event file.
+
+ Raises:
+ AssertionError: If logdir does not contain exactly one file.
+ """
+ assert gfile.Exists(logdir)
+ files = gfile.ListDirectory(logdir)
+ assert len(files) == 1, "Found not exactly one file in logdir: %s" % files
+ return events_from_file(os.path.join(logdir, files[0]))