diff options
Diffstat (limited to 'tensorflow/python/training/summary_io.py')
-rw-r--r-- | tensorflow/python/training/summary_io.py | 78 |
1 files changed, 63 insertions, 15 deletions
diff --git a/tensorflow/python/training/summary_io.py b/tensorflow/python/training/summary_io.py index 1257230df9..ff92008872 100644 --- a/tensorflow/python/training/summary_io.py +++ b/tensorflow/python/training/summary_io.py @@ -25,11 +25,14 @@ import time import six +from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import summary_pb2 from tensorflow.core.util import event_pb2 from tensorflow.python import pywrap_tensorflow +from tensorflow.python.framework import ops from tensorflow.python.lib.io import tf_record from tensorflow.python.platform import gfile +from tensorflow.python.platform import logging from tensorflow.python.util import compat @@ -53,7 +56,8 @@ class SummaryWriter(object): @@close """ - def __init__(self, logdir, graph_def=None, max_queue=10, flush_secs=120): + def __init__(self, logdir, graph=None, max_queue=10, flush_secs=120, + graph_def=None): """Creates a `SummaryWriter` and an event file. On construction the summary writer creates a new event file in `logdir`. @@ -61,7 +65,7 @@ class SummaryWriter(object): call one of the following functions: `add_summary()`, `add_session_log()`, `add_event()`, or `add_graph()`. - If you pass a `graph_def` protocol buffer to the constructor it is added to + If you pass a `Graph` to the constructor it is added to the event file. (This is equivalent to calling `add_graph()` later). TensorBoard will pick the graph from the file and display it graphically so @@ -72,8 +76,8 @@ class SummaryWriter(object): ...create a graph... # Launch the graph in a session. sess = tf.Session() - # Create a summary writer, add the 'graph_def' to the event file. - writer = tf.train.SummaryWriter(<some-directory>, sess.graph_def) + # Create a summary writer, add the 'graph' to the event file. + writer = tf.train.SummaryWriter(<some-directory>, sess.graph) ``` The other arguments to the constructor control the asynchronous writes to @@ -86,10 +90,11 @@ class SummaryWriter(object): Args: logdir: A string. Directory where event file will be written. - graph_def: A `GraphDef` protocol buffer. + graph: A `Graph` object, such as `sess.graph`. max_queue: Integer. Size of the queue for pending events and summaries. flush_secs: Number. How often, in seconds, to flush the pending events and summaries to disk. + graph_def: DEPRECATED: Use the `graph` argument instead. """ self._logdir = logdir if not gfile.IsDirectory(self._logdir): @@ -100,8 +105,9 @@ class SummaryWriter(object): self._worker = _EventLoggerThread(self._event_queue, self._ev_writer, flush_secs) self._worker.start() - if graph_def is not None: - self.add_graph(graph_def) + if graph is not None or graph_def is not None: + # Calling it with both graph and graph_def for backward compatibility. + self.add_graph(graph=graph, graph_def=graph_def) def add_summary(self, summary, global_step=None): """Adds a `Summary` protocol buffer to the event file. @@ -154,22 +160,64 @@ class SummaryWriter(object): """ self._event_queue.put(event) - def add_graph(self, graph_def, global_step=None): - """Adds a `GraphDef` protocol buffer to the event file. + def _add_graph_def(self, graph_def, global_step=None): + graph_bytes = graph_def.SerializeToString() + event = event_pb2.Event(wall_time=time.time(), graph_def=graph_bytes) + if global_step is not None: + event.step = int(global_step) + self._event_queue.put(event) + + def add_graph(self, graph, global_step=None, graph_def=None): + """Adds a `Graph` to the event file. The graph described by the protocol buffer will be displayed by TensorBoard. Most users pass a graph in the constructor instead. Args: - graph_def: A `GraphDef` protocol buffer. + graph: A `Graph` object, such as `sess.graph`. global_step: Number. Optional global step counter to record with the graph. + graph_def: DEPRECATED. Use the `graph` parameter instead. + + Raises: + ValueError: If both graph and graph_def are passed to the method. """ - graph_bytes = graph_def.SerializeToString() - event = event_pb2.Event(wall_time=time.time(), graph_def=graph_bytes) - if global_step is not None: - event.step = int(global_step) - self._event_queue.put(event) + + if graph is not None and graph_def is not None: + raise ValueError("Please pass only graph, or graph_def (deprecated), " + "but not both.") + + if isinstance(graph, ops.Graph) or isinstance(graph_def, ops.Graph): + # The user passed a `Graph`. + + # Check if the user passed it via the graph or the graph_def argument and + # correct for that. + if not isinstance(graph, ops.Graph): + logging.warning("When passing a `Graph` object, please use the `graph`" + " named argument instead of `graph_def`.") + graph = graph_def + + # Serialize the graph with additional info. + true_graph_def = graph.as_graph_def(add_shapes=True) + elif (isinstance(graph, graph_pb2.GraphDef) + or isinstance(graph_def, graph_pb2.GraphDef)): + # The user passed a `GraphDef`. + logging.warning("Passing a `GraphDef` to the SummaryWriter is deprecated." + " Pass a `Graph` object instead, such as `sess.graph`.") + + # Check if the user passed it via the graph or the graph_def argument and + # correct for that. + if isinstance(graph, graph_pb2.GraphDef): + true_graph_def = graph + else: + true_graph_def = graph_def + + else: + # The user passed neither `Graph`, nor `GraphDef`. + raise TypeError("The passed graph must be an instance of `Graph` " + "or the deprecated `GraphDef`") + # Finally, add the graph_def to the summary writer. + self._add_graph_def(true_graph_def, global_step) def flush(self): """Flushes the event file to disk. |