diff options
-rw-r--r-- | tensorflow/examples/tutorials/mnist/fully_connected_feed.py | 14 | ||||
-rw-r--r-- | tensorflow/examples/tutorials/word2vec/word2vec_basic.py | 5 |
2 files changed, 13 insertions, 6 deletions
diff --git a/tensorflow/examples/tutorials/mnist/fully_connected_feed.py b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py index a67055f88f..a8b04d24d0 100644 --- a/tensorflow/examples/tutorials/mnist/fully_connected_feed.py +++ b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py @@ -150,20 +150,24 @@ def run_training(): # Build the summary operation based on the TF collection of Summaries. summary_op = tf.merge_all_summaries() + # Add the variable initializer Op. + init = tf.initialize_all_variables() + # Create a saver for writing training checkpoints. saver = tf.train.Saver() # Create a session for running Ops on the Graph. sess = tf.Session() - # Run the Op to initialize the variables. - init = tf.initialize_all_variables() - sess.run(init) - # Instantiate a SummaryWriter to output summaries and the Graph. summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) - # And then after everything is built, start the training loop. + # And then after everything is built: + + # Run the Op to initialize the variables. + sess.run(init) + + # Start the training loop. for step in xrange(FLAGS.max_steps): start_time = time.time() diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py index 83bb5dd165..9f3a03e352 100644 --- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py +++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py @@ -171,12 +171,15 @@ with graph.as_default(): similarity = tf.matmul( valid_embeddings, normalized_embeddings, transpose_b=True) + # Add variable initializer. + init = tf.initialize_all_variables() + # Step 5: Begin training. num_steps = 100001 with tf.Session(graph=graph) as session: # We must initialize all variables before we use them. - tf.initialize_all_variables().run() + init.run() print("Initialized") average_loss = 0 |