diff options
author | 2016-03-24 08:02:12 -0800 | |
---|---|---|
committer | 2016-03-24 09:03:01 -0700 | |
commit | c404448d3b1e44fddc2d6e1c6da9862443112721 (patch) | |
tree | 6c44d4bb1556875829ec5d9c3a0f3c4dc372daf2 /tensorflow | |
parent | 5e06a231344f6324cc5054b867b0eb8cbeb37771 (diff) |
Replace graph_def with graph when passed to the SummaryWriter and update tutorials to reflect the new API.
Change: 118033430
Diffstat (limited to 'tensorflow')
12 files changed, 24 insertions, 37 deletions
diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py index ef46329cd5..d08a15433a 100644 --- a/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py +++ b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py @@ -103,8 +103,7 @@ def run_training(): sess.run(init_op) # Instantiate a SummaryWriter to output summaries and the Graph. - summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, - graph_def=sess.graph_def) + summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) # Start input enqueue threads. coord = tf.train.Coordinator() diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py index b8bd143ec8..3b43abc81c 100644 --- a/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py +++ b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py @@ -113,8 +113,7 @@ def run_training(): feed_dict={labels_initializer: data_sets.train.labels}) # Instantiate a SummaryWriter to output summaries and the Graph. - summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, - graph_def=sess.graph_def) + summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) # Start input enqueue threads. coord = tf.train.Coordinator() diff --git a/tensorflow/examples/tutorials/mnist/fully_connected_feed.py b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py index 8a9884a508..eda1ac5b59 100644 --- a/tensorflow/examples/tutorials/mnist/fully_connected_feed.py +++ b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py @@ -163,8 +163,7 @@ def run_training(): sess.run(init) # Instantiate a SummaryWriter to output summaries and the Graph. - summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, - graph_def=sess.graph_def) + summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) # And then after everything is built, start the training loop. for step in xrange(FLAGS.max_steps): diff --git a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py index 33dc13c813..e637032b2e 100644 --- a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py +++ b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py @@ -82,8 +82,7 @@ def main(_): # Merge all the summaries and write them out to /tmp/mnist_logs (by default) merged = tf.merge_all_summaries() - writer = tf.train.SummaryWriter(FLAGS.summaries_dir, - sess.graph.as_graph_def(add_shapes=True)) + writer = tf.train.SummaryWriter(FLAGS.summaries_dir, sess.graph) tf.initialize_all_variables().run() # Train the model, and feed in test data and record summaries every 10 steps diff --git a/tensorflow/g3doc/how_tos/graph_viz/index.md b/tensorflow/g3doc/how_tos/graph_viz/index.md index 77c9332790..a4689c492b 100644 --- a/tensorflow/g3doc/how_tos/graph_viz/index.md +++ b/tensorflow/g3doc/how_tos/graph_viz/index.md @@ -233,10 +233,9 @@ The images below give an illustration for a piece of a real-life graph. When the serialized `GraphDef` includes tensor shapes, the graph visualizer labels edges with tensor dimensions, and edge thickness reflects total tensor -size. To include tensor shapes in the `GraphDef` pass -`sess.graph.as_graph_def(add_shapes=True)` to the `SummaryWriter` when -serializing the graph. The images below show the CIFAR-10 model with tensor -shape information: +size. To include tensor shapes in the `GraphDef` pass the actual graph object +(as in `sess.graph`) to the `SummaryWriter` when serializing the graph. +The images below show the CIFAR-10 model with tensor shape information: <table width="100%;"> <tr> <td style="width: 100%;"> diff --git a/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md b/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md index 5059f02a73..39a5661f24 100644 --- a/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md +++ b/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md @@ -55,11 +55,10 @@ Finally, to write this summary data to disk, pass the summary protobuf to a The `SummaryWriter` takes a logdir in its constructor - this logdir is quite important, it's the directory where all of the events will be written out. -Also, the `SummaryWriter` can optionally take a `GraphDef` in its constructor. -If it receives one, then TensorBoard will visualize your graph as well. -To include tensor shape information in the `GraphDef`, pass -`sess.graph.as_graph_def(add_shapes=True)` to the `SummaryWriter`. This will -give you a much better sense of what flows through the graph: see +Also, the `SummaryWriter` can optionally take a `Graph` in its constructor. +If it receives a `Graph` object, then TensorBoard will visualize your graph +along with tensor shape information. This will give you a much better sense of +what flows through the graph: see [Tensor shape information](../../how_tos/graph_viz/index.md#tensor-shape-information). Now that you've modified your graph and have a `SummaryWriter`, you're ready to @@ -106,8 +105,7 @@ with tf.name_scope("test") as scope: # Merge all the summaries and write them out to /tmp/mnist_logs merged = tf.merge_all_summaries() -writer = tf.train.SummaryWriter("/tmp/mnist_logs", - sess.graph.as_graph_def(add_shapes=True)) +writer = tf.train.SummaryWriter("/tmp/mnist_logs", sess.graph) tf.initialize_all_variables().run() # Train the model, and feed in test data and record summaries every 10 steps diff --git a/tensorflow/models/embedding/word2vec.py b/tensorflow/models/embedding/word2vec.py index 196d15e9e4..cf30548e14 100644 --- a/tensorflow/models/embedding/word2vec.py +++ b/tensorflow/models/embedding/word2vec.py @@ -395,8 +395,7 @@ class Word2Vec(object): initial_epoch, initial_words = self._session.run([self._epoch, self._words]) summary_op = tf.merge_all_summaries() - summary_writer = tf.train.SummaryWriter(opts.save_path, - graph_def=self._session.graph_def) + summary_writer = tf.train.SummaryWriter(opts.save_path, self._session.graph) workers = [] for _ in xrange(opts.concurrent_steps): t = threading.Thread(target=self._train_thread_body) diff --git a/tensorflow/models/image/cifar10/cifar10_eval.py b/tensorflow/models/image/cifar10/cifar10_eval.py index 57eb94c3c9..e94cc3f2f5 100644 --- a/tensorflow/models/image/cifar10/cifar10_eval.py +++ b/tensorflow/models/image/cifar10/cifar10_eval.py @@ -115,7 +115,7 @@ def eval_once(saver, summary_writer, top_k_op, summary_op): def evaluate(): """Eval CIFAR-10 for a number of steps.""" - with tf.Graph().as_default(): + with tf.Graph().as_default() as g: # Get images and labels for CIFAR-10. eval_data = FLAGS.eval_data == 'test' images, labels = cifar10.inputs(eval_data=eval_data) @@ -136,9 +136,7 @@ def evaluate(): # Build the summary operation based on the TF collection of Summaries. summary_op = tf.merge_all_summaries() - graph_def = tf.get_default_graph().as_graph_def(add_shapes=True) - summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, - graph_def=graph_def) + summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, g) while True: eval_once(saver, summary_writer, top_k_op, summary_op) diff --git a/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py index 3e3d21cb53..da4565ff6c 100644 --- a/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py +++ b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py @@ -239,9 +239,7 @@ def train(): # Start the queue runners. tf.train.start_queue_runners(sess=sess) - graph_def = sess.graph.as_graph_def(add_shapes=True) - summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, - graph_def=graph_def) + summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) for step in xrange(FLAGS.max_steps): start_time = time.time() diff --git a/tensorflow/models/image/cifar10/cifar10_train.py b/tensorflow/models/image/cifar10/cifar10_train.py index cf59894a81..a224ecbf2e 100644 --- a/tensorflow/models/image/cifar10/cifar10_train.py +++ b/tensorflow/models/image/cifar10/cifar10_train.py @@ -93,9 +93,7 @@ def train(): # Start the queue runners. tf.train.start_queue_runners(sess=sess) - graph_def = sess.graph.as_graph_def(add_shapes=True) - summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, - graph_def=graph_def) + summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) for step in xrange(FLAGS.max_steps): start_time = time.time() diff --git a/tensorflow/python/summary/event_accumulator_test.py b/tensorflow/python/summary/event_accumulator_test.py index 2050c346cd..8dafaad0db 100644 --- a/tensorflow/python/summary/event_accumulator_test.py +++ b/tensorflow/python/summary/event_accumulator_test.py @@ -596,9 +596,11 @@ class RealisticEventAccumulatorTest(EventAccumulatorTest): gfile.MkDir(directory) writer = tf.train.SummaryWriter(directory, max_queue=100) - graph_def = tf.GraphDef(node=[tf.NodeDef(name='A', op='Mul')]) + + with tf.Graph().as_default() as graph: + _ = tf.constant([2.0, 1.0]) # Add a graph to the summary writer. - writer.add_graph(graph_def) + writer.add_graph(graph) run_metadata = tf.RunMetadata() device_stats = run_metadata.step_stats.dev_stats.add() @@ -651,7 +653,7 @@ class RealisticEventAccumulatorTest(EventAccumulatorTest): self.assertEqual(i * 5, sq_events[i].step) self.assertEqual(i, id_events[i].value) self.assertEqual(i * i, sq_events[i].value) - self.assertProtoEquals(graph_def, acc.Graph()) + self.assertProtoEquals(graph.as_graph_def(add_shapes=True), acc.Graph()) if __name__ == '__main__': diff --git a/tensorflow/python/training/summary_writer_test.py b/tensorflow/python/training/summary_writer_test.py index fcc095020d..79a7ecef24 100644 --- a/tensorflow/python/training/summary_writer_test.py +++ b/tensorflow/python/training/summary_writer_test.py @@ -81,8 +81,7 @@ class SummaryWriterTestCase(tf.test.TestCase): 20) with tf.Graph().as_default() as g: tf.constant([0], name="zero") - gd = g.as_graph_def() - sw.add_graph(gd, global_step=30) + sw.add_graph(g, global_step=30) run_metadata = tf.RunMetadata() device_stats = run_metadata.step_stats.dev_stats.add() @@ -124,7 +123,7 @@ class SummaryWriterTestCase(tf.test.TestCase): self.assertEquals(30, ev.step) ev_graph = tf.GraphDef() ev_graph.ParseFromString(ev.graph_def) - self.assertProtoEquals(gd, ev_graph) + self.assertProtoEquals(g.as_graph_def(add_shapes=True), ev_graph) # The next event should have metadata for the run. ev = next(rr) |