diff options
author | 2017-11-16 10:23:48 -0800 | |
---|---|---|
committer | 2017-11-16 10:28:33 -0800 | |
commit | aa4162ac9f1812a0966d3cd9b5e441e47f035828 (patch) | |
tree | 540cd3e4813da0188c3a290b313c802871b31748 /tensorflow/contrib/summary | |
parent | 9d737356147a730326cfcbdc08b0b876dd0766e6 (diff) |
contrib/summary: refactor summary_test_util
A logdir may contain files other than summary event files, e.g., checkpoints.
So add a method "events_from_file" to load events from a single file.
The existing "events_from_logdir" method now calls the new method.
PiperOrigin-RevId: 175981886
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_test.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_test_util.py | 35 |
2 files changed, 31 insertions, 10 deletions
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 09169fa6d7..c5ca054f77 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -79,7 +79,7 @@ class TargetTest(test_util.TensorFlowTestCase): summary_ops.scalar('scalar', 2.0) write() - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].simple_value, 2.0) @@ -92,7 +92,7 @@ class TargetTest(test_util.TensorFlowTestCase): summary_ops.scalar('scalar', 2.0) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'scalar') @@ -105,7 +105,7 @@ class TargetTest(test_util.TensorFlowTestCase): summary_ops.scalar('scalar', 2.0, global_step=global_step) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'scalar') diff --git a/tensorflow/contrib/summary/summary_test_util.py b/tensorflow/contrib/summary/summary_test_util.py index 37b546d3ab..794c5b8bab 100644 --- a/tensorflow/contrib/summary/summary_test_util.py +++ b/tensorflow/contrib/summary/summary_test_util.py @@ -26,16 +26,37 @@ from tensorflow.python.lib.io import tf_record from tensorflow.python.platform import gfile -def events_from_file(logdir): - """Returns all events in the single eventfile in logdir.""" - assert gfile.Exists(logdir) - files = gfile.ListDirectory(logdir) - assert len(files) == 1, "Found more than one file in logdir: %s" % files - records = list( - tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) +def events_from_file(filepath): + """Returns all events in a single event file. + + Args: + filepath: Path to the event file. + + Returns: + A list of all tf.Event protos in the event file. + """ + records = list(tf_record.tf_record_iterator(filepath)) result = [] for r in records: event = event_pb2.Event() event.ParseFromString(r) result.append(event) return result + + +def events_from_logdir(logdir): + """Returns all events in the single eventfile in logdir. + + Args: + logdir: The directory in which the single event file is sought. + + Returns: + A list of all tf.Event protos from the single event file. + + Raises: + AssertionError: If logdir does not contain exactly one file. + """ + assert gfile.Exists(logdir) + files = gfile.ListDirectory(logdir) + assert len(files) == 1, "Found not exactly one file in logdir: %s" % files + return events_from_file(os.path.join(logdir, files[0])) |