aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/summary_io.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/summary_io.py')
-rw-r--r--tensorflow/python/training/summary_io.py78
1 files changed, 63 insertions, 15 deletions
diff --git a/tensorflow/python/training/summary_io.py b/tensorflow/python/training/summary_io.py
index 1257230df9..ff92008872 100644
--- a/tensorflow/python/training/summary_io.py
+++ b/tensorflow/python/training/summary_io.py
@@ -25,11 +25,14 @@ import time
import six
+from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import summary_pb2
from tensorflow.core.util import event_pb2
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.framework import ops
from tensorflow.python.lib.io import tf_record
from tensorflow.python.platform import gfile
+from tensorflow.python.platform import logging
from tensorflow.python.util import compat
@@ -53,7 +56,8 @@ class SummaryWriter(object):
@@close
"""
- def __init__(self, logdir, graph_def=None, max_queue=10, flush_secs=120):
+ def __init__(self, logdir, graph=None, max_queue=10, flush_secs=120,
+ graph_def=None):
"""Creates a `SummaryWriter` and an event file.
On construction the summary writer creates a new event file in `logdir`.
@@ -61,7 +65,7 @@ class SummaryWriter(object):
call one of the following functions: `add_summary()`, `add_session_log()`,
`add_event()`, or `add_graph()`.
- If you pass a `graph_def` protocol buffer to the constructor it is added to
+ If you pass a `Graph` to the constructor it is added to
the event file. (This is equivalent to calling `add_graph()` later).
TensorBoard will pick the graph from the file and display it graphically so
@@ -72,8 +76,8 @@ class SummaryWriter(object):
...create a graph...
# Launch the graph in a session.
sess = tf.Session()
- # Create a summary writer, add the 'graph_def' to the event file.
- writer = tf.train.SummaryWriter(<some-directory>, sess.graph_def)
+ # Create a summary writer, add the 'graph' to the event file.
+ writer = tf.train.SummaryWriter(<some-directory>, sess.graph)
```
The other arguments to the constructor control the asynchronous writes to
@@ -86,10 +90,11 @@ class SummaryWriter(object):
Args:
logdir: A string. Directory where event file will be written.
- graph_def: A `GraphDef` protocol buffer.
+ graph: A `Graph` object, such as `sess.graph`.
max_queue: Integer. Size of the queue for pending events and summaries.
flush_secs: Number. How often, in seconds, to flush the
pending events and summaries to disk.
+ graph_def: DEPRECATED: Use the `graph` argument instead.
"""
self._logdir = logdir
if not gfile.IsDirectory(self._logdir):
@@ -100,8 +105,9 @@ class SummaryWriter(object):
self._worker = _EventLoggerThread(self._event_queue, self._ev_writer,
flush_secs)
self._worker.start()
- if graph_def is not None:
- self.add_graph(graph_def)
+ if graph is not None or graph_def is not None:
+ # Calling it with both graph and graph_def for backward compatibility.
+ self.add_graph(graph=graph, graph_def=graph_def)
def add_summary(self, summary, global_step=None):
"""Adds a `Summary` protocol buffer to the event file.
@@ -154,22 +160,64 @@ class SummaryWriter(object):
"""
self._event_queue.put(event)
- def add_graph(self, graph_def, global_step=None):
- """Adds a `GraphDef` protocol buffer to the event file.
+ def _add_graph_def(self, graph_def, global_step=None):
+ graph_bytes = graph_def.SerializeToString()
+ event = event_pb2.Event(wall_time=time.time(), graph_def=graph_bytes)
+ if global_step is not None:
+ event.step = int(global_step)
+ self._event_queue.put(event)
+
+ def add_graph(self, graph, global_step=None, graph_def=None):
+ """Adds a `Graph` to the event file.
The graph described by the protocol buffer will be displayed by
TensorBoard. Most users pass a graph in the constructor instead.
Args:
- graph_def: A `GraphDef` protocol buffer.
+ graph: A `Graph` object, such as `sess.graph`.
global_step: Number. Optional global step counter to record with the
graph.
+ graph_def: DEPRECATED. Use the `graph` parameter instead.
+
+ Raises:
+ ValueError: If both graph and graph_def are passed to the method.
"""
- graph_bytes = graph_def.SerializeToString()
- event = event_pb2.Event(wall_time=time.time(), graph_def=graph_bytes)
- if global_step is not None:
- event.step = int(global_step)
- self._event_queue.put(event)
+
+ if graph is not None and graph_def is not None:
+ raise ValueError("Please pass only graph, or graph_def (deprecated), "
+ "but not both.")
+
+ if isinstance(graph, ops.Graph) or isinstance(graph_def, ops.Graph):
+ # The user passed a `Graph`.
+
+ # Check if the user passed it via the graph or the graph_def argument and
+ # correct for that.
+ if not isinstance(graph, ops.Graph):
+ logging.warning("When passing a `Graph` object, please use the `graph`"
+ " named argument instead of `graph_def`.")
+ graph = graph_def
+
+ # Serialize the graph with additional info.
+ true_graph_def = graph.as_graph_def(add_shapes=True)
+ elif (isinstance(graph, graph_pb2.GraphDef)
+ or isinstance(graph_def, graph_pb2.GraphDef)):
+ # The user passed a `GraphDef`.
+ logging.warning("Passing a `GraphDef` to the SummaryWriter is deprecated."
+ " Pass a `Graph` object instead, such as `sess.graph`.")
+
+ # Check if the user passed it via the graph or the graph_def argument and
+ # correct for that.
+ if isinstance(graph, graph_pb2.GraphDef):
+ true_graph_def = graph
+ else:
+ true_graph_def = graph_def
+
+ else:
+ # The user passed neither `Graph`, nor `GraphDef`.
+ raise TypeError("The passed graph must be an instance of `Graph` "
+ "or the deprecated `GraphDef`")
+ # Finally, add the graph_def to the summary writer.
+ self._add_graph_def(true_graph_def, global_step)
def flush(self):
"""Flushes the event file to disk.