diff options
author | 2018-07-24 09:49:47 -0700 | |
---|---|---|
committer | 2018-07-24 09:53:22 -0700 | |
commit | 568727eed199dba04e37f500265b50f96fed455e (patch) | |
tree | 999f31d1469b3b5f2dc12d5ca04061cfe6062faa /tensorflow/contrib/summary | |
parent | f8bbd3ceb7e86b7595ba74a9a03cfc7c1be252a8 (diff) |
Add v2 summary support to Estimator.train() and MonitoredSession hooks
This change makes Estimator.train() support v2 summaries (tf.contrib.summary.*) out-of-the-box, to match the support for v1 summaries. Estimator.train() will now handle the boilerplate necessary to initialize a file writer and enable summary writing every N steps, and will ensure that its own automatically exported summaries (for loss and global_step/sec) get written to the same underlying events file.
As part of this change, tf.train.SummarySaverHook, tf.train.CheckpointSaverHook, tf.train.StepCounterHook, and tf.train.ProfilerHook have also been adapted to write summaries using the v2 summary system (via a compatibility layer), instead of using FileWriterCache.
A couple additional smaller changes are:
- the 'session' parameter to FileWriter() can now be a callable returning a tf.Session instance.
- the introduction of tf.contrib.summary.record_summaries_if() which takes a boolean tensor for direct control of tf.contrib.summary.should_record_summaries().
- EstimatorSpec.train_op, besides a tf.Operation, is now allowed to be any Tensor-equivalent object rather than just a tf.Tensor.
PiperOrigin-RevId: 205843986
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_graph_test.py | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py index ae8336daaf..409fdf4583 100644 --- a/tensorflow/contrib/summary/summary_ops_graph_test.py +++ b/tensorflow/contrib/summary/summary_ops_graph_test.py @@ -228,6 +228,26 @@ class GraphFileTest(test_util.TensorFlowTestCase): sess.run(writer.flush()) self.assertEqual(2, get_total()) + def testSummaryOpsCollector(self): + summary_ops.scalar('x', 1.0, step=1) + with summary_ops.create_file_writer(self.get_temp_dir()).as_default(): + s2 = summary_ops.scalar('x', 1.0, step=1) + collector1 = summary_ops._SummaryOpsCollector() + collector2 = summary_ops._SummaryOpsCollector() + with collector1.capture(): + s3 = summary_ops.scalar('x', 1.0, step=1) + with collector2.capture(): + s4 = summary_ops.scalar('x', 1.0, step=1) + s5 = summary_ops.scalar('x', 1.0, step=1) + s6 = summary_ops.scalar('x', 1.0, step=1) + summary_ops.scalar('six', 1.0, step=1) + + # Ops defined outside summary writer context are ignored; ops defined inside + # SummaryOpsCollector capture context are stored to innermost such context. + self.assertItemsEqual([s2, s6], summary_ops.all_summary_ops()) + self.assertItemsEqual([s3, s5], collector1.collected_ops) + self.assertItemsEqual([s4], collector2.collected_ops) + class GraphDbTest(summary_test_util.SummaryDbTest): |