aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/summary
diff options
context:
space:
mode:
authorGravatar Nick Felt <nickfelt@google.com>2018-07-24 09:49:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-24 09:53:22 -0700
commit568727eed199dba04e37f500265b50f96fed455e (patch)
tree999f31d1469b3b5f2dc12d5ca04061cfe6062faa /tensorflow/python/summary
parentf8bbd3ceb7e86b7595ba74a9a03cfc7c1be252a8 (diff)
Add v2 summary support to Estimator.train() and MonitoredSession hooks
This change makes Estimator.train() support v2 summaries (tf.contrib.summary.*) out-of-the-box, to match the support for v1 summaries. Estimator.train() will now handle the boilerplate necessary to initialize a file writer and enable summary writing every N steps, and will ensure that its own automatically exported summaries (for loss and global_step/sec) get written to the same underlying events file. As part of this change, tf.train.SummarySaverHook, tf.train.CheckpointSaverHook, tf.train.StepCounterHook, and tf.train.ProfilerHook have also been adapted to write summaries using the v2 summary system (via a compatibility layer), instead of using FileWriterCache. A couple additional smaller changes are: - the 'session' parameter to FileWriter() can now be a callable returning a tf.Session instance. - the introduction of tf.contrib.summary.record_summaries_if() which takes a boolean tensor for direct control of tf.contrib.summary.should_record_summaries(). - EstimatorSpec.train_op, besides a tf.Operation, is now allowed to be any Tensor-equivalent object rather than just a tf.Tensor. PiperOrigin-RevId: 205843986
Diffstat (limited to 'tensorflow/python/summary')
-rw-r--r--tensorflow/python/summary/writer/event_file_writer_v2.py71
-rw-r--r--tensorflow/python/summary/writer/writer.py8
-rw-r--r--tensorflow/python/summary/writer/writer_test.py54
3 files changed, 103 insertions, 30 deletions
diff --git a/tensorflow/python/summary/writer/event_file_writer_v2.py b/tensorflow/python/summary/writer/event_file_writer_v2.py
index 5c66c0f7a8..262182d3b8 100644
--- a/tensorflow/python/summary/writer/event_file_writer_v2.py
+++ b/tensorflow/python/summary/writer/event_file_writer_v2.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -43,11 +44,11 @@ class EventFileWriterV2(object):
"""Creates an `EventFileWriterV2` and an event file to write to.
On construction, this calls `tf.contrib.summary.create_file_writer` within
- the graph from `session.graph` to look up a shared summary writer resource
- for `logdir` if one exists, and create one if not. Creating the summary
+ the default graph, which finds and returns a shared summary writer resource
+ for `logdir` if one exists, and creates one if not. Creating the summary
writer resource in turn creates a new event file in `logdir` to be filled
with `Event` protocol buffers passed to `add_event`. Graph ops to control
- this writer resource are added to `session.graph` during this init call;
+ this writer resource are added to the default graph during this init call;
stateful methods on this class will call `session.run()` on these ops.
Note that because the underlying resource is shared, it is possible that
@@ -61,38 +62,50 @@ class EventFileWriterV2(object):
no effect. See `tf.contrib.summary.create_file_writer` for details.
Args:
- session: A `tf.Session`. Session that will hold shared writer resource.
- The writer ops will be added to session.graph during this init call.
+ session: A `tf.Session`, or a callable that provides one which will be
+ called on-demand. The session will hold the shared writer resource.
logdir: A string. Directory where event file will be written.
max_queue: Integer. Size of the queue for pending events and summaries.
flush_secs: Number. How often, in seconds, to flush the
pending events and summaries to disk.
filename_suffix: A string. Every event file's name is suffixed with
`filename_suffix`.
+
+ Raises:
+ ValueError: if `session` is not a `tf.Session` or a callable
"""
- self._session = session
+ if isinstance(session, tf_session.SessionInterface):
+ self._session = lambda: session
+ elif callable(session):
+ self._session = session
+ else:
+ raise ValueError('session must be tf.Session or callable')
self._logdir = logdir
+ self._initialized = False
self._closed = False
if not gfile.IsDirectory(self._logdir):
gfile.MakeDirs(self._logdir)
- with self._session.graph.as_default():
- with ops.name_scope('filewriter'):
- file_writer = summary_ops_v2.create_file_writer(
- logdir=self._logdir,
- max_queue=max_queue,
- flush_millis=flush_secs * 1000,
- filename_suffix=filename_suffix)
- with summary_ops_v2.always_record_summaries(), file_writer.as_default():
- self._event_placeholder = array_ops.placeholder_with_default(
- constant_op.constant('unused', dtypes.string),
- shape=[])
- self._add_event_op = summary_ops_v2.import_event(
- self._event_placeholder)
- self._init_op = file_writer.init()
- self._flush_op = file_writer.flush()
- self._close_op = file_writer.close()
- self._session.run(self._init_op)
+ with ops.name_scope('filewriter'):
+ file_writer = summary_ops_v2.create_file_writer(
+ logdir=self._logdir,
+ max_queue=max_queue,
+ flush_millis=flush_secs * 1000,
+ filename_suffix=filename_suffix)
+ with summary_ops_v2.always_record_summaries(), file_writer.as_default():
+ self._event_placeholder = array_ops.placeholder_with_default(
+ constant_op.constant('unused', dtypes.string),
+ shape=[])
+ self._add_event_op = summary_ops_v2.import_event(
+ self._event_placeholder)
+ self._init_op = file_writer.init()
+ self._flush_op = file_writer.flush()
+ self._close_op = file_writer.close()
+
+ def _init_if_needed(self):
+ if not self._initialized:
+ self._session().run(self._init_op)
+ self._initialized = True
def get_logdir(self):
"""Returns the directory where event file will be written."""
@@ -108,7 +121,6 @@ class EventFileWriterV2(object):
"""
if self._closed:
self._closed = False
- self._session.run(self._init_op)
def add_event(self, event):
"""Adds an event to the event file.
@@ -117,8 +129,9 @@ class EventFileWriterV2(object):
event: An `Event` protocol buffer.
"""
if not self._closed:
+ self._init_if_needed()
event_pb = event.SerializeToString()
- self._session.run(
+ self._session().run(
self._add_event_op, feed_dict={self._event_placeholder: event_pb})
def flush(self):
@@ -127,7 +140,9 @@ class EventFileWriterV2(object):
Call this method to make sure that all pending events have been written to
disk.
"""
- self._session.run(self._flush_op)
+ if not self._closed:
+ self._init_if_needed()
+ self._session().run(self._flush_op)
def close(self):
"""Flushes the event file to disk and close the file.
@@ -135,6 +150,8 @@ class EventFileWriterV2(object):
Call this method when you do not need the summary writer anymore.
"""
if not self._closed:
+ self._init_if_needed()
self.flush()
- self._session.run(self._close_op)
+ self._session().run(self._close_op)
self._closed = True
+ self._initialized = False
diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py
index aca084fc91..2a967ae3a5 100644
--- a/tensorflow/python/summary/writer/writer.py
+++ b/tensorflow/python/summary/writer/writer.py
@@ -332,8 +332,11 @@ class FileWriter(SummaryToEventTransformer):
the same shared resource name (which by default scoped to the logdir). If
no such resource exists, one will be created using the remaining arguments
to this constructor, but if one already exists those arguments are ignored.
- In either case, ops will be added to `session.graph` to control the
+ In either case, ops will be added to the default graph to control the
underlying file writer resource. See `tf.contrib.summary` for more details.
+ Instead of an actual `tf.Session`, this argument may also be a callable that
+ provides a `tf.Session` when invoked (e.g. `tf.get_default_session`), which
+ will be called on-demand when a session is needed.
Args:
logdir: A string. Directory where event file will be written.
@@ -344,7 +347,8 @@ class FileWriter(SummaryToEventTransformer):
graph_def: DEPRECATED: Use the `graph` argument instead.
filename_suffix: A string. Every event file's name is suffixed with
`suffix`.
- session: A `tf.Session` object. See details above.
+ session: A `tf.Session` object or a callable that provides `tf.Session`
+ objects. See details above.
Raises:
RuntimeError: If called with eager execution enabled.
diff --git a/tensorflow/python/summary/writer/writer_test.py b/tensorflow/python/summary/writer/writer_test.py
index dc990c2602..3380dea317 100644
--- a/tensorflow/python/summary/writer/writer_test.py
+++ b/tensorflow/python/summary/writer/writer_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for training_coordinator.py."""
+"""Tests for writer.py."""
from __future__ import absolute_import
from __future__ import division
@@ -574,6 +574,58 @@ class SessionBasedFileWriterTestCase(FileWriterTestCase):
# No more files
self.assertRaises(StopIteration, lambda: next(event_paths))
+ def testSesssionArgument_callableProvider(self):
+ logdir = self.get_temp_dir()
+ setup_writer = summary_ops_v2.create_file_writer(logdir=logdir)
+ with summary_ops_v2.always_record_summaries(), setup_writer.as_default():
+ summary1 = summary_ops_v2.scalar("one", 0.0, step=0)
+ summary2 = summary_ops_v2.scalar("two", 0.0, step=0)
+ sess1 = session.Session()
+ sess1.run(setup_writer.init())
+ sess1.run(summary1)
+ sess1.run(setup_writer.flush())
+ time.sleep(1.1) # Ensure filename has a different timestamp
+ sess2 = session.Session()
+ sess2.run(setup_writer.init())
+ sess2.run(summary2)
+ sess2.run(setup_writer.flush())
+
+ # Using get_default_session as session provider should make this FileWriter
+ # send its summaries to the current default session's shared summary writer
+ # resource (initializing it as needed).
+ test_writer = writer.FileWriter(
+ session=ops.get_default_session, logdir=logdir)
+ with sess1.as_default():
+ test_writer.add_summary(self._createTaggedSummary("won"), 1)
+ test_writer.flush()
+ with sess2.as_default():
+ test_writer.add_summary(self._createTaggedSummary("too"), 1)
+ test_writer.flush()
+
+ event_paths = iter(sorted(glob.glob(os.path.join(logdir, "event*"))))
+
+ # First file should have tags "one", "won"
+ events = summary_iterator.summary_iterator(next(event_paths))
+ self.assertEqual("brain.Event:2", next(events).file_version)
+ self.assertEqual("one", next(events).summary.value[0].tag)
+ self.assertEqual("won", next(events).summary.value[0].tag)
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ # Second file should have tags "two", "too"
+ events = summary_iterator.summary_iterator(next(event_paths))
+ self.assertEqual("brain.Event:2", next(events).file_version)
+ self.assertEqual("two", next(events).summary.value[0].tag)
+ self.assertEqual("too", next(events).summary.value[0].tag)
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ # No more files
+ self.assertRaises(StopIteration, lambda: next(event_paths))
+
+ def testSessionArgument_notSessionOrCallable(self):
+ logdir = self.get_temp_dir()
+ self.assertRaises(
+ ValueError, lambda: writer.FileWriter(session=[], logdir=logdir))
+
class FileWriterCacheTest(test.TestCase):
"""FileWriterCache tests."""