aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/summary_io.py
blob: dd994c53115a2087744c18ed393e3f995e41d215 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""Reads Summaries from and writes Summaries to event files."""

import os.path
import Queue
import threading
import time

from tensorflow.core.framework import summary_pb2
from tensorflow.core.util import event_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.lib.io import tf_record
from tensorflow.python.platform import gfile


class SummaryWriter(object):
  """Writes `Summary` protocol buffers to event files.

  The `SummaryWriter` class provides a mechanism to create an event file in a
  given directory and add summaries and events to it. The class updates the
  file contents asynchronously. This allows a training program to call methods
  to add data to the file directly from the training loop, without slowing down
  training.

  @@__init__

  @@add_summary
  @@add_event
  @@add_graph

  @@flush
  @@close
  """

  def __init__(self, logdir, graph_def=None, max_queue=10, flush_secs=120):
    """Creates a `SummaryWriter` and an event file.

    On construction the summary writer creates a new event file in `logdir`.
    This event file will contain `Event` protocol buffers constructed when you
    call one of the following functions: `add_summary()`, `add_event()`, or
    `add_graph()`.

    If you pass a `graph_def` protocol buffer to the constructor it is added to
    the event file. (This is equivalent to calling `add_graph()` later).

    TensorBoard will pick the graph from the file and display it graphically so
    you can interactively explore the graph you built. You will usually pass
    the graph from the session in which you launched it:

    ```python
    ...create a graph...
    # Launch the graph in a session.
    sess = tf.Session()
    # Create a summary writer, add the 'graph_def' to the event file.
    writer = tf.train.SummaryWriter(<some-directory>, sess.graph_def)
    ```

    The other arguments to the constructor control the asynchronous writes to
    the event file:

    *  `flush_secs`: How often, in seconds, to flush the added summaries
       and events to disk.
    *  `max_queue`: Maximum number of summaries or events pending to be
       written to disk before one of the 'add' calls block.

    Args:
      logdir: A string. Directory where event file will be written.
      graph_def: A `GraphDef` protocol buffer.
      max_queue: Integer. Size of the queue for pending events and summaries.
      flush_secs: Number. How often, in seconds, to flush the
        pending events and summaries to disk.
    """
    self._logdir = logdir
    if not gfile.IsDirectory(self._logdir):
      gfile.MakeDirs(self._logdir)
    self._event_queue = Queue.Queue(max_queue)
    self._ev_writer = pywrap_tensorflow.EventsWriter(
        os.path.join(self._logdir, "events"))
    self._worker = _EventLoggerThread(self._event_queue, self._ev_writer,
                                      flush_secs)
    self._worker.start()
    if graph_def is not None:
      self.add_graph(graph_def)

  def add_summary(self, summary, global_step=None):
    """Adds a `Summary` protocol buffer to the event file.

    This method wraps the provided summary in an `Event` procotol buffer
    and adds it to the event file.

    You can pass the output of any summary op, as-is, to this function. You
    can also pass a `Summary` procotol buffer that you manufacture with your
    own data. This is commonly done to report evaluation results in event
    files.

    Args:
      summary: A `Summary` protocol buffer, optionally serialized as a string.
      global_step: Number. Optional global step value to record with the
        summary.
    """
    if isinstance(summary, basestring):
      summ = summary_pb2.Summary()
      summ.ParseFromString(summary)
      summary = summ
    event = event_pb2.Event(wall_time=time.time(), summary=summary)
    if global_step is not None:
      event.step = long(global_step)
    self.add_event(event)

  def add_event(self, event):
    """Adds an event to the event file.

    Args:
      event: An `Event` protocol buffer.
    """
    self._event_queue.put(event)

  def add_graph(self, graph_def, global_step=None):
    """Adds a `GraphDef` protocol buffer to the event file.

    The graph described by the protocol buffer will be displayed by
    TensorBoard. Most users pass a graph in the constructor instead.

    Args:
      graph_def: A `GraphDef` protocol buffer.
      global_step: Number. Optional global step counter to record with the
        graph.
    """
    event = event_pb2.Event(wall_time=time.time(), graph_def=graph_def)
    if global_step is not None:
      event.step = long(global_step)
    self._event_queue.put(event)

  def flush(self):
    """Flushes the event file to disk.

    Call this method to make sure that all pending events have been written to
    disk.
    """
    self._event_queue.join()
    self._ev_writer.Flush()

  def close(self):
    """Flushes the event file to disk and close the file.

    Call this method when you do not need the summary writer anymore.
    """
    self.flush()
    self._ev_writer.Close()


class _EventLoggerThread(threading.Thread):
  """Thread that logs events."""

  def __init__(self, queue, ev_writer, flush_secs):
    """Creates an _EventLoggerThread.

    Args:
      queue: a Queue from which to dequeue events.
      ev_writer: an event writer. Used to log brain events for
       the visualizer.
      flush_secs: How often, in seconds, to flush the
        pending file to disk.
    """
    threading.Thread.__init__(self)
    self.daemon = True
    self._queue = queue
    self._ev_writer = ev_writer
    self._flush_secs = flush_secs
    # The first event will be flushed immediately.
    self._next_event_flush_time = 0

  def run(self):
    while True:
      event = self._queue.get()
      try:
        self._ev_writer.WriteEvent(event)
        # Flush the event writer every so often.
        now = time.time()
        if now > self._next_event_flush_time:
          self._ev_writer.Flush()
          # Do it again in two minutes.
          self._next_event_flush_time = now + self._flush_secs
      finally:
        self._queue.task_done()


def summary_iterator(path):
  """An iterator for reading `Event` protocol buffers from an event file.

  You can use this function to read events written to an event file. It returns
  a Python iterator that yields `Event` protocol buffers.

  Example: Print the contents of an events file.

  ```python
  for e in tf.summary_iterator(path to events file):
      print e
  ```

  Example: Print selected summary values.

  ```python
  # This example supposes that the events file contains summaries with a
  # summary value tag 'loss'.  These could have been added by calling
  # `add_summary()`, passing the output of a scalar summary op created with
  # with: `tf.scalar_summary(['loss'], loss_tensor)`.
  for e in tf.summary_iterator(path to events file):
      for v in e.summary.value:
          if v.tag == 'loss':
              print v.simple_value
  ```

  See the protocol buffer definitions of
  [Event](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/util/event.proto)
  and
  [Summary](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/summary.proto)
  for more information about their attributes.

  Args:
    path: The path to an event file created by a `SummaryWriter`.

  Yields:
    `Event` protocol buffers.
  """
  for r in tf_record.tf_record_iterator(path):
    yield event_pb2.Event.FromString(r)