aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Nick Felt <nickfelt@google.com>2018-04-10 23:44:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-10 23:46:53 -0700
commit6accb84d8437cb915e23d83673c233f5084aad68 (patch)
tree44a28a02a04b2cbcc6f3f340f8ec60e48abb8bdb
parent231146433a45ca8135e132ee0b48469798ca0b1f (diff)
Create FileWriter <-> tf.contrib.summary compatibility layer
This provides an implementation of FileWriter, activated by passing in a `session` parameter to the constructor, that is backed by session.run'ing graph ops that manipulate a tf.contrib.summary.create_file_writer() instance. Because tf.contrib.summary.SummaryWriters are backed by shared resources in the graph, this makes it possible to have a FileWriter and a tf.contrib.summary.SummaryWriter that both write to the same events file. This change includes some related smaller changes: - Factors out training_utils.py into a separate target to avoid a cyclic dep - Moves contrib/summary/summary_ops.py to python/ops/summary_ops_v2.py - Adds SummaryWriter.init(), .flush(), and .close() op-returning methods - Changes create_file_writer() `name` arg to default to logdir prefixed by `logdir:` so shared resources are scoped by logdir by default - Fixes a bug with tf.contrib.summary.flush() `writer` arg - Makes create_file_writer()'s max_queue arg behave as documented - Adds more testing for existing tf.contrib.summary API PiperOrigin-RevId: 192408079
-rw-r--r--tensorflow/contrib/eager/python/BUILD6
-rw-r--r--tensorflow/contrib/eager/python/evaluator.py2
-rw-r--r--tensorflow/contrib/eager/python/metrics_impl.py2
-rw-r--r--tensorflow/contrib/eager/python/metrics_test.py2
-rw-r--r--tensorflow/contrib/summary/BUILD33
-rw-r--r--tensorflow/contrib/summary/summary.py40
-rw-r--r--tensorflow/contrib/summary/summary_ops_graph_test.py197
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py113
-rw-r--r--tensorflow/contrib/summary/summary_test_internal.py60
-rw-r--r--tensorflow/contrib/summary/summary_test_util.py2
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_file_writer.cc2
-rw-r--r--tensorflow/contrib/tpu/BUILD2
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py2
-rw-r--r--tensorflow/python/BUILD54
-rw-r--r--tensorflow/python/ops/summary_ops_v2.py (renamed from tensorflow/contrib/summary/summary_ops.py)68
-rw-r--r--tensorflow/python/summary/writer/event_file_writer_v2.py140
-rw-r--r--tensorflow/python/summary/writer/writer.py40
-rw-r--r--tensorflow/python/summary/writer/writer_test.py233
-rw-r--r--tensorflow/tools/api/golden/tensorflow.summary.-file-writer.pbtxt2
19 files changed, 797 insertions, 203 deletions
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD
index 4e088503bf..d97048405d 100644
--- a/tensorflow/contrib/eager/python/BUILD
+++ b/tensorflow/contrib/eager/python/BUILD
@@ -120,13 +120,13 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/contrib/eager/python:checkpointable_utils",
- "//tensorflow/contrib/summary:summary_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:summary_ops_v2",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:context",
@@ -140,11 +140,11 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":metrics",
- "//tensorflow/contrib/summary:summary_ops",
"//tensorflow/contrib/summary:summary_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:summary_ops_v2",
"//tensorflow/python:training",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
@@ -161,10 +161,10 @@ py_library(
deps = [
":datasets",
":metrics",
- "//tensorflow/contrib/summary:summary_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:summary_ops_v2",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:function",
"@six_archive//:six",
diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py
index 37c8f0d47a..7949a3f6da 100644
--- a/tensorflow/contrib/eager/python/evaluator.py
+++ b/tensorflow/contrib/eager/python/evaluator.py
@@ -22,12 +22,12 @@ import six
from tensorflow.contrib.eager.python import datasets
from tensorflow.contrib.eager.python import metrics
-from tensorflow.contrib.summary import summary_ops
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import summary_ops_v2 as summary_ops
class Evaluator(object):
diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py
index 2f2347736a..907f9204c2 100644
--- a/tensorflow/contrib/eager/python/metrics_impl.py
+++ b/tensorflow/contrib/eager/python/metrics_impl.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import re
-from tensorflow.contrib.summary import summary_ops
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import dtypes
@@ -29,6 +28,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import checkpointable
diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py
index 15ac889191..28f5f286eb 100644
--- a/tensorflow/contrib/eager/python/metrics_test.py
+++ b/tensorflow/contrib/eager/python/metrics_test.py
@@ -23,7 +23,6 @@ import tempfile
from tensorflow.contrib.eager.python import checkpointable_utils
from tensorflow.contrib.eager.python import metrics
-from tensorflow.contrib.summary import summary_ops
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.eager import context
from tensorflow.python.eager import test
@@ -31,6 +30,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.training import training_util
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD
index fda1367b15..f88b03ec4c 100644
--- a/tensorflow/contrib/summary/BUILD
+++ b/tensorflow/contrib/summary/BUILD
@@ -15,7 +15,6 @@ py_test(
srcs = ["summary_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":summary_ops",
":summary_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:errors",
@@ -23,6 +22,7 @@ py_test(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform",
"//tensorflow/python:state_ops",
+ "//tensorflow/python:summary_ops_v2",
"//tensorflow/python:training",
"//tensorflow/python/eager:function",
"//tensorflow/python/eager:test",
@@ -35,7 +35,6 @@ py_test(
srcs = ["summary_ops_graph_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":summary_ops",
":summary_test_util",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
@@ -44,31 +43,9 @@ py_test(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:summary_ops_v2",
"//tensorflow/python:training",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "summary_ops",
- srcs = ["summary_ops.py"],
- srcs_version = "PY2AND3",
- visibility = ["//tensorflow:internal"],
- deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:layers_base",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:summary_op_util",
- "//tensorflow/python:summary_ops_gen",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python/eager:context",
+ "//tensorflow/python:variables",
"@six_archive//:six",
],
)
@@ -79,7 +56,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
- ":summary_ops",
+ "//tensorflow/python:summary_ops_v2",
],
)
@@ -92,8 +69,10 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python:framework_test_lib",
"//tensorflow/python:lib",
"//tensorflow/python:platform",
+ "//tensorflow/python:summary_ops_v2",
"@org_sqlite//:python",
],
)
diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py
index 2d6d7ea6a3..99ced53e11 100644
--- a/tensorflow/contrib/summary/summary.py
+++ b/tensorflow/contrib/summary/summary.py
@@ -61,23 +61,23 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
-from tensorflow.contrib.summary.summary_ops import all_summary_ops
-from tensorflow.contrib.summary.summary_ops import always_record_summaries
-from tensorflow.contrib.summary.summary_ops import audio
-from tensorflow.contrib.summary.summary_ops import create_db_writer
-from tensorflow.contrib.summary.summary_ops import create_file_writer
-from tensorflow.contrib.summary.summary_ops import create_summary_file_writer
-from tensorflow.contrib.summary.summary_ops import eval_dir
-from tensorflow.contrib.summary.summary_ops import flush
-from tensorflow.contrib.summary.summary_ops import generic
-from tensorflow.contrib.summary.summary_ops import graph
-from tensorflow.contrib.summary.summary_ops import histogram
-from tensorflow.contrib.summary.summary_ops import image
-from tensorflow.contrib.summary.summary_ops import import_event
-from tensorflow.contrib.summary.summary_ops import initialize
-from tensorflow.contrib.summary.summary_ops import never_record_summaries
-from tensorflow.contrib.summary.summary_ops import record_summaries_every_n_global_steps
-from tensorflow.contrib.summary.summary_ops import scalar
-from tensorflow.contrib.summary.summary_ops import should_record_summaries
-from tensorflow.contrib.summary.summary_ops import summary_writer_initializer_op
-from tensorflow.contrib.summary.summary_ops import SummaryWriter
+from tensorflow.python.ops.summary_ops_v2 import all_summary_ops
+from tensorflow.python.ops.summary_ops_v2 import always_record_summaries
+from tensorflow.python.ops.summary_ops_v2 import audio
+from tensorflow.python.ops.summary_ops_v2 import create_db_writer
+from tensorflow.python.ops.summary_ops_v2 import create_file_writer
+from tensorflow.python.ops.summary_ops_v2 import create_summary_file_writer
+from tensorflow.python.ops.summary_ops_v2 import eval_dir
+from tensorflow.python.ops.summary_ops_v2 import flush
+from tensorflow.python.ops.summary_ops_v2 import generic
+from tensorflow.python.ops.summary_ops_v2 import graph
+from tensorflow.python.ops.summary_ops_v2 import histogram
+from tensorflow.python.ops.summary_ops_v2 import image
+from tensorflow.python.ops.summary_ops_v2 import import_event
+from tensorflow.python.ops.summary_ops_v2 import initialize
+from tensorflow.python.ops.summary_ops_v2 import never_record_summaries
+from tensorflow.python.ops.summary_ops_v2 import record_summaries_every_n_global_steps
+from tensorflow.python.ops.summary_ops_v2 import scalar
+from tensorflow.python.ops.summary_ops_v2 import should_record_summaries
+from tensorflow.python.ops.summary_ops_v2 import summary_writer_initializer_op
+from tensorflow.python.ops.summary_ops_v2 import SummaryWriter
diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py
index 3aba04540e..ae8336daaf 100644
--- a/tensorflow/contrib/summary/summary_ops_graph_test.py
+++ b/tensorflow/contrib/summary/summary_ops_graph_test.py
@@ -16,27 +16,220 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
import tempfile
+import time
import six
-from tensorflow.contrib.summary import summary_ops
from tensorflow.contrib.summary import summary_test_util
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import summary_ops_v2 as summary_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.training import training_util
get_all = summary_test_util.get_all
-class DbTest(summary_test_util.SummaryDbTest):
+class GraphFileTest(test_util.TensorFlowTestCase):
+
+ def testSummaryOps(self):
+ logdir = self.get_temp_dir()
+ writer = summary_ops.create_file_writer(logdir, max_queue=0)
+ with writer.as_default(), summary_ops.always_record_summaries():
+ summary_ops.generic('tensor', 1, step=1)
+ summary_ops.scalar('scalar', 2.0, step=1)
+ summary_ops.histogram('histogram', [1.0], step=1)
+ summary_ops.image('image', [[[[1.0]]]], step=1)
+ summary_ops.audio('audio', [[1.0]], 1.0, 1, step=1)
+ with self.test_session() as sess:
+ sess.run(summary_ops.summary_writer_initializer_op())
+ sess.run(summary_ops.all_summary_ops())
+ # The working condition of the ops is tested in the C++ test so we just
+ # test here that we're calling them correctly.
+ self.assertTrue(gfile.Exists(logdir))
+
+ def testSummaryName(self):
+ logdir = self.get_temp_dir()
+ writer = summary_ops.create_file_writer(logdir, max_queue=0)
+ with writer.as_default(), summary_ops.always_record_summaries():
+ summary_ops.scalar('scalar', 2.0, step=1)
+ with self.test_session() as sess:
+ sess.run(summary_ops.summary_writer_initializer_op())
+ sess.run(summary_ops.all_summary_ops())
+ events = summary_test_util.events_from_logdir(logdir)
+ self.assertEqual(2, len(events))
+ self.assertEqual('scalar', events[1].summary.value[0].tag)
+
+ def testSummaryNameScope(self):
+ logdir = self.get_temp_dir()
+ writer = summary_ops.create_file_writer(logdir, max_queue=0)
+ with writer.as_default(), summary_ops.always_record_summaries():
+ with ops.name_scope('scope'):
+ summary_ops.scalar('scalar', 2.0, step=1)
+ with self.test_session() as sess:
+ sess.run(summary_ops.summary_writer_initializer_op())
+ sess.run(summary_ops.all_summary_ops())
+ events = summary_test_util.events_from_logdir(logdir)
+ self.assertEqual(2, len(events))
+ self.assertEqual('scope/scalar', events[1].summary.value[0].tag)
+
+ def testSummaryGlobalStep(self):
+ training_util.get_or_create_global_step()
+ logdir = self.get_temp_dir()
+ writer = summary_ops.create_file_writer(logdir, max_queue=0)
+ with writer.as_default(), summary_ops.always_record_summaries():
+ summary_ops.scalar('scalar', 2.0)
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(summary_ops.summary_writer_initializer_op())
+ step, _ = sess.run(
+ [training_util.get_global_step(), summary_ops.all_summary_ops()])
+ events = summary_test_util.events_from_logdir(logdir)
+ self.assertEqual(2, len(events))
+ self.assertEqual(step, events[1].step)
+
+ def testMaxQueue(self):
+ logdir = self.get_temp_dir()
+ writer = summary_ops.create_file_writer(
+ logdir, max_queue=1, flush_millis=999999)
+ with writer.as_default(), summary_ops.always_record_summaries():
+ summary_ops.scalar('scalar', 2.0, step=1)
+ with self.test_session() as sess:
+ sess.run(summary_ops.summary_writer_initializer_op())
+ get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
+ # Note: First tf.Event is always file_version.
+ self.assertEqual(1, get_total())
+ sess.run(summary_ops.all_summary_ops())
+ self.assertEqual(1, get_total())
+ # Should flush after second summary since max_queue = 1
+ sess.run(summary_ops.all_summary_ops())
+ self.assertEqual(3, get_total())
+
+ def testFlushFunction(self):
+ logdir = self.get_temp_dir()
+ writer = summary_ops.create_file_writer(
+ logdir, max_queue=999999, flush_millis=999999)
+ with writer.as_default(), summary_ops.always_record_summaries():
+ summary_ops.scalar('scalar', 2.0, step=1)
+ flush_op = summary_ops.flush()
+ with self.test_session() as sess:
+ sess.run(summary_ops.summary_writer_initializer_op())
+ get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
+ # Note: First tf.Event is always file_version.
+ self.assertEqual(1, get_total())
+ sess.run(summary_ops.all_summary_ops())
+ self.assertEqual(1, get_total())
+ sess.run(flush_op)
+ self.assertEqual(2, get_total())
+ # Test "writer" parameter
+ sess.run(summary_ops.all_summary_ops())
+ sess.run(summary_ops.flush(writer=writer))
+ self.assertEqual(3, get_total())
+ sess.run(summary_ops.all_summary_ops())
+ sess.run(summary_ops.flush(writer=writer._resource)) # pylint:disable=protected-access
+ self.assertEqual(4, get_total())
+
+ def testSharedName(self):
+ logdir = self.get_temp_dir()
+ with summary_ops.always_record_summaries():
+ # Create with default shared name (should match logdir)
+ writer1 = summary_ops.create_file_writer(logdir)
+ with writer1.as_default():
+ summary_ops.scalar('one', 1.0, step=1)
+ # Create with explicit logdir shared name (should be same resource/file)
+ shared_name = 'logdir:' + logdir
+ writer2 = summary_ops.create_file_writer(logdir, name=shared_name)
+ with writer2.as_default():
+ summary_ops.scalar('two', 2.0, step=2)
+ # Create with different shared name (should be separate resource/file)
+ writer3 = summary_ops.create_file_writer(logdir, name='other')
+ with writer3.as_default():
+ summary_ops.scalar('three', 3.0, step=3)
+
+ with self.test_session() as sess:
+ # Run init ops across writers sequentially to avoid race condition.
+ # TODO(nickfelt): fix race condition in resource manager lookup or create
+ sess.run(writer1.init())
+ sess.run(writer2.init())
+ time.sleep(1.1) # Ensure filename has a different timestamp
+ sess.run(writer3.init())
+ sess.run(summary_ops.all_summary_ops())
+ sess.run([writer1.flush(), writer2.flush(), writer3.flush()])
+
+ event_files = iter(sorted(gfile.Glob(os.path.join(logdir, '*tfevents*'))))
+
+ # First file has tags "one" and "two"
+ events = summary_test_util.events_from_file(next(event_files))
+ self.assertEqual('brain.Event:2', events[0].file_version)
+ tags = [e.summary.value[0].tag for e in events[1:]]
+ self.assertItemsEqual(['one', 'two'], tags)
+
+ # Second file has tag "three"
+ events = summary_test_util.events_from_file(next(event_files))
+ self.assertEqual('brain.Event:2', events[0].file_version)
+ tags = [e.summary.value[0].tag for e in events[1:]]
+ self.assertItemsEqual(['three'], tags)
+
+ # No more files
+ self.assertRaises(StopIteration, lambda: next(event_files))
+
+ def testWriterInitAndClose(self):
+ logdir = self.get_temp_dir()
+ with summary_ops.always_record_summaries():
+ writer = summary_ops.create_file_writer(
+ logdir, max_queue=100, flush_millis=1000000)
+ with writer.as_default():
+ summary_ops.scalar('one', 1.0, step=1)
+ with self.test_session() as sess:
+ sess.run(summary_ops.summary_writer_initializer_op())
+ get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
+ self.assertEqual(1, get_total()) # file_version Event
+ # Running init() again while writer is open has no effect
+ sess.run(writer.init())
+ self.assertEqual(1, get_total())
+ sess.run(summary_ops.all_summary_ops())
+ self.assertEqual(1, get_total())
+ # Running close() should do an implicit flush
+ sess.run(writer.close())
+ self.assertEqual(2, get_total())
+ # Running init() on a closed writer should start a new file
+ time.sleep(1.1) # Ensure filename has a different timestamp
+ sess.run(writer.init())
+ sess.run(summary_ops.all_summary_ops())
+ sess.run(writer.close())
+ files = sorted(gfile.Glob(os.path.join(logdir, '*tfevents*')))
+ self.assertEqual(2, len(files))
+ self.assertEqual(2, len(summary_test_util.events_from_file(files[1])))
+
+ def testWriterFlush(self):
+ logdir = self.get_temp_dir()
+ with summary_ops.always_record_summaries():
+ writer = summary_ops.create_file_writer(
+ logdir, max_queue=100, flush_millis=1000000)
+ with writer.as_default():
+ summary_ops.scalar('one', 1.0, step=1)
+ with self.test_session() as sess:
+ sess.run(summary_ops.summary_writer_initializer_op())
+ get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
+ self.assertEqual(1, get_total()) # file_version Event
+ sess.run(summary_ops.all_summary_ops())
+ self.assertEqual(1, get_total())
+ sess.run(writer.flush())
+ self.assertEqual(2, get_total())
+
+
+class GraphDbTest(summary_test_util.SummaryDbTest):
def testGraphPassedToGraph_isForbiddenForThineOwnSafety(self):
with self.assertRaises(TypeError):
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index c756f8b270..f1ef218e74 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -16,12 +16,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
import tempfile
+import time
import numpy as np
import six
-from tensorflow.contrib.summary import summary_ops
from tensorflow.contrib.summary import summary_test_util
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
@@ -33,6 +34,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.platform import gfile
from tensorflow.python.training import training_util
@@ -57,7 +59,7 @@ _NUMPY_NUMERIC_TYPES = {
}
-class TargetTest(test_util.TensorFlowTestCase):
+class EagerFileTest(test_util.TensorFlowTestCase):
def testShouldRecordSummary(self):
self.assertFalse(summary_ops.should_record_summaries())
@@ -138,21 +140,22 @@ class TargetTest(test_util.TensorFlowTestCase):
def testMaxQueue(self):
logs = tempfile.mkdtemp()
with summary_ops.create_file_writer(
- logs, max_queue=2, flush_millis=999999,
+ logs, max_queue=1, flush_millis=999999,
name='lol').as_default(), summary_ops.always_record_summaries():
get_total = lambda: len(summary_test_util.events_from_logdir(logs))
# Note: First tf.Event is always file_version.
self.assertEqual(1, get_total())
summary_ops.scalar('scalar', 2.0, step=1)
self.assertEqual(1, get_total())
+ # Should flush after second summary since max_queue = 1
summary_ops.scalar('scalar', 2.0, step=2)
self.assertEqual(3, get_total())
- def testFlush(self):
+ def testFlushFunction(self):
logs = tempfile.mkdtemp()
- with summary_ops.create_file_writer(
- logs, max_queue=999999, flush_millis=999999,
- name='lol').as_default(), summary_ops.always_record_summaries():
+ writer = summary_ops.create_file_writer(
+ logs, max_queue=999999, flush_millis=999999, name='lol')
+ with writer.as_default(), summary_ops.always_record_summaries():
get_total = lambda: len(summary_test_util.events_from_logdir(logs))
# Note: First tf.Event is always file_version.
self.assertEqual(1, get_total())
@@ -161,9 +164,103 @@ class TargetTest(test_util.TensorFlowTestCase):
self.assertEqual(1, get_total())
summary_ops.flush()
self.assertEqual(3, get_total())
+ # Test "writer" parameter
+ summary_ops.scalar('scalar', 2.0, step=3)
+ summary_ops.flush(writer=writer)
+ self.assertEqual(4, get_total())
+ summary_ops.scalar('scalar', 2.0, step=4)
+ summary_ops.flush(writer=writer._resource) # pylint:disable=protected-access
+ self.assertEqual(5, get_total())
+
+ def testSharedName(self):
+ logdir = self.get_temp_dir()
+ with summary_ops.always_record_summaries():
+ # Create with default shared name (should match logdir)
+ writer1 = summary_ops.create_file_writer(logdir)
+ with writer1.as_default():
+ summary_ops.scalar('one', 1.0, step=1)
+ summary_ops.flush()
+ # Create with explicit logdir shared name (should be same resource/file)
+ shared_name = 'logdir:' + logdir
+ writer2 = summary_ops.create_file_writer(logdir, name=shared_name)
+ with writer2.as_default():
+ summary_ops.scalar('two', 2.0, step=2)
+ summary_ops.flush()
+ # Create with different shared name (should be separate resource/file)
+ time.sleep(1.1) # Ensure filename has a different timestamp
+ writer3 = summary_ops.create_file_writer(logdir, name='other')
+ with writer3.as_default():
+ summary_ops.scalar('three', 3.0, step=3)
+ summary_ops.flush()
+
+ event_files = iter(sorted(gfile.Glob(os.path.join(logdir, '*tfevents*'))))
+
+ # First file has tags "one" and "two"
+ events = iter(summary_test_util.events_from_file(next(event_files)))
+ self.assertEqual('brain.Event:2', next(events).file_version)
+ self.assertEqual('one', next(events).summary.value[0].tag)
+ self.assertEqual('two', next(events).summary.value[0].tag)
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ # Second file has tag "three"
+ events = iter(summary_test_util.events_from_file(next(event_files)))
+ self.assertEqual('brain.Event:2', next(events).file_version)
+ self.assertEqual('three', next(events).summary.value[0].tag)
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ # No more files
+ self.assertRaises(StopIteration, lambda: next(event_files))
+
+ def testWriterInitAndClose(self):
+ logdir = self.get_temp_dir()
+ get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
+ with summary_ops.always_record_summaries():
+ writer = summary_ops.create_file_writer(
+ logdir, max_queue=100, flush_millis=1000000)
+ self.assertEqual(1, get_total()) # file_version Event
+ # Calling init() again while writer is open has no effect
+ writer.init()
+ self.assertEqual(1, get_total())
+ try:
+ # Not using .as_default() to avoid implicit flush when exiting
+ writer.set_as_default()
+ summary_ops.scalar('one', 1.0, step=1)
+ self.assertEqual(1, get_total())
+ # Calling .close() should do an implicit flush
+ writer.close()
+ self.assertEqual(2, get_total())
+ # Calling init() on a closed writer should start a new file
+ time.sleep(1.1) # Ensure filename has a different timestamp
+ writer.init()
+ files = sorted(gfile.Glob(os.path.join(logdir, '*tfevents*')))
+ self.assertEqual(2, len(files))
+ get_total = lambda: len(summary_test_util.events_from_file(files[1]))
+ self.assertEqual(1, get_total()) # file_version Event
+ summary_ops.scalar('two', 2.0, step=2)
+ writer.close()
+ self.assertEqual(2, get_total())
+ finally:
+ # Clean up by resetting default writer
+ summary_ops.create_file_writer(None).set_as_default()
+
+ def testWriterFlush(self):
+ logdir = self.get_temp_dir()
+ get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
+ with summary_ops.always_record_summaries():
+ writer = summary_ops.create_file_writer(
+ logdir, max_queue=100, flush_millis=1000000)
+ self.assertEqual(1, get_total()) # file_version Event
+ with writer.as_default():
+ summary_ops.scalar('one', 1.0, step=1)
+ self.assertEqual(1, get_total())
+ writer.flush()
+ self.assertEqual(2, get_total())
+ summary_ops.scalar('two', 2.0, step=2)
+ # Exiting the "as_default()" should do an implicit flush of the "two" tag
+ self.assertEqual(3, get_total())
-class DbTest(summary_test_util.SummaryDbTest):
+class EagerDbTest(summary_test_util.SummaryDbTest):
def testIntegerSummaries(self):
step = training_util.create_global_step()
diff --git a/tensorflow/contrib/summary/summary_test_internal.py b/tensorflow/contrib/summary/summary_test_internal.py
deleted file mode 100644
index d0d3384735..0000000000
--- a/tensorflow/contrib/summary/summary_test_internal.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Internal helpers for tests in this directory."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import functools
-import os
-
-import sqlite3
-
-from tensorflow.contrib.summary import summary_ops
-from tensorflow.python.framework import test_util
-
-
-class SummaryDbTest(test_util.TensorFlowTestCase):
- """Helper for summary database testing."""
-
- def setUp(self):
- super(SummaryDbTest, self).setUp()
- self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite')
- if os.path.exists(self.db_path):
- os.unlink(self.db_path)
- self.db = sqlite3.connect(self.db_path)
- self.create_db_writer = functools.partial(
- summary_ops.create_db_writer,
- db_uri=self.db_path,
- experiment_name='experiment',
- run_name='run',
- user_name='user')
-
- def tearDown(self):
- self.db.close()
- super(SummaryDbTest, self).tearDown()
-
-
-def get_one(db, q, *p):
- return db.execute(q, p).fetchone()[0]
-
-
-def get_all(db, q, *p):
- return unroll(db.execute(q, p).fetchall())
-
-
-def unroll(list_of_tuples):
- return sum(list_of_tuples, ())
diff --git a/tensorflow/contrib/summary/summary_test_util.py b/tensorflow/contrib/summary/summary_test_util.py
index 8506c4be9c..b4ae43302c 100644
--- a/tensorflow/contrib/summary/summary_test_util.py
+++ b/tensorflow/contrib/summary/summary_test_util.py
@@ -24,10 +24,10 @@ import os
import sqlite3
-from tensorflow.contrib.summary import summary_ops
from tensorflow.core.util import event_pb2
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import tf_record
+from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.platform import gfile
diff --git a/tensorflow/contrib/tensorboard/db/summary_file_writer.cc b/tensorflow/contrib/tensorboard/db/summary_file_writer.cc
index 85b3e7231b..3f24f58f03 100644
--- a/tensorflow/contrib/tensorboard/db/summary_file_writer.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_file_writer.cc
@@ -132,7 +132,7 @@ class SummaryFileWriter : public SummaryWriterInterface {
Status WriteEvent(std::unique_ptr<Event> event) override {
mutex_lock ml(mu_);
queue_.emplace_back(std::move(event));
- if (queue_.size() >= max_queue_ ||
+ if (queue_.size() > max_queue_ ||
env_->NowMicros() - last_flush_ > 1000 * flush_millis_) {
return InternalFlush();
}
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 2f4a76720d..3e489d38b6 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -46,7 +46,6 @@ py_library(
deps = [
":tpu_lib",
":tpu_py",
- "//tensorflow/contrib/summary:summary_ops",
"//tensorflow/contrib/training:training_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
@@ -57,6 +56,7 @@ py_library(
"//tensorflow/python:platform",
"//tensorflow/python:state_ops",
"//tensorflow/python:summary",
+ "//tensorflow/python:summary_ops_v2",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 1332108d04..7fab19afee 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -30,7 +30,6 @@ import six
from six.moves import queue as Queue # pylint: disable=redefined-builtin
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.contrib.summary import summary_ops as contrib_summary
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_config
@@ -57,6 +56,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import summary_ops_v2 as contrib_summary
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 7b548d2c70..9707b370c0 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2550,6 +2550,30 @@ py_library(
)
py_library(
+ name = "summary_ops_v2",
+ srcs = ["ops/summary_ops_v2.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":array_ops",
+ ":constant_op",
+ ":control_flow_ops",
+ ":dtypes",
+ ":framework_ops",
+ ":math_ops",
+ ":resource_variable_ops",
+ ":smart_cond",
+ ":summary_op_util",
+ ":summary_ops_gen",
+ ":training_util",
+ ":util",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/eager:context",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
name = "template",
srcs = ["ops/template.py"],
srcs_version = "PY2AND3",
@@ -2911,7 +2935,10 @@ py_library(
name = "training",
srcs = glob(
["training/**/*.py"],
- exclude = ["**/*test*"],
+ exclude = [
+ "**/*test*",
+ "training/training_util.py", # See :training_util
+ ],
),
srcs_version = "PY2AND3",
deps = [
@@ -2945,6 +2972,7 @@ py_library(
":string_ops",
":summary",
":training_ops_gen",
+ ":training_util",
":util",
":variable_scope",
":variables",
@@ -4194,6 +4222,25 @@ py_test(
],
)
+py_library(
+ name = "training_util",
+ srcs = ["training/training_util.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dtypes",
+ ":framework",
+ ":framework_ops",
+ ":init_ops",
+ ":platform",
+ ":resource_variable_ops",
+ ":state_ops",
+ ":util",
+ ":variable_scope",
+ ":variables",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
py_test(
name = "training_util_test",
size = "small",
@@ -4204,6 +4251,7 @@ py_test(
":framework",
":platform",
":training",
+ ":training_util",
":variables",
],
)
@@ -4248,6 +4296,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
+ ":client",
":constant_op",
":errors",
":framework",
@@ -4260,6 +4309,7 @@ py_library(
":summary_op_util",
":summary_ops",
":summary_ops_gen",
+ ":summary_ops_v2",
":util",
"//tensorflow/python/eager:context",
"//third_party/py/numpy",
@@ -4286,7 +4336,7 @@ py_tests(
":platform",
":platform_test",
":summary",
- ":training",
+ ":summary_ops_v2",
"//tensorflow/core:protos_all_py",
],
)
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/python/ops/summary_ops_v2.py
index bc763fe655..12f361c513 100644
--- a/tensorflow/contrib/summary/summary_ops.py
+++ b/tensorflow/python/ops/summary_ops_v2.py
@@ -31,7 +31,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.layers import utils
+from tensorflow.python.framework import smart_cond
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_summary_ops
@@ -108,8 +108,10 @@ class SummaryWriter(object):
- @{tf.contrib.summary.create_db_writer}
"""
- def __init__(self, resource):
+ def __init__(self, resource, init_op_fn):
self._resource = resource
+ # TODO(nickfelt): cache constructed ops in graph mode
+ self._init_op_fn = init_op_fn
if context.executing_eagerly() and self._resource is not None:
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
handle=self._resource, handle_device="cpu:0")
@@ -129,10 +131,32 @@ class SummaryWriter(object):
yield self
# Flushes the summary writer in eager mode or in graph functions, but not
# in legacy graph mode (you're on your own there).
- with ops.device("cpu:0"):
- gen_summary_ops.flush_summary_writer(self._resource)
+ self.flush()
context.context().summary_writer_resource = old
+ def init(self):
+ """Operation to initialize the summary writer resource."""
+ if self._resource is not None:
+ return self._init_op_fn()
+
+ def _flush(self):
+ return _flush_fn(writer=self)
+
+ def flush(self):
+ """Operation to force the summary writer to flush any buffered data."""
+ if self._resource is not None:
+ return self._flush()
+
+ def _close(self):
+ with ops.control_dependencies([self.flush()]):
+ with ops.device("cpu:0"):
+ return gen_summary_ops.close_summary_writer(self._resource)
+
+ def close(self):
+ """Operation to flush and close the summary writer resource."""
+ if self._resource is not None:
+ return self._close()
+
def initialize(
graph=None, # pylint: disable=redefined-outer-name
@@ -178,7 +202,7 @@ def create_file_writer(logdir,
flush_millis=None,
filename_suffix=None,
name=None):
- """Creates a summary file writer in the current context.
+ """Creates a summary file writer in the current context under the given name.
Args:
logdir: a string, or None. If a string, creates a summary file writer
@@ -186,18 +210,20 @@ def create_file_writer(logdir,
a mock object which acts like a summary writer but does nothing,
useful to use as a context manager.
max_queue: the largest number of summaries to keep in a queue; will
- flush once the queue gets bigger than this.
- flush_millis: the largest interval between flushes.
- filename_suffix: optional suffix for the event file name.
+ flush once the queue gets bigger than this. Defaults to 10.
+ flush_millis: the largest interval between flushes. Defaults to 120,000.
+ filename_suffix: optional suffix for the event file name. Defaults to `.v2`.
name: Shared name for this SummaryWriter resource stored to default
- Graph.
+ Graph. Defaults to the provided logdir prefixed with `logdir:`. Note: if a
+ summary writer resource with this shared name already exists, the returned
+ SummaryWriter wraps that resource and the other arguments have no effect.
Returns:
Either a summary writer or an empty object which can be used as a
summary writer.
"""
if logdir is None:
- return SummaryWriter(None)
+ return SummaryWriter(None, None)
with ops.device("cpu:0"):
if max_queue is None:
max_queue = constant_op.constant(10)
@@ -205,6 +231,8 @@ def create_file_writer(logdir,
flush_millis = constant_op.constant(2 * 60 * 1000)
if filename_suffix is None:
filename_suffix = constant_op.constant(".v2")
+ if name is None:
+ name = "logdir:" + logdir
return _make_summary_writer(
name,
gen_summary_ops.create_summary_file_writer,
@@ -267,13 +295,12 @@ def create_db_writer(db_uri,
def _make_summary_writer(name, factory, **kwargs):
resource = gen_summary_ops.summary_writer(shared_name=name)
+ init_op_fn = lambda: factory(resource, **kwargs)
# TODO(apassos): Consider doing this instead.
- # node = factory(resource, **kwargs)
# if not context.executing_eagerly():
- # ops.get_default_session().run(node)
- ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME,
- factory(resource, **kwargs))
- return SummaryWriter(resource)
+ # ops.get_default_session().run(init_op)
+ ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, init_op_fn())
+ return SummaryWriter(resource, init_op_fn)
def _cleanse_string(name, pattern, value):
@@ -341,7 +368,7 @@ def summary_writer_function(name, tensor, function, family=None):
if context.context().summary_writer_resource is None:
return control_flow_ops.no_op()
with ops.device("cpu:0"):
- op = utils.smart_cond(
+ op = smart_cond.smart_cond(
should_record_summaries(), record, _nothing, name="")
ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access
return op
@@ -538,7 +565,14 @@ def flush(writer=None, name=None):
writer = context.context().summary_writer_resource
if writer is None:
return control_flow_ops.no_op()
- return gen_summary_ops.flush_summary_writer(writer, name=name)
+ else:
+ if isinstance(writer, SummaryWriter):
+ writer = writer._resource # pylint: disable=protected-access
+ with ops.device("cpu:0"):
+ return gen_summary_ops.flush_summary_writer(writer, name=name)
+
+
+_flush_fn = flush # for within SummaryWriter.flush()
def eval_dir(model_dir, name=None):
diff --git a/tensorflow/python/summary/writer/event_file_writer_v2.py b/tensorflow/python/summary/writer/event_file_writer_v2.py
new file mode 100644
index 0000000000..5c66c0f7a8
--- /dev/null
+++ b/tensorflow/python/summary/writer/event_file_writer_v2.py
@@ -0,0 +1,140 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Writes events to disk in a logdir."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import summary_ops_v2
+from tensorflow.python.platform import gfile
+
+
+class EventFileWriterV2(object):
+ """Writes `Event` protocol buffers to an event file via the graph.
+
+ The `EventFileWriterV2` class is backed by the summary file writer in the v2
+ summary API (currently in tf.contrib.summary), so it uses a shared summary
+ writer resource and graph ops to write events.
+
+ As with the original EventFileWriter, this class will asynchronously write
+ Event protocol buffers to the backing file. The Event file is encoded using
+ the tfrecord format, which is similar to RecordIO.
+ """
+
+ def __init__(self, session, logdir, max_queue=10, flush_secs=120,
+ filename_suffix=''):
+ """Creates an `EventFileWriterV2` and an event file to write to.
+
+ On construction, this calls `tf.contrib.summary.create_file_writer` within
+ the graph from `session.graph` to look up a shared summary writer resource
+ for `logdir` if one exists, and create one if not. Creating the summary
+ writer resource in turn creates a new event file in `logdir` to be filled
+ with `Event` protocol buffers passed to `add_event`. Graph ops to control
+ this writer resource are added to `session.graph` during this init call;
+ stateful methods on this class will call `session.run()` on these ops.
+
+ Note that because the underlying resource is shared, it is possible that
+ other parts of the code using the same session may interact independently
+ with the resource, e.g. by flushing or even closing it. It is the caller's
+ responsibility to avoid any undesirable sharing in this regard.
+
+ The remaining arguments to the constructor (`flush_secs`, `max_queue`, and
+ `filename_suffix`) control the construction of the shared writer resource
+ if one is created. If an existing resource is reused, these arguments have
+ no effect. See `tf.contrib.summary.create_file_writer` for details.
+
+ Args:
+ session: A `tf.Session`. Session that will hold shared writer resource.
+ The writer ops will be added to session.graph during this init call.
+ logdir: A string. Directory where event file will be written.
+ max_queue: Integer. Size of the queue for pending events and summaries.
+ flush_secs: Number. How often, in seconds, to flush the
+ pending events and summaries to disk.
+ filename_suffix: A string. Every event file's name is suffixed with
+ `filename_suffix`.
+ """
+ self._session = session
+ self._logdir = logdir
+ self._closed = False
+ if not gfile.IsDirectory(self._logdir):
+ gfile.MakeDirs(self._logdir)
+
+ with self._session.graph.as_default():
+ with ops.name_scope('filewriter'):
+ file_writer = summary_ops_v2.create_file_writer(
+ logdir=self._logdir,
+ max_queue=max_queue,
+ flush_millis=flush_secs * 1000,
+ filename_suffix=filename_suffix)
+ with summary_ops_v2.always_record_summaries(), file_writer.as_default():
+ self._event_placeholder = array_ops.placeholder_with_default(
+ constant_op.constant('unused', dtypes.string),
+ shape=[])
+ self._add_event_op = summary_ops_v2.import_event(
+ self._event_placeholder)
+ self._init_op = file_writer.init()
+ self._flush_op = file_writer.flush()
+ self._close_op = file_writer.close()
+ self._session.run(self._init_op)
+
+ def get_logdir(self):
+ """Returns the directory where event file will be written."""
+ return self._logdir
+
+ def reopen(self):
+ """Reopens the EventFileWriter.
+
+ Can be called after `close()` to add more events in the same directory.
+ The events will go into a new events file.
+
+ Does nothing if the EventFileWriter was not closed.
+ """
+ if self._closed:
+ self._closed = False
+ self._session.run(self._init_op)
+
+ def add_event(self, event):
+ """Adds an event to the event file.
+
+ Args:
+ event: An `Event` protocol buffer.
+ """
+ if not self._closed:
+ event_pb = event.SerializeToString()
+ self._session.run(
+ self._add_event_op, feed_dict={self._event_placeholder: event_pb})
+
+ def flush(self):
+ """Flushes the event file to disk.
+
+ Call this method to make sure that all pending events have been written to
+ disk.
+ """
+ self._session.run(self._flush_op)
+
+ def close(self):
+ """Flushes the event file to disk and close the file.
+
+ Call this method when you do not need the summary writer anymore.
+ """
+ if not self._closed:
+ self.flush()
+ self._session.run(self._close_op)
+ self._closed = True
diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py
index 57f78c156b..aca084fc91 100644
--- a/tensorflow/python/summary/writer/writer.py
+++ b/tensorflow/python/summary/writer/writer.py
@@ -32,6 +32,7 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import plugin_asset
from tensorflow.python.summary.writer.event_file_writer import EventFileWriter
+from tensorflow.python.summary.writer.event_file_writer_v2 import EventFileWriterV2
from tensorflow.python.util.tf_export import tf_export
_PLUGINS_DIR = "plugins"
@@ -286,6 +287,11 @@ class FileWriter(SummaryToEventTransformer):
file contents asynchronously. This allows a training program to call methods
to add data to the file directly from the training loop, without slowing down
training.
+
+ When constructed with a `tf.Session` parameter, a `FileWriter` instead forms
+ a compatibility layer over new graph-based summaries (`tf.contrib.summary`)
+ to facilitate the use of new summary writing with pre-existing code that
+ expects a `FileWriter` instance.
"""
def __init__(self,
@@ -294,10 +300,11 @@ class FileWriter(SummaryToEventTransformer):
max_queue=10,
flush_secs=120,
graph_def=None,
- filename_suffix=None):
- """Creates a `FileWriter` and an event file.
+ filename_suffix=None,
+ session=None):
+ """Creates a `FileWriter`, optionally shared within the given session.
- On construction the summary writer creates a new event file in `logdir`.
+ Typically, constructing a file writer creates a new event file in `logdir`.
This event file will contain `Event` protocol buffers constructed when you
call one of the following functions: `add_summary()`, `add_session_log()`,
`add_event()`, or `add_graph()`.
@@ -317,13 +324,16 @@ class FileWriter(SummaryToEventTransformer):
writer = tf.summary.FileWriter(<some-directory>, sess.graph)
```
- The other arguments to the constructor control the asynchronous writes to
- the event file:
-
- * `flush_secs`: How often, in seconds, to flush the added summaries
- and events to disk.
- * `max_queue`: Maximum number of summaries or events pending to be
- written to disk before one of the 'add' calls block.
+ The `session` argument to the constructor makes the returned `FileWriter` a
+ a compatibility layer over new graph-based summaries (`tf.contrib.summary`).
+ Crucially, this means the underlying writer resource and events file will
+ be shared with any other `FileWriter` using the same `session` and `logdir`,
+ and with any `tf.contrib.summary.SummaryWriter` in this session using the
+ the same shared resource name (which by default scoped to the logdir). If
+ no such resource exists, one will be created using the remaining arguments
+ to this constructor, but if one already exists those arguments are ignored.
+ In either case, ops will be added to `session.graph` to control the
+ underlying file writer resource. See `tf.contrib.summary` for more details.
Args:
logdir: A string. Directory where event file will be written.
@@ -334,6 +344,7 @@ class FileWriter(SummaryToEventTransformer):
graph_def: DEPRECATED: Use the `graph` argument instead.
filename_suffix: A string. Every event file's name is suffixed with
`suffix`.
+ session: A `tf.Session` object. See details above.
Raises:
RuntimeError: If called with eager execution enabled.
@@ -347,9 +358,12 @@ class FileWriter(SummaryToEventTransformer):
raise RuntimeError(
"tf.summary.FileWriter is not compatible with eager execution. "
"Use tf.contrib.summary instead.")
-
- event_writer = EventFileWriter(logdir, max_queue, flush_secs,
- filename_suffix)
+ if session is not None:
+ event_writer = EventFileWriterV2(
+ session, logdir, max_queue, flush_secs, filename_suffix)
+ else:
+ event_writer = EventFileWriter(logdir, max_queue, flush_secs,
+ filename_suffix)
super(FileWriter, self).__init__(event_writer, graph, graph_def)
def __enter__(self):
diff --git a/tensorflow/python/summary/writer/writer_test.py b/tensorflow/python/summary/writer/writer_test.py
index 88ade0aac3..dc990c2602 100644
--- a/tensorflow/python/summary/writer/writer_test.py
+++ b/tensorflow/python/summary/writer/writer_test.py
@@ -29,10 +29,12 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.util import event_pb2
from tensorflow.core.util.event_pb2 import SessionLog
+from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
+from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary import plugin_asset
@@ -42,7 +44,10 @@ from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.util import compat
-class SummaryWriterTestCase(test.TestCase):
+class FileWriterTestCase(test.TestCase):
+
+ def _FileWriter(self, *args, **kwargs):
+ return writer.FileWriter(*args, **kwargs)
def _TestDir(self, test_name):
test_dir = os.path.join(self.get_temp_dir(), test_name)
@@ -96,7 +101,7 @@ class SummaryWriterTestCase(test.TestCase):
def testAddingSummaryGraphAndRunMetadata(self):
test_dir = self._CleanTestDir("basics")
- sw = writer.FileWriter(test_dir)
+ sw = self._FileWriter(test_dir)
sw.add_session_log(event_pb2.SessionLog(status=SessionLog.START), 1)
sw.add_summary(
@@ -171,7 +176,7 @@ class SummaryWriterTestCase(test.TestCase):
test_dir = self._CleanTestDir("basics_named_graph")
with ops.Graph().as_default() as g:
constant_op.constant([12], name="douze")
- sw = writer.FileWriter(test_dir, graph=g)
+ sw = self._FileWriter(test_dir, graph=g)
sw.close()
self._assertEventsWithGraph(test_dir, g, True)
@@ -179,7 +184,7 @@ class SummaryWriterTestCase(test.TestCase):
test_dir = self._CleanTestDir("basics_positional_graph")
with ops.Graph().as_default() as g:
constant_op.constant([12], name="douze")
- sw = writer.FileWriter(test_dir, g)
+ sw = self._FileWriter(test_dir, g)
sw.close()
self._assertEventsWithGraph(test_dir, g, True)
@@ -188,7 +193,7 @@ class SummaryWriterTestCase(test.TestCase):
with ops.Graph().as_default() as g:
constant_op.constant([12], name="douze")
gd = g.as_graph_def()
- sw = writer.FileWriter(test_dir, graph_def=gd)
+ sw = self._FileWriter(test_dir, graph_def=gd)
sw.close()
self._assertEventsWithGraph(test_dir, g, False)
@@ -197,7 +202,7 @@ class SummaryWriterTestCase(test.TestCase):
with ops.Graph().as_default() as g:
constant_op.constant([12], name="douze")
gd = g.as_graph_def()
- sw = writer.FileWriter(test_dir, gd)
+ sw = self._FileWriter(test_dir, gd)
sw.close()
self._assertEventsWithGraph(test_dir, g, False)
@@ -207,18 +212,18 @@ class SummaryWriterTestCase(test.TestCase):
with ops.Graph().as_default() as g:
constant_op.constant([12], name="douze")
gd = g.as_graph_def()
- sw = writer.FileWriter(test_dir, graph=g, graph_def=gd)
+ sw = self._FileWriter(test_dir, graph=g, graph_def=gd)
sw.close()
def testNeitherGraphNorGraphDef(self):
with self.assertRaises(TypeError):
test_dir = self._CleanTestDir("basics_string_instead_of_graph")
- sw = writer.FileWriter(test_dir, "string instead of graph object")
+ sw = self._FileWriter(test_dir, "string instead of graph object")
sw.close()
def testCloseAndReopen(self):
test_dir = self._CleanTestDir("close_and_reopen")
- sw = writer.FileWriter(test_dir)
+ sw = self._FileWriter(test_dir)
sw.add_session_log(event_pb2.SessionLog(status=SessionLog.START), 1)
sw.close()
# Sleep at least one second to make sure we get a new event file name.
@@ -261,7 +266,7 @@ class SummaryWriterTestCase(test.TestCase):
def testNonBlockingClose(self):
test_dir = self._CleanTestDir("non_blocking_close")
- sw = writer.FileWriter(test_dir)
+ sw = self._FileWriter(test_dir)
# Sleep 1.2 seconds to make sure event queue is empty.
time.sleep(1.2)
time_before_close = time.time()
@@ -270,7 +275,7 @@ class SummaryWriterTestCase(test.TestCase):
def testWithStatement(self):
test_dir = self._CleanTestDir("with_statement")
- with writer.FileWriter(test_dir) as sw:
+ with self._FileWriter(test_dir) as sw:
sw.add_session_log(event_pb2.SessionLog(status=SessionLog.START), 1)
event_paths = sorted(glob.glob(os.path.join(test_dir, "event*")))
self.assertEquals(1, len(event_paths))
@@ -280,7 +285,7 @@ class SummaryWriterTestCase(test.TestCase):
# protocol buffers correctly.
def testAddingSummariesFromSessionRunCalls(self):
test_dir = self._CleanTestDir("global_step")
- sw = writer.FileWriter(test_dir)
+ sw = self._FileWriter(test_dir)
with self.test_session():
i = constant_op.constant(1, dtype=dtypes.int32, shape=[])
l = constant_op.constant(2, dtype=dtypes.int64, shape=[])
@@ -327,7 +332,7 @@ class SummaryWriterTestCase(test.TestCase):
def testPluginMetadataStrippedFromSubsequentEvents(self):
test_dir = self._CleanTestDir("basics")
- sw = writer.FileWriter(test_dir)
+ sw = self._FileWriter(test_dir)
sw.add_session_log(event_pb2.SessionLog(status=SessionLog.START), 1)
@@ -386,7 +391,7 @@ class SummaryWriterTestCase(test.TestCase):
def testFileWriterWithSuffix(self):
test_dir = self._CleanTestDir("test_suffix")
- sw = writer.FileWriter(test_dir, filename_suffix="_test_suffix")
+ sw = self._FileWriter(test_dir, filename_suffix="_test_suffix")
for _ in range(10):
sw.add_summary(
summary_pb2.Summary(value=[
@@ -400,9 +405,178 @@ class SummaryWriterTestCase(test.TestCase):
for filename in event_filenames:
self.assertTrue(filename.endswith("_test_suffix"))
+ def testPluginAssetSerialized(self):
+ class ExamplePluginAsset(plugin_asset.PluginAsset):
+ plugin_name = "example"
+
+ def assets(self):
+ return {"foo.txt": "foo!", "bar.txt": "bar!"}
+
+ with ops.Graph().as_default() as g:
+ plugin_asset.get_plugin_asset(ExamplePluginAsset)
+
+ logdir = self.get_temp_dir()
+ fw = self._FileWriter(logdir)
+ fw.add_graph(g)
+ plugin_dir = os.path.join(logdir, writer._PLUGINS_DIR, "example")
+
+ with gfile.Open(os.path.join(plugin_dir, "foo.txt"), "r") as f:
+ content = f.read()
+ self.assertEqual(content, "foo!")
+
+ with gfile.Open(os.path.join(plugin_dir, "bar.txt"), "r") as f:
+ content = f.read()
+ self.assertEqual(content, "bar!")
-class SummaryWriterCacheTest(test.TestCase):
- """SummaryWriterCache tests."""
+
+class SessionBasedFileWriterTestCase(FileWriterTestCase):
+ """Tests for FileWriter behavior when passed a Session argument."""
+
+ def _FileWriter(self, *args, **kwargs):
+ if "session" not in kwargs:
+ # Pass in test_session() as the session. It will be cached during this
+ # test method invocation so that any other use of test_session() with no
+ # graph should result in re-using the same underlying Session.
+ with self.test_session() as sess:
+ kwargs["session"] = sess
+ return writer.FileWriter(*args, **kwargs)
+ return writer.FileWriter(*args, **kwargs)
+
+ def _createTaggedSummary(self, tag):
+ summary = summary_pb2.Summary()
+ summary.value.add(tag=tag)
+ return summary
+
+ def testSharing_withOtherSessionBasedFileWriters(self):
+ logdir = self.get_temp_dir()
+ with session.Session() as sess:
+ # Initial file writer
+ writer1 = writer.FileWriter(session=sess, logdir=logdir)
+ writer1.add_summary(self._createTaggedSummary("one"), 1)
+ writer1.flush()
+
+ # File writer, should share file with writer1
+ writer2 = writer.FileWriter(session=sess, logdir=logdir)
+ writer2.add_summary(self._createTaggedSummary("two"), 2)
+ writer2.flush()
+
+ # File writer with different logdir (shouldn't be in this logdir at all)
+ writer3 = writer.FileWriter(session=sess, logdir=logdir + "-other")
+ writer3.add_summary(self._createTaggedSummary("three"), 3)
+ writer3.flush()
+
+ # File writer in a different session (should be in separate file)
+ time.sleep(1.1) # Ensure filename has a different timestamp
+ with session.Session() as other_sess:
+ writer4 = writer.FileWriter(session=other_sess, logdir=logdir)
+ writer4.add_summary(self._createTaggedSummary("four"), 4)
+ writer4.flush()
+
+ # One more file writer, should share file with writer1
+ writer5 = writer.FileWriter(session=sess, logdir=logdir)
+ writer5.add_summary(self._createTaggedSummary("five"), 5)
+ writer5.flush()
+
+ event_paths = iter(sorted(glob.glob(os.path.join(logdir, "event*"))))
+
+ # First file should have tags "one", "two", and "five"
+ events = summary_iterator.summary_iterator(next(event_paths))
+ self.assertEqual("brain.Event:2", next(events).file_version)
+ self.assertEqual("one", next(events).summary.value[0].tag)
+ self.assertEqual("two", next(events).summary.value[0].tag)
+ self.assertEqual("five", next(events).summary.value[0].tag)
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ # Second file should have just "four"
+ events = summary_iterator.summary_iterator(next(event_paths))
+ self.assertEqual("brain.Event:2", next(events).file_version)
+ self.assertEqual("four", next(events).summary.value[0].tag)
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ # No more files
+ self.assertRaises(StopIteration, lambda: next(event_paths))
+
+ # Just check that the other logdir file exists to be sure we wrote it
+ self.assertTrue(glob.glob(os.path.join(logdir + "-other", "event*")))
+
+ def testSharing_withExplicitSummaryFileWriters(self):
+ logdir = self.get_temp_dir()
+ with session.Session() as sess:
+ # Initial file writer via FileWriter(session=?)
+ writer1 = writer.FileWriter(session=sess, logdir=logdir)
+ writer1.add_summary(self._createTaggedSummary("one"), 1)
+ writer1.flush()
+
+ # Next one via create_file_writer(), should use same file
+ writer2 = summary_ops_v2.create_file_writer(logdir=logdir)
+ with summary_ops_v2.always_record_summaries(), writer2.as_default():
+ summary2 = summary_ops_v2.scalar("two", 2.0, step=2)
+ sess.run(writer2.init())
+ sess.run(summary2)
+ sess.run(writer2.flush())
+
+ # Next has different shared name, should be in separate file
+ time.sleep(1.1) # Ensure filename has a different timestamp
+ writer3 = summary_ops_v2.create_file_writer(logdir=logdir, name="other")
+ with summary_ops_v2.always_record_summaries(), writer3.as_default():
+ summary3 = summary_ops_v2.scalar("three", 3.0, step=3)
+ sess.run(writer3.init())
+ sess.run(summary3)
+ sess.run(writer3.flush())
+
+ # Next uses a second session, should be in separate file
+ time.sleep(1.1) # Ensure filename has a different timestamp
+ with session.Session() as other_sess:
+ writer4 = summary_ops_v2.create_file_writer(logdir=logdir)
+ with summary_ops_v2.always_record_summaries(), writer4.as_default():
+ summary4 = summary_ops_v2.scalar("four", 4.0, step=4)
+ other_sess.run(writer4.init())
+ other_sess.run(summary4)
+ other_sess.run(writer4.flush())
+
+ # Next via FileWriter(session=?) uses same second session, should be in
+ # same separate file. (This checks sharing in the other direction)
+ writer5 = writer.FileWriter(session=other_sess, logdir=logdir)
+ writer5.add_summary(self._createTaggedSummary("five"), 5)
+ writer5.flush()
+
+ # One more via create_file_writer(), should use same file
+ writer6 = summary_ops_v2.create_file_writer(logdir=logdir)
+ with summary_ops_v2.always_record_summaries(), writer6.as_default():
+ summary6 = summary_ops_v2.scalar("six", 6.0, step=6)
+ sess.run(writer6.init())
+ sess.run(summary6)
+ sess.run(writer6.flush())
+
+ event_paths = iter(sorted(glob.glob(os.path.join(logdir, "event*"))))
+
+ # First file should have tags "one", "two", and "six"
+ events = summary_iterator.summary_iterator(next(event_paths))
+ self.assertEqual("brain.Event:2", next(events).file_version)
+ self.assertEqual("one", next(events).summary.value[0].tag)
+ self.assertEqual("two", next(events).summary.value[0].tag)
+ self.assertEqual("six", next(events).summary.value[0].tag)
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ # Second file should have just "three"
+ events = summary_iterator.summary_iterator(next(event_paths))
+ self.assertEqual("brain.Event:2", next(events).file_version)
+ self.assertEqual("three", next(events).summary.value[0].tag)
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ # Third file should have "four" and "five"
+ events = summary_iterator.summary_iterator(next(event_paths))
+ self.assertEqual("brain.Event:2", next(events).file_version)
+ self.assertEqual("four", next(events).summary.value[0].tag)
+ self.assertEqual("five", next(events).summary.value[0].tag)
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ # No more files
+ self.assertRaises(StopIteration, lambda: next(event_paths))
+
+
+class FileWriterCacheTest(test.TestCase):
+ """FileWriterCache tests."""
def _test_dir(self, test_name):
"""Create an empty dir to use for tests.
@@ -448,32 +622,5 @@ class SummaryWriterCacheTest(test.TestCase):
self.assertFalse(sw1 == sw2)
-class ExamplePluginAsset(plugin_asset.PluginAsset):
- plugin_name = "example"
-
- def assets(self):
- return {"foo.txt": "foo!", "bar.txt": "bar!"}
-
-
-class PluginAssetsTest(test.TestCase):
-
- def testPluginAssetSerialized(self):
- with ops.Graph().as_default() as g:
- plugin_asset.get_plugin_asset(ExamplePluginAsset)
-
- logdir = self.get_temp_dir()
- fw = writer.FileWriter(logdir)
- fw.add_graph(g)
- plugin_dir = os.path.join(logdir, writer._PLUGINS_DIR, "example")
-
- with gfile.Open(os.path.join(plugin_dir, "foo.txt"), "r") as f:
- content = f.read()
- self.assertEqual(content, "foo!")
-
- with gfile.Open(os.path.join(plugin_dir, "bar.txt"), "r") as f:
- content = f.read()
- self.assertEqual(content, "bar!")
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-file-writer.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-file-writer.pbtxt
index dcf747971b..6b65b0ace3 100644
--- a/tensorflow/tools/api/golden/tensorflow.summary.-file-writer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.summary.-file-writer.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'logdir\', \'graph\', \'max_queue\', \'flush_secs\', \'graph_def\', \'filename_suffix\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'120\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'logdir\', \'graph\', \'max_queue\', \'flush_secs\', \'graph_def\', \'filename_suffix\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'120\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_event"