diff options
Diffstat (limited to 'tensorflow/examples/tutorials/word2vec/word2vec_basic.py')
-rw-r--r-- | tensorflow/examples/tutorials/word2vec/word2vec_basic.py | 104 |
1 files changed, 85 insertions, 19 deletions
diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py index 87cd95165e..7d1650f05e 100644 --- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py +++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py @@ -21,6 +21,8 @@ from __future__ import print_function import collections import math import os +import sys +import argparse import random from tempfile import gettempdir import zipfile @@ -30,6 +32,24 @@ from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf +from tensorflow.contrib.tensorboard.plugins import projector + +# Give a folder path as an argument with '--log_dir' to save +# TensorBoard summaries. Default is a log folder in current directory. +current_path = os.path.dirname(os.path.realpath(sys.argv[0])) + +parser = argparse.ArgumentParser() +parser.add_argument( + '--log_dir', + type=str, + default=os.path.join(current_path, 'log'), + help='The log directory for TensorBoard summaries.') +FLAGS, unparsed = parser.parse_known_args() + +# Create the directory for TensorBoard variables if there is not. +if not os.path.exists(FLAGS.log_dir): + os.makedirs(FLAGS.log_dir) + # Step 1: Download the data. url = 'http://mattmahoney.net/dc/' @@ -156,38 +176,47 @@ graph = tf.Graph() with graph.as_default(): # Input data. - train_inputs = tf.placeholder(tf.int32, shape=[batch_size]) - train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1]) - valid_dataset = tf.constant(valid_examples, dtype=tf.int32) + with tf.name_scope('inputs'): + train_inputs = tf.placeholder(tf.int32, shape=[batch_size]) + train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1]) + valid_dataset = tf.constant(valid_examples, dtype=tf.int32) # Ops and variables pinned to the CPU because of missing GPU implementation with tf.device('/cpu:0'): # Look up embeddings for inputs. - embeddings = tf.Variable( - tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0)) - embed = tf.nn.embedding_lookup(embeddings, train_inputs) + with tf.name_scope('embeddings'): + embeddings = tf.Variable( + tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0)) + embed = tf.nn.embedding_lookup(embeddings, train_inputs) # Construct the variables for the NCE loss - nce_weights = tf.Variable( - tf.truncated_normal([vocabulary_size, embedding_size], - stddev=1.0 / math.sqrt(embedding_size))) - nce_biases = tf.Variable(tf.zeros([vocabulary_size])) + with tf.name_scope('weights'): + nce_weights = tf.Variable( + tf.truncated_normal([vocabulary_size, embedding_size], + stddev=1.0 / math.sqrt(embedding_size))) + with tf.name_scope('biases'): + nce_biases = tf.Variable(tf.zeros([vocabulary_size])) # Compute the average NCE loss for the batch. # tf.nce_loss automatically draws a new sample of the negative labels each # time we evaluate the loss. # Explanation of the meaning of NCE loss: # http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/ - loss = tf.reduce_mean( - tf.nn.nce_loss(weights=nce_weights, - biases=nce_biases, - labels=train_labels, - inputs=embed, - num_sampled=num_sampled, - num_classes=vocabulary_size)) + with tf.name_scope('loss'): + loss = tf.reduce_mean( + tf.nn.nce_loss(weights=nce_weights, + biases=nce_biases, + labels=train_labels, + inputs=embed, + num_sampled=num_sampled, + num_classes=vocabulary_size)) + + # Add the loss value as a scalar to summary. + tf.summary.scalar('loss', loss) # Construct the SGD optimizer using a learning rate of 1.0. - optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss) + with tf.name_scope('optimizer'): + optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss) # Compute the cosine similarity between minibatch examples and all embeddings. norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True)) @@ -197,13 +226,22 @@ with graph.as_default(): similarity = tf.matmul( valid_embeddings, normalized_embeddings, transpose_b=True) + # Merge all summaries. + merged = tf.summary.merge_all() + # Add variable initializer. init = tf.global_variables_initializer() + # Create a saver. + saver = tf.train.Saver() + # Step 5: Begin training. num_steps = 100001 with tf.Session(graph=graph) as session: + # Open a writer to write summaries. + writer = tf.summary.FileWriter(FLAGS.log_dir, session.graph) + # We must initialize all variables before we use them. init.run() print('Initialized') @@ -214,10 +252,21 @@ with tf.Session(graph=graph) as session: batch_size, num_skips, skip_window) feed_dict = {train_inputs: batch_inputs, train_labels: batch_labels} + # Define metadata variable. + run_metadata = tf.RunMetadata() + # We perform one update step by evaluating the optimizer op (including it # in the list of returned values for session.run() - _, loss_val = session.run([optimizer, loss], feed_dict=feed_dict) + # Also, evaluate the merged op to get all summaries from the returned "summary" variable. + # Feed metadata variable to session for visualizing the graph in TensorBoard. + _, summary, loss_val = session.run([optimizer, merged, loss], feed_dict=feed_dict, run_metadata=run_metadata) average_loss += loss_val + + # Add returned summaries to writer in each step. + writer.add_summary(summary, step) + # Add metadata to visualize the graph for the last run. + if step == (num_steps - 1): + writer.add_run_metadata(run_metadata, 'step%d' % step) if step % 2000 == 0: if step > 0: @@ -240,6 +289,23 @@ with tf.Session(graph=graph) as session: print(log_str) final_embeddings = normalized_embeddings.eval() + # Write corresponding labels for the embeddings. + with open(FLAGS.log_dir + '/metadata.tsv', 'w') as f: + for i in xrange(vocabulary_size): + f.write(reverse_dictionary[i] + '\n') + + # Save the model for checkpoints. + saver.save(session, os.path.join(FLAGS.log_dir, "model.ckpt")) + + # Create a configuration for visualizing embeddings with the labels in TensorBoard. + config = projector.ProjectorConfig() + embedding_conf = config.embeddings.add() + embedding_conf.tensor_name = embeddings.name + embedding_conf.metadata_path = os.path.join(FLAGS.log_dir, 'metadata.tsv') + projector.visualize_embeddings(writer, config) + +writer.close() + # Step 6: Visualize the embeddings. |