diff options
-rw-r--r-- | tensorflow/contrib/summary/summary.py | 36 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops.py | 12 |
2 files changed, 41 insertions, 7 deletions
diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py index 7d3b8b7437..2d6d7ea6a3 100644 --- a/tensorflow/contrib/summary/summary.py +++ b/tensorflow/contrib/summary/summary.py @@ -18,6 +18,42 @@ The operations in this package are safe to use with eager execution turned on or off. It has a more flexible API that allows summaries to be written directly from ops to places other than event log files, rather than propagating protos from @{tf.summary.merge_all} to @{tf.summary.FileWriter}. + +To use with eager execution enabled, write your code as follows: + +global_step = tf.train.get_or_create_global_step() +summary_writer = tf.contrib.summary.create_file_writer( + train_dir, flush_millis=10000) +with summary_writer.as_default(), tf.contrib.summary.always_record_summaries(): + # model code goes here + # and in it call + tf.contrib.summary.scalar("loss", my_loss) + # In this case every call to tf.contrib.summary.scalar will generate a record + # ... + +To use it with graph execution, write your code as follows: + +global_step = tf.train.get_or_create_global_step() +summary_writer = tf.contrib.summary.create_file_writer( + train_dir, flush_millis=10000) +with summary_writer.as_default(), tf.contrib.summary.always_record_summaries(): + # model definition code goes here + # and in it call + tf.contrib.summary.scalar("loss", my_loss) + # In this case every call to tf.contrib.summary.scalar will generate an op, + # note the need to run tf.contrib.summary.all_summary_ops() to make sure these + # ops get executed. + # ... + train_op = .... + +with tf.Session(...) as sess: + tf.global_variables_initializer().run() + tf.contrib.summary.initialize(graph=tf.get_default_graph()) + # ... + while not_done_training: + sess.run([train_op, tf.contrib.summary.all_summary_ops()]) + # ... + """ from __future__ import absolute_import diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index a6968d8b2a..068ae35c71 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -154,10 +154,12 @@ def initialize( to @{tf.get_default_session}. Raises: - RuntimeError: If in eager mode, or if the current thread has no - default @{tf.contrib.summary.SummaryWriter}. + RuntimeError: If the current thread has no default + @{tf.contrib.summary.SummaryWriter}. ValueError: If session wasn't passed and no default session. """ + if context.in_eager_mode(): + return if context.context().summary_writer_resource is None: raise RuntimeError("No default tf.contrib.summary.SummaryWriter found") if session is None: @@ -292,13 +294,9 @@ def all_summary_ops(): Returns: The summary ops. - - Raises: - RuntimeError: If in Eager mode. """ if context.in_eager_mode(): - raise RuntimeError( - "tf.contrib.summary.all_summary_ops is only supported in graph mode.") + return None return ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access |