aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/summary/summary_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/summary/summary_ops_test.py')
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py29
1 files changed, 7 insertions, 22 deletions
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index 405a92a726..de7ae6ec27 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -17,16 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os
import tempfile
from tensorflow.contrib.summary import summary_ops
-from tensorflow.core.util import event_pb2
+from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
-from tensorflow.python.lib.io import tf_record
from tensorflow.python.platform import gfile
from tensorflow.python.training import training_util
@@ -71,16 +69,9 @@ class TargetTest(test_util.TensorFlowTestCase):
summary_ops.scalar('scalar', 2.0)
write()
-
- self.assertTrue(gfile.Exists(logdir))
- files = gfile.ListDirectory(logdir)
- self.assertEqual(len(files), 1)
- records = list(
- tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
- self.assertEqual(len(records), 2)
- event = event_pb2.Event()
- event.ParseFromString(records[1])
- self.assertEqual(event.summary.value[0].simple_value, 2.0)
+ events = summary_test_util.events_from_file(logdir)
+ self.assertEqual(len(events), 2)
+ self.assertEqual(events[1].summary.value[0].simple_value, 2.0)
def testSummaryName(self):
training_util.get_or_create_global_step()
@@ -91,15 +82,9 @@ class TargetTest(test_util.TensorFlowTestCase):
summary_ops.scalar('scalar', 2.0)
- self.assertTrue(gfile.Exists(logdir))
- files = gfile.ListDirectory(logdir)
- self.assertEqual(len(files), 1)
- records = list(
- tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
- self.assertEqual(len(records), 2)
- event = event_pb2.Event()
- event.ParseFromString(records[1])
- self.assertEqual(event.summary.value[0].tag, 'scalar')
+ events = summary_test_util.events_from_file(logdir)
+ self.assertEqual(len(events), 2)
+ self.assertEqual(events[1].summary.value[0].tag, 'scalar')
if __name__ == '__main__':