diff options
author | Nick Felt <nickfelt@google.com> | 2018-07-24 09:49:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-24 09:53:22 -0700 |
commit | 568727eed199dba04e37f500265b50f96fed455e (patch) | |
tree | 999f31d1469b3b5f2dc12d5ca04061cfe6062faa /tensorflow/python/summary | |
parent | f8bbd3ceb7e86b7595ba74a9a03cfc7c1be252a8 (diff) |
Add v2 summary support to Estimator.train() and MonitoredSession hooks
This change makes Estimator.train() support v2 summaries (tf.contrib.summary.*) out-of-the-box, to match the support for v1 summaries. Estimator.train() will now handle the boilerplate necessary to initialize a file writer and enable summary writing every N steps, and will ensure that its own automatically exported summaries (for loss and global_step/sec) get written to the same underlying events file.
As part of this change, tf.train.SummarySaverHook, tf.train.CheckpointSaverHook, tf.train.StepCounterHook, and tf.train.ProfilerHook have also been adapted to write summaries using the v2 summary system (via a compatibility layer), instead of using FileWriterCache.
A couple additional smaller changes are:
- the 'session' parameter to FileWriter() can now be a callable returning a tf.Session instance.
- the introduction of tf.contrib.summary.record_summaries_if() which takes a boolean tensor for direct control of tf.contrib.summary.should_record_summaries().
- EstimatorSpec.train_op, besides a tf.Operation, is now allowed to be any Tensor-equivalent object rather than just a tf.Tensor.
PiperOrigin-RevId: 205843986
Diffstat (limited to 'tensorflow/python/summary')
-rw-r--r-- | tensorflow/python/summary/writer/event_file_writer_v2.py | 71 | ||||
-rw-r--r-- | tensorflow/python/summary/writer/writer.py | 8 | ||||
-rw-r--r-- | tensorflow/python/summary/writer/writer_test.py | 54 |
3 files changed, 103 insertions, 30 deletions
diff --git a/tensorflow/python/summary/writer/event_file_writer_v2.py b/tensorflow/python/summary/writer/event_file_writer_v2.py index 5c66c0f7a8..262182d3b8 100644 --- a/tensorflow/python/summary/writer/event_file_writer_v2.py +++ b/tensorflow/python/summary/writer/event_file_writer_v2.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.client import session as tf_session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -43,11 +44,11 @@ class EventFileWriterV2(object): """Creates an `EventFileWriterV2` and an event file to write to. On construction, this calls `tf.contrib.summary.create_file_writer` within - the graph from `session.graph` to look up a shared summary writer resource - for `logdir` if one exists, and create one if not. Creating the summary + the default graph, which finds and returns a shared summary writer resource + for `logdir` if one exists, and creates one if not. Creating the summary writer resource in turn creates a new event file in `logdir` to be filled with `Event` protocol buffers passed to `add_event`. Graph ops to control - this writer resource are added to `session.graph` during this init call; + this writer resource are added to the default graph during this init call; stateful methods on this class will call `session.run()` on these ops. Note that because the underlying resource is shared, it is possible that @@ -61,38 +62,50 @@ class EventFileWriterV2(object): no effect. See `tf.contrib.summary.create_file_writer` for details. Args: - session: A `tf.Session`. Session that will hold shared writer resource. - The writer ops will be added to session.graph during this init call. + session: A `tf.Session`, or a callable that provides one which will be + called on-demand. The session will hold the shared writer resource. logdir: A string. Directory where event file will be written. max_queue: Integer. Size of the queue for pending events and summaries. flush_secs: Number. How often, in seconds, to flush the pending events and summaries to disk. filename_suffix: A string. Every event file's name is suffixed with `filename_suffix`. + + Raises: + ValueError: if `session` is not a `tf.Session` or a callable """ - self._session = session + if isinstance(session, tf_session.SessionInterface): + self._session = lambda: session + elif callable(session): + self._session = session + else: + raise ValueError('session must be tf.Session or callable') self._logdir = logdir + self._initialized = False self._closed = False if not gfile.IsDirectory(self._logdir): gfile.MakeDirs(self._logdir) - with self._session.graph.as_default(): - with ops.name_scope('filewriter'): - file_writer = summary_ops_v2.create_file_writer( - logdir=self._logdir, - max_queue=max_queue, - flush_millis=flush_secs * 1000, - filename_suffix=filename_suffix) - with summary_ops_v2.always_record_summaries(), file_writer.as_default(): - self._event_placeholder = array_ops.placeholder_with_default( - constant_op.constant('unused', dtypes.string), - shape=[]) - self._add_event_op = summary_ops_v2.import_event( - self._event_placeholder) - self._init_op = file_writer.init() - self._flush_op = file_writer.flush() - self._close_op = file_writer.close() - self._session.run(self._init_op) + with ops.name_scope('filewriter'): + file_writer = summary_ops_v2.create_file_writer( + logdir=self._logdir, + max_queue=max_queue, + flush_millis=flush_secs * 1000, + filename_suffix=filename_suffix) + with summary_ops_v2.always_record_summaries(), file_writer.as_default(): + self._event_placeholder = array_ops.placeholder_with_default( + constant_op.constant('unused', dtypes.string), + shape=[]) + self._add_event_op = summary_ops_v2.import_event( + self._event_placeholder) + self._init_op = file_writer.init() + self._flush_op = file_writer.flush() + self._close_op = file_writer.close() + + def _init_if_needed(self): + if not self._initialized: + self._session().run(self._init_op) + self._initialized = True def get_logdir(self): """Returns the directory where event file will be written.""" @@ -108,7 +121,6 @@ class EventFileWriterV2(object): """ if self._closed: self._closed = False - self._session.run(self._init_op) def add_event(self, event): """Adds an event to the event file. @@ -117,8 +129,9 @@ class EventFileWriterV2(object): event: An `Event` protocol buffer. """ if not self._closed: + self._init_if_needed() event_pb = event.SerializeToString() - self._session.run( + self._session().run( self._add_event_op, feed_dict={self._event_placeholder: event_pb}) def flush(self): @@ -127,7 +140,9 @@ class EventFileWriterV2(object): Call this method to make sure that all pending events have been written to disk. """ - self._session.run(self._flush_op) + if not self._closed: + self._init_if_needed() + self._session().run(self._flush_op) def close(self): """Flushes the event file to disk and close the file. @@ -135,6 +150,8 @@ class EventFileWriterV2(object): Call this method when you do not need the summary writer anymore. """ if not self._closed: + self._init_if_needed() self.flush() - self._session.run(self._close_op) + self._session().run(self._close_op) self._closed = True + self._initialized = False diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py index aca084fc91..2a967ae3a5 100644 --- a/tensorflow/python/summary/writer/writer.py +++ b/tensorflow/python/summary/writer/writer.py @@ -332,8 +332,11 @@ class FileWriter(SummaryToEventTransformer): the same shared resource name (which by default scoped to the logdir). If no such resource exists, one will be created using the remaining arguments to this constructor, but if one already exists those arguments are ignored. - In either case, ops will be added to `session.graph` to control the + In either case, ops will be added to the default graph to control the underlying file writer resource. See `tf.contrib.summary` for more details. + Instead of an actual `tf.Session`, this argument may also be a callable that + provides a `tf.Session` when invoked (e.g. `tf.get_default_session`), which + will be called on-demand when a session is needed. Args: logdir: A string. Directory where event file will be written. @@ -344,7 +347,8 @@ class FileWriter(SummaryToEventTransformer): graph_def: DEPRECATED: Use the `graph` argument instead. filename_suffix: A string. Every event file's name is suffixed with `suffix`. - session: A `tf.Session` object. See details above. + session: A `tf.Session` object or a callable that provides `tf.Session` + objects. See details above. Raises: RuntimeError: If called with eager execution enabled. diff --git a/tensorflow/python/summary/writer/writer_test.py b/tensorflow/python/summary/writer/writer_test.py index dc990c2602..3380dea317 100644 --- a/tensorflow/python/summary/writer/writer_test.py +++ b/tensorflow/python/summary/writer/writer_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for training_coordinator.py.""" +"""Tests for writer.py.""" from __future__ import absolute_import from __future__ import division @@ -574,6 +574,58 @@ class SessionBasedFileWriterTestCase(FileWriterTestCase): # No more files self.assertRaises(StopIteration, lambda: next(event_paths)) + def testSesssionArgument_callableProvider(self): + logdir = self.get_temp_dir() + setup_writer = summary_ops_v2.create_file_writer(logdir=logdir) + with summary_ops_v2.always_record_summaries(), setup_writer.as_default(): + summary1 = summary_ops_v2.scalar("one", 0.0, step=0) + summary2 = summary_ops_v2.scalar("two", 0.0, step=0) + sess1 = session.Session() + sess1.run(setup_writer.init()) + sess1.run(summary1) + sess1.run(setup_writer.flush()) + time.sleep(1.1) # Ensure filename has a different timestamp + sess2 = session.Session() + sess2.run(setup_writer.init()) + sess2.run(summary2) + sess2.run(setup_writer.flush()) + + # Using get_default_session as session provider should make this FileWriter + # send its summaries to the current default session's shared summary writer + # resource (initializing it as needed). + test_writer = writer.FileWriter( + session=ops.get_default_session, logdir=logdir) + with sess1.as_default(): + test_writer.add_summary(self._createTaggedSummary("won"), 1) + test_writer.flush() + with sess2.as_default(): + test_writer.add_summary(self._createTaggedSummary("too"), 1) + test_writer.flush() + + event_paths = iter(sorted(glob.glob(os.path.join(logdir, "event*")))) + + # First file should have tags "one", "won" + events = summary_iterator.summary_iterator(next(event_paths)) + self.assertEqual("brain.Event:2", next(events).file_version) + self.assertEqual("one", next(events).summary.value[0].tag) + self.assertEqual("won", next(events).summary.value[0].tag) + self.assertRaises(StopIteration, lambda: next(events)) + + # Second file should have tags "two", "too" + events = summary_iterator.summary_iterator(next(event_paths)) + self.assertEqual("brain.Event:2", next(events).file_version) + self.assertEqual("two", next(events).summary.value[0].tag) + self.assertEqual("too", next(events).summary.value[0].tag) + self.assertRaises(StopIteration, lambda: next(events)) + + # No more files + self.assertRaises(StopIteration, lambda: next(event_paths)) + + def testSessionArgument_notSessionOrCallable(self): + logdir = self.get_temp_dir() + self.assertRaises( + ValueError, lambda: writer.FileWriter(session=[], logdir=logdir)) + class FileWriterCacheTest(test.TestCase): """FileWriterCache tests.""" |