aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md
blob: 5059f02a73d215af95c04dd93dd9305820628fc3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# TensorBoard: Visualizing Learning

The computations you'll use TensorFlow for - like training a massive
deep neural network - can be complex and confusing. To make it easier to
understand, debug, and optimize TensorFlow programs, we've included a suite of
visualization tools called TensorBoard. You can use TensorBoard to visualize
your TensorFlow graph, plot quantitative metrics about the execution of your
graph, and show additional data like images that pass through it. When
TensorBoard is fully configured, it looks like this:

![MNIST TensorBoard](../../images/mnist_tensorboard.png "MNIST TensorBoard")


## Serializing the data

TensorBoard operates by reading TensorFlow events files, which contain summary
data that you can generate when running TensorFlow. Here's the general
lifecycle for summary data within TensorBoard.

First, create the TensorFlow graph that you'd like to collect summary
data from, and decide which nodes you would like to annotate with
[summary operations]
(../../api_docs/python/train.md#summary-operations).

For example, suppose you are training a convolutional neural network for
recognizing MNIST digits. You'd like to record how the learning rate
varies over time, and how the objective function is changing. Collect these by
attaching [`scalar_summary`](../../api_docs/python/train.md#scalar_summary) ops
to the nodes that output the learning rate and loss respectively. Then, give
each `scalar_summary` a meaningful `tag`, like `'learning rate'` or `'loss
function'`.

Perhaps you'd also like to visualize the distributions of activations coming
off a particular layer, or the distribution of gradients or weights. Collect
this data by attaching
[`histogram_summary`](../../api_docs/python/train.md#histogram_summary) ops to
the gradient outputs and to the variable that holds your weights, respectively.

For details on all of the summary operations available, check out the docs on
[summary operations]
(../../api_docs/python/train.md#summary-operations).

Operations in TensorFlow don't do anything until you run them, or an op that
depends on their output. And the summary nodes that we've just created are
peripheral to your graph: none of the ops you are currently running depend on
them. So, to generate summaries, we need to run all of these summary nodes.
Managing them by hand would be tedious, so use
[`tf.merge_all_summaries`](../../api_docs/python/train.md#merge_all_summaries)
to combine them into a single op that generates all the summary data.

Then, you can just run the merged summary op, which will generate a serialized
`Summary` protobuf object with all of your summary data at a given step.
Finally, to write this summary data to disk, pass the summary protobuf to a
[`tf.train.SummaryWriter`](../../api_docs/python/train.md#SummaryWriter).

The `SummaryWriter` takes a logdir in its constructor - this logdir is quite
important, it's the directory where all of the events will be written out.
Also, the `SummaryWriter` can optionally take a `GraphDef` in its constructor.
If it receives one, then TensorBoard will visualize your graph as well.
To include tensor shape information in the `GraphDef`, pass
`sess.graph.as_graph_def(add_shapes=True)` to the `SummaryWriter`. This will
give you a much better sense of what flows through the graph: see
[Tensor shape information](../../how_tos/graph_viz/index.md#tensor-shape-information).

Now that you've modified your graph and have a `SummaryWriter`, you're ready to
start running your network! If you want, you could run the merged summary op
every single step, and record a ton of training data. That's likely to be more
data than you need, though. Instead, consider running the merged summary op
every `n` steps.

The code example below is a modification of the [simple MNIST tutorial]
(http://tensorflow.org/tutorials/mnist/beginners/index.md), in which we have
added some summary ops, and run them every ten steps. If you run this and then
launch `tensorboard --logdir=/tmp/mnist_logs`, you'll be able to visualize
statistics, such as how the weights or accuracy varied during training.
The code below is an excerpt; full source is [here](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py).

```python
# Create the model
x = tf.placeholder(tf.float32, [None, 784], name="x-input")
W = tf.Variable(tf.zeros([784,10]), name="weights")
b = tf.Variable(tf.zeros([10], name="bias"))

# use a name scope to organize nodes in the graph visualizer
with tf.name_scope("Wx_b") as scope:
  y = tf.nn.softmax(tf.matmul(x,W) + b)

# Add summary ops to collect data
tf.histogram_summary("weights", W)
tf.histogram_summary("biases", b)
tf.histogram_summary("y", y)

# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None,10], name="y-input")
# More name scopes will clean up the graph representation
with tf.name_scope("xent") as scope:
  cross_entropy = -tf.reduce_sum(y_*tf.log(y))
  tf.scalar_summary("cross entropy", cross_entropy)
with tf.name_scope("train") as scope:
  train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

with tf.name_scope("test") as scope:
  correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  tf.scalar_summary("accuracy", accuracy)

# Merge all the summaries and write them out to /tmp/mnist_logs
merged = tf.merge_all_summaries()
writer = tf.train.SummaryWriter("/tmp/mnist_logs",
                                sess.graph.as_graph_def(add_shapes=True))
tf.initialize_all_variables().run()

# Train the model, and feed in test data and record summaries every 10 steps

for i in range(1000):
  if i % 10 == 0:  # Record summary data, and the accuracy
    feed = {x: mnist.test.images, y_: mnist.test.labels}
    result = sess.run([merged, accuracy], feed_dict=feed)
    summary_str = result[0]
    acc = result[1]
    writer.add_summary(summary_str, i)
    print("Accuracy at step %s: %s" % (i, acc))
  else:
    batch_xs, batch_ys = mnist.train.next_batch(100)
    feed = {x: batch_xs, y_: batch_ys}
    sess.run(train_step, feed_dict=feed)

print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

```

You're now all set to visualize this data using TensorBoard.


## Launching TensorBoard

To run TensorBoard, use the command

```bash
python tensorflow/tensorboard/tensorboard.py --logdir=path/to/log-directory
```

where `logdir` points to the directory where the `SummaryWriter` serialized its
data.  If this `logdir` directory contains subdirectories which contain
serialized data from separate runs, then TensorBoard will visualize the data
from all of those runs. Once TensorBoard is running, navigate your web browser
to `localhost:6006` to view the TensorBoard.

If you have pip installed TensorFlow, `tensorboard` is installed into
the system path, so you can use the simpler command

```bash
tensorboard --logdir=/path/to/log-directory
```

When looking at TensorBoard, you will see the navigation tabs in the top right
corner. Each tab represents a set of serialized data that can be visualized.
For any tab you are looking at, if the logs being looked at by TensorBoard do
not contain any data relevant to that tab, a message will be displayed
indicating how to serialize data that is applicable to that tab.

For in depth information on how to use the *graph* tab to visualize your graph,
see [TensorBoard: Graph Visualization](../../how_tos/graph_viz/index.md).