aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/tutorials/word2vec/word2vec_basic.py')
-rw-r--r--tensorflow/examples/tutorials/word2vec/word2vec_basic.py104
1 files changed, 85 insertions, 19 deletions
diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
index 87cd95165e..7d1650f05e 100644
--- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
+++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
@@ -21,6 +21,8 @@ from __future__ import print_function
import collections
import math
import os
+import sys
+import argparse
import random
from tempfile import gettempdir
import zipfile
@@ -30,6 +32,24 @@ from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
+from tensorflow.contrib.tensorboard.plugins import projector
+
+# Give a folder path as an argument with '--log_dir' to save
+# TensorBoard summaries. Default is a log folder in current directory.
+current_path = os.path.dirname(os.path.realpath(sys.argv[0]))
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ '--log_dir',
+ type=str,
+ default=os.path.join(current_path, 'log'),
+ help='The log directory for TensorBoard summaries.')
+FLAGS, unparsed = parser.parse_known_args()
+
+# Create the directory for TensorBoard variables if there is not.
+if not os.path.exists(FLAGS.log_dir):
+ os.makedirs(FLAGS.log_dir)
+
# Step 1: Download the data.
url = 'http://mattmahoney.net/dc/'
@@ -156,38 +176,47 @@ graph = tf.Graph()
with graph.as_default():
# Input data.
- train_inputs = tf.placeholder(tf.int32, shape=[batch_size])
- train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
- valid_dataset = tf.constant(valid_examples, dtype=tf.int32)
+ with tf.name_scope('inputs'):
+ train_inputs = tf.placeholder(tf.int32, shape=[batch_size])
+ train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
+ valid_dataset = tf.constant(valid_examples, dtype=tf.int32)
# Ops and variables pinned to the CPU because of missing GPU implementation
with tf.device('/cpu:0'):
# Look up embeddings for inputs.
- embeddings = tf.Variable(
- tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
- embed = tf.nn.embedding_lookup(embeddings, train_inputs)
+ with tf.name_scope('embeddings'):
+ embeddings = tf.Variable(
+ tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
+ embed = tf.nn.embedding_lookup(embeddings, train_inputs)
# Construct the variables for the NCE loss
- nce_weights = tf.Variable(
- tf.truncated_normal([vocabulary_size, embedding_size],
- stddev=1.0 / math.sqrt(embedding_size)))
- nce_biases = tf.Variable(tf.zeros([vocabulary_size]))
+ with tf.name_scope('weights'):
+ nce_weights = tf.Variable(
+ tf.truncated_normal([vocabulary_size, embedding_size],
+ stddev=1.0 / math.sqrt(embedding_size)))
+ with tf.name_scope('biases'):
+ nce_biases = tf.Variable(tf.zeros([vocabulary_size]))
# Compute the average NCE loss for the batch.
# tf.nce_loss automatically draws a new sample of the negative labels each
# time we evaluate the loss.
# Explanation of the meaning of NCE loss:
# http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/
- loss = tf.reduce_mean(
- tf.nn.nce_loss(weights=nce_weights,
- biases=nce_biases,
- labels=train_labels,
- inputs=embed,
- num_sampled=num_sampled,
- num_classes=vocabulary_size))
+ with tf.name_scope('loss'):
+ loss = tf.reduce_mean(
+ tf.nn.nce_loss(weights=nce_weights,
+ biases=nce_biases,
+ labels=train_labels,
+ inputs=embed,
+ num_sampled=num_sampled,
+ num_classes=vocabulary_size))
+
+ # Add the loss value as a scalar to summary.
+ tf.summary.scalar('loss', loss)
# Construct the SGD optimizer using a learning rate of 1.0.
- optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss)
+ with tf.name_scope('optimizer'):
+ optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss)
# Compute the cosine similarity between minibatch examples and all embeddings.
norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))
@@ -197,13 +226,22 @@ with graph.as_default():
similarity = tf.matmul(
valid_embeddings, normalized_embeddings, transpose_b=True)
+ # Merge all summaries.
+ merged = tf.summary.merge_all()
+
# Add variable initializer.
init = tf.global_variables_initializer()
+ # Create a saver.
+ saver = tf.train.Saver()
+
# Step 5: Begin training.
num_steps = 100001
with tf.Session(graph=graph) as session:
+ # Open a writer to write summaries.
+ writer = tf.summary.FileWriter(FLAGS.log_dir, session.graph)
+
# We must initialize all variables before we use them.
init.run()
print('Initialized')
@@ -214,10 +252,21 @@ with tf.Session(graph=graph) as session:
batch_size, num_skips, skip_window)
feed_dict = {train_inputs: batch_inputs, train_labels: batch_labels}
+ # Define metadata variable.
+ run_metadata = tf.RunMetadata()
+
# We perform one update step by evaluating the optimizer op (including it
# in the list of returned values for session.run()
- _, loss_val = session.run([optimizer, loss], feed_dict=feed_dict)
+ # Also, evaluate the merged op to get all summaries from the returned "summary" variable.
+ # Feed metadata variable to session for visualizing the graph in TensorBoard.
+ _, summary, loss_val = session.run([optimizer, merged, loss], feed_dict=feed_dict, run_metadata=run_metadata)
average_loss += loss_val
+
+ # Add returned summaries to writer in each step.
+ writer.add_summary(summary, step)
+ # Add metadata to visualize the graph for the last run.
+ if step == (num_steps - 1):
+ writer.add_run_metadata(run_metadata, 'step%d' % step)
if step % 2000 == 0:
if step > 0:
@@ -240,6 +289,23 @@ with tf.Session(graph=graph) as session:
print(log_str)
final_embeddings = normalized_embeddings.eval()
+ # Write corresponding labels for the embeddings.
+ with open(FLAGS.log_dir + '/metadata.tsv', 'w') as f:
+ for i in xrange(vocabulary_size):
+ f.write(reverse_dictionary[i] + '\n')
+
+ # Save the model for checkpoints.
+ saver.save(session, os.path.join(FLAGS.log_dir, "model.ckpt"))
+
+ # Create a configuration for visualizing embeddings with the labels in TensorBoard.
+ config = projector.ProjectorConfig()
+ embedding_conf = config.embeddings.add()
+ embedding_conf.tensor_name = embeddings.name
+ embedding_conf.metadata_path = os.path.join(FLAGS.log_dir, 'metadata.tsv')
+ projector.visualize_embeddings(writer, config)
+
+writer.close()
+
# Step 6: Visualize the embeddings.