diff options
Diffstat (limited to 'tensorflow/contrib/summary/summary_ops_test.py')
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_test.py | 29 |
1 files changed, 7 insertions, 22 deletions
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 405a92a726..de7ae6ec27 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -17,16 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import tempfile from tensorflow.contrib.summary import summary_ops -from tensorflow.core.util import event_pb2 +from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import errors from tensorflow.python.framework import test_util -from tensorflow.python.lib.io import tf_record from tensorflow.python.platform import gfile from tensorflow.python.training import training_util @@ -71,16 +69,9 @@ class TargetTest(test_util.TensorFlowTestCase): summary_ops.scalar('scalar', 2.0) write() - - self.assertTrue(gfile.Exists(logdir)) - files = gfile.ListDirectory(logdir) - self.assertEqual(len(files), 1) - records = list( - tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) - self.assertEqual(len(records), 2) - event = event_pb2.Event() - event.ParseFromString(records[1]) - self.assertEqual(event.summary.value[0].simple_value, 2.0) + events = summary_test_util.events_from_file(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].simple_value, 2.0) def testSummaryName(self): training_util.get_or_create_global_step() @@ -91,15 +82,9 @@ class TargetTest(test_util.TensorFlowTestCase): summary_ops.scalar('scalar', 2.0) - self.assertTrue(gfile.Exists(logdir)) - files = gfile.ListDirectory(logdir) - self.assertEqual(len(files), 1) - records = list( - tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) - self.assertEqual(len(records), 2) - event = event_pb2.Event() - event.ParseFromString(records[1]) - self.assertEqual(event.summary.value[0].tag, 'scalar') + events = summary_test_util.events_from_file(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].tag, 'scalar') if __name__ == '__main__': |