aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/models/embedding
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/models/embedding')
-rw-r--r--tensorflow/models/embedding/BUILD74
-rwxr-xr-xtensorflow/models/embedding/__init__.py0
-rw-r--r--tensorflow/models/embedding/word2vec.py503
-rw-r--r--tensorflow/models/embedding/word2vec_kernels.cc287
-rw-r--r--tensorflow/models/embedding/word2vec_ops.cc56
-rw-r--r--tensorflow/models/embedding/word2vec_optimized.py405
6 files changed, 1325 insertions, 0 deletions
diff --git a/tensorflow/models/embedding/BUILD b/tensorflow/models/embedding/BUILD
new file mode 100644
index 0000000000..0fb164b05e
--- /dev/null
+++ b/tensorflow/models/embedding/BUILD
@@ -0,0 +1,74 @@
+# Description:
+# TensorFlow model for word2vec
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("/tensorflow/tensorflow", "tf_gen_op_wrapper_py")
+
+py_binary(
+ name = "word2vec",
+ srcs = [
+ "word2vec.py",
+ ],
+ deps = [
+ ":gen_word2vec",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:platform",
+ ],
+)
+
+py_binary(
+ name = "word2vec_optimized",
+ srcs = [
+ "word2vec_optimized.py",
+ ],
+ deps = [
+ ":gen_word2vec",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:platform",
+ ],
+)
+
+cc_library(
+ name = "word2vec_ops",
+ srcs = [
+ "word2vec_ops.cc",
+ ],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/core:framework",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "word2vec_kernels",
+ srcs = [
+ "word2vec_kernels.cc",
+ ],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/core",
+ ],
+ alwayslink = 1,
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_word2vec",
+ out = "gen_word2vec.py",
+ deps = [":word2vec_ops"],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/models/embedding/__init__.py b/tensorflow/models/embedding/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/models/embedding/__init__.py
diff --git a/tensorflow/models/embedding/word2vec.py b/tensorflow/models/embedding/word2vec.py
new file mode 100644
index 0000000000..4ebf3d6f27
--- /dev/null
+++ b/tensorflow/models/embedding/word2vec.py
@@ -0,0 +1,503 @@
+"""Multi-threaded word2vec mini-batched skip-gram model.
+
+Trains the model described in:
+(Mikolov, et. al.) Efficient Estimation of Word Representations in Vector Space
+ICLR 2013.
+http://arxiv.org/abs/1301.3781
+This model does traditional minibatching.
+
+The key ops used are:
+* placeholder for feeding in tensors for each example.
+* embedding_lookup for fetching rows from the embedding matrix.
+* sigmoid_cross_entropy_with_logits to calculate the loss.
+* GradientDescentOptimizer for optimizing the loss.
+* skipgram custom op that does input processing.
+"""
+
+import sys
+import threading
+import time
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.models.embedding import gen_word2vec as word2vec
+
+flags = tf.app.flags
+
+flags.DEFINE_string("save_path", None, "Directory to write the model and "
+ "training summaries.")
+flags.DEFINE_string("train_data", None, "Training text file. "
+ "E.g., unzipped file http://mattmahoney.net/dc/text8.zip.")
+flags.DEFINE_string(
+ "eval_data", None, "File consisting of analogies of four tokens."
+ "embedding 2 - embedding 1 + embedding 3 should be close "
+ "to embedding 4."
+ "E.g. https://word2vec.googlecode.com/svn/trunk/questions-words.txt.")
+flags.DEFINE_integer("embedding_size", 200, "The embedding dimension size.")
+flags.DEFINE_integer(
+ "epochs_to_train", 15,
+ "Number of epochs to train. Each epoch processes the training data once "
+ "completely.")
+flags.DEFINE_float("learning_rate", 0.2, "Initial learning rate.")
+flags.DEFINE_integer("num_neg_samples", 100,
+ "Negative samples per training example.")
+flags.DEFINE_integer("batch_size", 16,
+ "Number of training examples processed per step "
+ "(size of a minibatch).")
+flags.DEFINE_integer("concurrent_steps", 12,
+ "The number of concurrent training steps.")
+flags.DEFINE_integer("window_size", 5,
+ "The number of words to predict to the left and right "
+ "of the target word.")
+flags.DEFINE_integer("min_count", 5,
+ "The minimum number of word occurrences for it to be "
+ "included in the vocabulary.")
+flags.DEFINE_float("subsample", 1e-3,
+ "Subsample threshold for word occurrence. Words that appear "
+ "with higher frequency will be randomly down-sampled. Set "
+ "to 0 to disable.")
+flags.DEFINE_boolean(
+ "interactive", False,
+ "If true, enters an IPython interactive session to play with the trained "
+ "model. E.g., try model.analogy('france', 'paris', 'russia') and "
+ "model.nearby(['proton', 'elephant', 'maxwell']")
+flags.DEFINE_integer("statistics_interval", 5,
+ "Print statistics every n seconds.")
+flags.DEFINE_integer("summary_interval", 5,
+ "Save training summary to file every n seconds (rounded "
+ "up to statistics interval.")
+flags.DEFINE_integer("checkpoint_interval", 600,
+ "Checkpoint the model (i.e. save the parameters) every n "
+ "seconds (rounded up to statistics interval.")
+
+FLAGS = flags.FLAGS
+
+
+class Options(object):
+ """Options used by our word2vec model."""
+
+ def __init__(self):
+ # Model options.
+
+ # Embedding dimension.
+ self.emb_dim = FLAGS.embedding_size
+
+ # Training options.
+ # The training text file.
+ self.train_data = FLAGS.train_data
+
+ # Number of negative samples per example.
+ self.num_samples = FLAGS.num_neg_samples
+
+ # The initial learning rate.
+ self.learning_rate = FLAGS.learning_rate
+
+ # Number of epochs to train. After these many epochs, the learning
+ # rate decays linearly to zero and the training stops.
+ self.epochs_to_train = FLAGS.epochs_to_train
+
+ # Concurrent training steps.
+ self.concurrent_steps = FLAGS.concurrent_steps
+
+ # Number of examples for one training step.
+ self.batch_size = FLAGS.batch_size
+
+ # The number of words to predict to the left and right of the target word.
+ self.window_size = FLAGS.window_size
+
+ # The minimum number of word occurrences for it to be included in the
+ # vocabulary.
+ self.min_count = FLAGS.min_count
+
+ # Subsampling threshold for word occurrence.
+ self.subsample = FLAGS.subsample
+
+ # How often to print statistics.
+ self.statistics_interval = FLAGS.statistics_interval
+
+ # How often to write to the summary file (rounds up to the nearest
+ # statistics_interval).
+ self.summary_interval = FLAGS.summary_interval
+
+ # How often to write checkpoints (rounds up to the nearest statistics
+ # interval).
+ self.checkpoint_interval = FLAGS.checkpoint_interval
+
+ # Where to write out summaries.
+ self.save_path = FLAGS.save_path
+
+ # Eval options.
+ # The text file for eval.
+ self.eval_data = FLAGS.eval_data
+
+
+class Word2Vec(object):
+ """Word2Vec model (Skipgram)."""
+
+ def __init__(self, options, session):
+ self._options = options
+ self._session = session
+ self._word2id = {}
+ self._id2word = []
+ self.build_graph()
+ self.build_eval_graph()
+ self.save_vocab()
+ self._read_analogies()
+
+ def _read_analogies(self):
+ """Reads through the analogy question file.
+
+ Returns:
+ questions: a [n, 4] numpy array containing the analogy question's
+ word ids.
+ questions_skipped: questions skipped due to unknown words.
+ """
+ questions = []
+ questions_skipped = 0
+ with open(self._options.eval_data) as analogy_f:
+ for line in analogy_f:
+ if line.startswith(":"): # Skip comments.
+ continue
+ words = line.strip().lower().split(" ")
+ ids = [self._word2id.get(w.strip()) for w in words]
+ if None in ids or len(ids) != 4:
+ questions_skipped += 1
+ else:
+ questions.append(np.array(ids))
+ print "Eval analogy file: ", self._options.eval_data
+ print "Questions: ", len(questions)
+ print "Skipped: ", questions_skipped
+ self._analogy_questions = np.array(questions, dtype=np.int32)
+
+ def forward(self, examples, labels):
+ """Build the graph for the forward pass."""
+ opts = self._options
+
+ # Declare all variables we need.
+ # Embedding: [vocab_size, emb_dim]
+ init_width = 0.5 / opts.emb_dim
+ emb = tf.Variable(
+ tf.random_uniform(
+ [opts.vocab_size, opts.emb_dim], -init_width, init_width),
+ name="emb")
+ self._emb = emb
+
+ # Softmax weight: [vocab_size, emb_dim]. Transposed.
+ sm_w_t = tf.Variable(
+ tf.zeros([opts.vocab_size, opts.emb_dim]),
+ name="sm_w_t")
+
+ # Softmax bias: [emb_dim].
+ sm_b = tf.Variable(tf.zeros([opts.vocab_size]), name="sm_b")
+
+ # Global step: scalar, i.e., shape [].
+ self.global_step = tf.Variable(0, name="global_step")
+
+ # Nodes to compute the nce loss w/ candidate sampling.
+ labels_matrix = tf.reshape(
+ tf.cast(labels,
+ dtype=tf.int64),
+ [opts.batch_size, 1])
+
+ # Negative sampling.
+ sampled_ids, _, _ = (tf.nn.fixed_unigram_candidate_sampler(
+ true_classes=labels_matrix,
+ num_true=1,
+ num_sampled=opts.num_samples,
+ unique=True,
+ range_max=opts.vocab_size,
+ distortion=0.75,
+ unigrams=opts.vocab_counts.tolist()))
+
+ # Embeddings for examples: [batch_size, emb_dim]
+ example_emb = tf.nn.embedding_lookup(emb, examples)
+
+ # Weights for labels: [batch_size, emb_dim]
+ true_w = tf.nn.embedding_lookup(sm_w_t, labels)
+ # Biases for labels: [batch_size, 1]
+ true_b = tf.nn.embedding_lookup(sm_b, labels)
+
+ # Weights for sampled ids: [num_sampled, emb_dim]
+ sampled_w = tf.nn.embedding_lookup(sm_w_t, sampled_ids)
+ # Biases for sampled ids: [num_sampled, 1]
+ sampled_b = tf.nn.embedding_lookup(sm_b, sampled_ids)
+
+ # True logits: [batch_size, 1]
+ true_logits = tf.reduce_sum(tf.mul(example_emb, true_w), 1) + true_b
+
+ # Sampled logits: [batch_size, num_sampled]
+ # We replicate sampled noise lables for all examples in the batch
+ # using the matmul.
+ sampled_b_vec = tf.reshape(sampled_b, [opts.num_samples])
+ sampled_logits = tf.matmul(example_emb,
+ sampled_w,
+ transpose_b=True) + sampled_b_vec
+ return true_logits, sampled_logits
+
+ def nce_loss(self, true_logits, sampled_logits):
+ """Build the graph for the NCE loss."""
+
+ # cross-entropy(logits, labels)
+ opts = self._options
+ true_xent = tf.nn.sigmoid_cross_entropy_with_logits(
+ true_logits, tf.ones_like(true_logits))
+ sampled_xent = tf.nn.sigmoid_cross_entropy_with_logits(
+ sampled_logits, tf.zeros_like(sampled_logits))
+
+ # NCE-loss is the sum of the true and noise (sampled words)
+ # contributions, averaged over the batch.
+ nce_loss_tensor = (tf.reduce_sum(true_xent) +
+ tf.reduce_sum(sampled_xent)) / opts.batch_size
+ return nce_loss_tensor
+
+ def optimize(self, loss):
+ """Build the graph to optimize the loss function."""
+
+ # Optimizer nodes.
+ # Linear learning rate decay.
+ opts = self._options
+ words_to_train = float(opts.words_per_epoch * opts.epochs_to_train)
+ lr = opts.learning_rate * tf.maximum(
+ 0.0001, 1.0 - tf.cast(self._words, tf.float32) / words_to_train)
+ self._lr = lr
+ optimizer = tf.train.GradientDescentOptimizer(lr)
+ train = optimizer.minimize(loss,
+ global_step=self.global_step,
+ gate_gradients=optimizer.GATE_NONE)
+ self._train = train
+
+ def build_eval_graph(self):
+ """Build the eval graph."""
+ # Eval graph
+
+ # Each analogy task is to predict the 4th word (d) given three
+ # words: a, b, c. E.g., a=italy, b=rome, c=france, we should
+ # predict d=paris.
+
+ # The eval feeds three vectors of word ids for a, b, c, each of
+ # which is of size N, where N is the number of analogies we want to
+ # evaluate in one batch.
+ analogy_a = tf.placeholder(dtype=tf.int32) # [N]
+ analogy_b = tf.placeholder(dtype=tf.int32) # [N]
+ analogy_c = tf.placeholder(dtype=tf.int32) # [N]
+
+ # Normalized word embeddings of shape [vocab_size, emb_dim].
+ nemb = tf.nn.l2_normalize(self._emb, 1)
+
+ # Each row of a_emb, b_emb, c_emb is a word's embedding vector.
+ # They all have the shape [N, emb_dim]
+ a_emb = tf.gather(nemb, analogy_a) # a's embs
+ b_emb = tf.gather(nemb, analogy_b) # b's embs
+ c_emb = tf.gather(nemb, analogy_c) # c's embs
+
+ # We expect that d's embedding vectors on the unit hyper-sphere is
+ # near: c_emb + (b_emb - a_emb), which has the shape [N, emb_dim].
+ target = c_emb + (b_emb - a_emb)
+
+ # Compute cosine distance between each pair of target and vocab.
+ # dist has shape [N, vocab_size].
+ dist = tf.matmul(target, nemb, transpose_b=True)
+
+ # For each question (row in dist), find the top 4 words.
+ _, pred_idx = tf.nn.top_k(dist, 4)
+
+ # Nodes for computing neighbors for a given word according to
+ # their cosine distance.
+ nearby_word = tf.placeholder(dtype=tf.int32) # word id
+ nearby_emb = tf.gather(nemb, nearby_word)
+ nearby_dist = tf.matmul(nearby_emb, nemb, transpose_b=True)
+ nearby_val, nearby_idx = tf.nn.top_k(nearby_dist,
+ min(1000, self._options.vocab_size))
+
+ # Nodes in the construct graph which are used by training and
+ # evaluation to run/feed/fetch.
+ self._analogy_a = analogy_a
+ self._analogy_b = analogy_b
+ self._analogy_c = analogy_c
+ self._analogy_pred_idx = pred_idx
+ self._nearby_word = nearby_word
+ self._nearby_val = nearby_val
+ self._nearby_idx = nearby_idx
+
+ def build_graph(self):
+ """Build the graph for the full model."""
+ opts = self._options
+ # The training data. A text file.
+ (words, counts, words_per_epoch, self._epoch, self._words, examples,
+ labels) = word2vec.skipgram(filename=opts.train_data,
+ batch_size=opts.batch_size,
+ window_size=opts.window_size,
+ min_count=opts.min_count,
+ subsample=opts.subsample)
+ (opts.vocab_words, opts.vocab_counts,
+ opts.words_per_epoch) = self._session.run([words, counts, words_per_epoch])
+ opts.vocab_size = len(opts.vocab_words)
+ print "Data file: ", opts.train_data
+ print "Vocab size: ", opts.vocab_size - 1, " + UNK"
+ print "Words per epoch: ", opts.words_per_epoch
+ self._examples = examples
+ self._labels = labels
+ self._id2word = opts.vocab_words
+ for i, w in enumerate(self._id2word):
+ self._word2id[w] = i
+ true_logits, sampled_logits = self.forward(examples, labels)
+ loss = self.nce_loss(true_logits, sampled_logits)
+ tf.scalar_summary("NCE loss", loss)
+ self._loss = loss
+ self.optimize(loss)
+
+ # Properly initialize all variables.
+ tf.initialize_all_variables().run()
+
+ self.saver = tf.train.Saver()
+
+ def save_vocab(self):
+ """Save the vocabulary to a file so the model can be reloaded."""
+ opts = self._options
+ with open(opts.save_path + "/vocab.txt", "w") as f:
+ for i in xrange(opts.vocab_size):
+ f.write(opts.vocab_words[i] + " " + str(opts.vocab_counts[i]) + "\n")
+
+ def _train_thread_body(self):
+ initial_epoch, = self._session.run([self._epoch])
+ while True:
+ _, epoch = self._session.run([self._train, self._epoch])
+ if epoch != initial_epoch:
+ break
+
+ def train(self):
+ """Train the model."""
+ opts = self._options
+
+ initial_epoch, initial_words = self._session.run([self._epoch, self._words])
+
+ summary_op = tf.merge_all_summaries()
+ summary_writer = tf.train.SummaryWriter(opts.save_path,
+ graph_def=self._session.graph_def)
+ workers = []
+ for _ in xrange(opts.concurrent_steps):
+ t = threading.Thread(target=self._train_thread_body)
+ t.start()
+ workers.append(t)
+
+ last_words, last_time, last_summary_time = initial_words, time.time(), 0
+ last_checkpoint_time = 0
+ while True:
+ time.sleep(opts.statistics_interval) # Reports our progress once a while.
+ (epoch, step, loss, words, lr) = self._session.run(
+ [self._epoch, self.global_step, self._loss, self._words, self._lr])
+ now = time.time()
+ last_words, last_time, rate = words, now, (words - last_words) / (
+ now - last_time)
+ print("Epoch %4d Step %8d: lr = %5.3f loss = %6.2f words/sec = %8.0f\r" %
+ (epoch, step, lr, loss, rate)),
+ sys.stdout.flush()
+ if now - last_summary_time > opts.summary_interval:
+ summary_str = self._session.run(summary_op)
+ summary_writer.add_summary(summary_str, step)
+ last_summary_time = now
+ if now - last_checkpoint_time > opts.checkpoint_interval:
+ self.saver.save(self._session,
+ opts.save_path + "model",
+ global_step=step)
+ last_checkpoint_time = now
+ if epoch != initial_epoch:
+ break
+
+ for t in workers:
+ t.join()
+
+ return epoch
+
+ def _predict(self, analogy):
+ """Predict the top 4 answers for analogy questions."""
+ idx, = self._session.run([self._analogy_pred_idx], {
+ self._analogy_a: analogy[:, 0],
+ self._analogy_b: analogy[:, 1],
+ self._analogy_c: analogy[:, 2]
+ })
+ return idx
+
+ def eval(self):
+ """Evaluate analogy questions and reports accuracy."""
+
+ # How many questions we get right at precision@1.
+ correct = 0
+
+ total = self._analogy_questions.shape[0]
+ start = 0
+ while start < total:
+ limit = start + 2500
+ sub = self._analogy_questions[start:limit, :]
+ idx = self._predict(sub)
+ start = limit
+ for question in xrange(sub.shape[0]):
+ for j in xrange(4):
+ if idx[question, j] == sub[question, 3]:
+ # Bingo! We predicted correctly. E.g., [italy, rome, france, paris].
+ correct += 1
+ break
+ elif idx[question, j] in sub[question, :3]:
+ # We need to skip words already in the question.
+ continue
+ else:
+ # The correct label is not the precision@1
+ break
+ print
+ print "Eval %4d/%d accuracy = %4.1f%%" % (correct, total,
+ correct * 100.0 / total)
+
+ def analogy(self, w0, w1, w2):
+ """Predict word w3 as in w0:w1 vs w2:w3."""
+ wid = np.array([[self._word2id.get(w, 0) for w in [w0, w1, w2]]])
+ idx = self._predict(wid)
+ for c in [self._id2word[i] for i in idx[0, :]]:
+ if c not in [w0, w1, w2]:
+ return c
+ return "unknown"
+
+ def nearby(self, words, num=20):
+ """Prints out nearby words given a list of words."""
+ ids = np.array([self._word2id.get(x, 0) for x in words])
+ vals, idx = self._session.run(
+ [self._nearby_val, self._nearby_idx], {self._nearby_word: ids})
+ for i in xrange(len(words)):
+ print "\n%s\n=====================================" % (words[i])
+ for (neighbor, distance) in zip(idx[i, :num], vals[i, :num]):
+ print "%-20s %6.4f" % (self._id2word[neighbor], distance)
+
+
+def _start_shell(local_ns=None):
+ # An interactive shell is useful for debugging/development.
+ import IPython
+ user_ns = {}
+ if local_ns:
+ user_ns.update(local_ns)
+ user_ns.update(globals())
+ IPython.start_ipython(argv=[], user_ns=user_ns)
+
+
+def main(_):
+ """Train a word2vec model."""
+ opts = Options()
+ with tf.Graph().as_default(), tf.Session() as session:
+ model = Word2Vec(opts, session)
+ for _ in xrange(opts.epochs_to_train):
+ model.train() # Process one epoch
+ model.eval() # Eval analogies.
+ # Perform a final save.
+ model.saver.save(session,
+ opts.save_path + "model",
+ global_step=model.global_step)
+ if FLAGS.interactive:
+ # E.g.,
+ # [0]: model.analogy('france', 'paris', 'russia')
+ # [1]: model.nearby(['proton', 'elephant', 'maxwell'])
+ _start_shell(locals())
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/tensorflow/models/embedding/word2vec_kernels.cc b/tensorflow/models/embedding/word2vec_kernels.cc
new file mode 100644
index 0000000000..f68139fc91
--- /dev/null
+++ b/tensorflow/models/embedding/word2vec_kernels.cc
@@ -0,0 +1,287 @@
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/random/distribution_sampler.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/util/guarded_philox_random.h"
+
+namespace tensorflow {
+
+class SkipgramOp : public OpKernel {
+ public:
+ explicit SkipgramOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), rng_(&philox_) {
+ string filename;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("filename", &filename));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_size", &batch_size_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size", &window_size_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("min_count", &min_count_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("subsample", &subsample_));
+ OP_REQUIRES_OK(ctx, Init(ctx->env(), filename));
+
+ mutex_lock l(mu_);
+ example_pos_ = corpus_size_;
+ label_pos_ = corpus_size_;
+ label_limit_ = corpus_size_;
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ Tensor words_per_epoch(DT_INT64, TensorShape({}));
+ Tensor current_epoch(DT_INT32, TensorShape({}));
+ Tensor total_words_processed(DT_INT64, TensorShape({}));
+ Tensor examples(DT_INT32, TensorShape({batch_size_}));
+ auto Texamples = examples.flat<int32>();
+ Tensor labels(DT_INT32, TensorShape({batch_size_}));
+ auto Tlabels = labels.flat<int32>();
+ {
+ mutex_lock l(mu_);
+ for (int i = 0; i < batch_size_; ++i) {
+ NextExample(&Texamples(i), &Tlabels(i));
+ }
+ words_per_epoch.scalar<int64>()() = corpus_size_;
+ current_epoch.scalar<int32>()() = current_epoch_;
+ total_words_processed.scalar<int64>()() = total_words_processed_;
+ }
+ ctx->set_output(0, word_);
+ ctx->set_output(1, freq_);
+ ctx->set_output(2, words_per_epoch);
+ ctx->set_output(3, current_epoch);
+ ctx->set_output(4, total_words_processed);
+ ctx->set_output(5, examples);
+ ctx->set_output(6, labels);
+ }
+
+ private:
+ int32 batch_size_ = 0;
+ int32 window_size_ = 5;
+ float subsample_ = 1e-3;
+ int min_count_ = 5;
+ int32 vocab_size_ = 0;
+ Tensor word_;
+ Tensor freq_;
+ int32 corpus_size_ = 0;
+ std::vector<int32> corpus_;
+
+ mutex mu_;
+ random::PhiloxRandom philox_ GUARDED_BY(mu_);
+ random::SimplePhilox rng_ GUARDED_BY(mu_);
+ int32 current_epoch_ GUARDED_BY(mu_) = -1;
+ int64 total_words_processed_ GUARDED_BY(mu_) = 0;
+ int32 example_pos_ GUARDED_BY(mu_);
+ int32 label_pos_ GUARDED_BY(mu_);
+ int32 label_limit_ GUARDED_BY(mu_);
+
+ // {example_pos_, label_pos_} is the cursor for the next example.
+ // example_pos_ wrapps around at the end of corpus_. For each
+ // example, we randomly generate [label_pos_, label_limit) for
+ // labels.
+ void NextExample(int32* example, int32* label) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ while (true) {
+ if (label_pos_ >= label_limit_) {
+ if (example_pos_ + 1 >= corpus_size_) {
+ ++current_epoch_;
+ example_pos_ = 0;
+ } else {
+ ++example_pos_;
+ }
+ ++total_words_processed_;
+ int32 word_freq = freq_.flat<int32>()(corpus_[example_pos_]);
+ if (subsample_ > 0) {
+ // See Eq. 5 in http://arxiv.org/abs/1310.4546
+ float keep_prob =
+ (std::sqrt(word_freq / (subsample_ * corpus_size_)) + 1) *
+ (subsample_ * corpus_size_) / word_freq;
+ if (rng_.RandFloat() > keep_prob) continue;
+ }
+ const int32 skip = 1 + rng_.Uniform(window_size_);
+ label_pos_ = std::max<int32>(0, example_pos_ - skip);
+ label_limit_ = std::min<int32>(corpus_size_, example_pos_ + skip + 1);
+ }
+ if (example_pos_ != label_pos_) {
+ break;
+ }
+ ++label_pos_;
+ }
+ *example = corpus_[example_pos_];
+ *label = corpus_[label_pos_++];
+ }
+
+ Status Init(Env* env, const string& filename) {
+ string data;
+ TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &data));
+ RE2 kWord("\\s*(\\S+)");
+ auto input = ToRegexpStringPiece(data);
+ string w;
+ corpus_size_ = 0;
+ std::unordered_map<string, int32> word_freq;
+ while (RE2::Consume(&input, kWord, &w)) {
+ ++(word_freq[w]);
+ ++corpus_size_;
+ }
+ if (corpus_size_ < window_size_ * 10) {
+ return errors::InvalidArgument("The text file ", filename,
+ " contains too little data: ",
+ corpus_size_, " words");
+ }
+ typedef std::pair<string, int32> WordFreq;
+ std::vector<WordFreq> ordered;
+ for (const auto& p : word_freq) {
+ if (p.second >= min_count_) ordered.push_back(p);
+ }
+ LOG(INFO) << "Data file: " << filename << " contains " << data.size()
+ << " bytes, " << corpus_size_ << " words, " << word_freq.size()
+ << " unique words, " << ordered.size()
+ << " unique frequent words.";
+ word_freq.clear();
+ std::sort(ordered.begin(), ordered.end(),
+ [](const WordFreq& x, const WordFreq& y) {
+ return x.second > y.second;
+ });
+ vocab_size_ = static_cast<int32>(1 + ordered.size());
+ Tensor word(DT_STRING, TensorShape({vocab_size_}));
+ Tensor freq(DT_INT32, TensorShape({vocab_size_}));
+ word.flat<string>()(0) = "UNK";
+ static const int32 kUnkId = 0;
+ std::unordered_map<string, int32> word_id;
+ int64 total_counted = 0;
+ for (std::size_t i = 0; i < ordered.size(); ++i) {
+ const auto& w = ordered[i].first;
+ auto id = i + 1;
+ word.flat<string>()(id) = w;
+ auto word_count = ordered[i].second;
+ freq.flat<int32>()(id) = word_count;
+ total_counted += word_count;
+ word_id[w] = id;
+ }
+ freq.flat<int32>()(kUnkId) = corpus_size_ - total_counted;
+ word_ = word;
+ freq_ = freq;
+ corpus_.reserve(corpus_size_);
+ input = ToRegexpStringPiece(data);
+ while (RE2::Consume(&input, kWord, &w)) {
+ corpus_.push_back(gtl::FindWithDefault(word_id, w, kUnkId));
+ }
+ return Status::OK();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("Skipgram").Device(DEVICE_CPU), SkipgramOp);
+
+class NegTrainOp : public OpKernel {
+ public:
+ explicit NegTrainOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ base_.Init(0, 0);
+
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("num_negative_samples", &num_samples_));
+
+ std::vector<int32> vocab_count;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_count", &vocab_count));
+
+ std::vector<float> vocab_weights;
+ vocab_weights.reserve(vocab_count.size());
+ for (const auto& f : vocab_count) {
+ float r = std::pow(static_cast<float>(f), 0.75f);
+ vocab_weights.push_back(r);
+ }
+ sampler_ = new random::DistributionSampler(vocab_weights);
+ }
+
+ ~NegTrainOp() { delete sampler_; }
+
+ void Compute(OpKernelContext* ctx) override {
+ Tensor w_in = ctx->mutable_input(0, false);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(w_in.shape()),
+ errors::InvalidArgument("Must be a matrix"));
+ Tensor w_out = ctx->mutable_input(1, false);
+ OP_REQUIRES(ctx, w_in.shape() == w_out.shape(),
+ errors::InvalidArgument("w_in.shape == w_out.shape"));
+ const Tensor& examples = ctx->input(2);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(examples.shape()),
+ errors::InvalidArgument("Must be a vector"));
+ const Tensor& labels = ctx->input(3);
+ OP_REQUIRES(ctx, examples.shape() == labels.shape(),
+ errors::InvalidArgument("examples.shape == labels.shape"));
+ const Tensor& learning_rate = ctx->input(4);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(learning_rate.shape()),
+ errors::InvalidArgument("Must be a scalar"));
+
+ auto Tw_in = w_in.matrix<float>();
+ auto Tw_out = w_out.matrix<float>();
+ auto Texamples = examples.flat<int32>();
+ auto Tlabels = labels.flat<int32>();
+ auto lr = learning_rate.scalar<float>()();
+ const int64 vocab_size = w_in.dim_size(0);
+ const int64 dims = w_in.dim_size(1);
+ const int64 batch_size = examples.dim_size(0);
+ OP_REQUIRES(ctx, vocab_size == sampler_->num(),
+ errors::InvalidArgument("vocab_size mismatches: ", vocab_size,
+ " vs. ", sampler_->num()));
+
+ // Gradient accumulator for v_in.
+ Tensor buf(DT_FLOAT, TensorShape({dims}));
+ auto Tbuf = buf.flat<float>();
+
+ // Scalar buffer to hold sigmoid(+/- dot).
+ Tensor g_buf(DT_FLOAT, TensorShape({}));
+ auto g = g_buf.scalar<float>();
+
+ // The following loop needs 2 random 32-bit values per negative
+ // sample. We reserve 8 values per sample just in case the
+ // underlying implementation changes.
+ auto rnd = base_.ReserveSamples32(batch_size * num_samples_ * 8);
+ random::SimplePhilox srnd(&rnd);
+
+ for (int64 i = 0; i < batch_size; ++i) {
+ const int32 example = Texamples(i);
+ DCHECK(0 <= example && example < vocab_size) << example;
+ const int32 label = Tlabels(i);
+ DCHECK(0 <= label && label < vocab_size) << label;
+ auto v_in = Tw_in.chip<0>(example);
+
+ // Positive: example predicts label.
+ // forward: x = v_in' * v_out
+ // l = log(sigmoid(x))
+ // backward: dl/dx = g = sigmoid(-x)
+ // dl/d(v_in) = g * v_out'
+ // dl/d(v_out) = v_in' * g
+ {
+ auto v_out = Tw_out.chip<0>(label);
+ auto dot = (v_in * v_out).sum();
+ g = (dot.exp() + 1.f).inverse();
+ Tbuf = v_out * (g() * lr);
+ v_out += v_in * (g() * lr);
+ }
+
+ // Negative samples:
+ // forward: x = v_in' * v_sample
+ // l = log(sigmoid(-x))
+ // backward: dl/dx = g = -sigmoid(x)
+ // dl/d(v_in) = g * v_out'
+ // dl/d(v_out) = v_in' * g
+ for (int j = 0; j < num_samples_; ++j) {
+ const int sample = sampler_->Sample(&srnd);
+ if (sample == label) continue; // Skip.
+ auto v_sample = Tw_out.chip<0>(sample);
+ auto dot = (v_in * v_sample).sum();
+ g = -((-dot).exp() + 1.f).inverse();
+ Tbuf += v_sample * (g() * lr);
+ v_sample += v_in * (g() * lr);
+ }
+
+ // Applies the gradient on v_in.
+ v_in += Tbuf;
+ }
+ }
+
+ private:
+ int32 num_samples_ = 0;
+ random::DistributionSampler* sampler_ = nullptr;
+ GuardedPhiloxRandom base_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("NegTrain").Device(DEVICE_CPU), NegTrainOp);
+
+} // end namespace tensorflow
diff --git a/tensorflow/models/embedding/word2vec_ops.cc b/tensorflow/models/embedding/word2vec_ops.cc
new file mode 100644
index 0000000000..abe03baaf4
--- /dev/null
+++ b/tensorflow/models/embedding/word2vec_ops.cc
@@ -0,0 +1,56 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("Skipgram")
+ .Output("vocab_word: string")
+ .Output("vocab_freq: int32")
+ .Output("words_per_epoch: int64")
+ .Output("current_epoch: int32")
+ .Output("total_words_processed: int64")
+ .Output("examples: int32")
+ .Output("labels: int32")
+ .Attr("filename: string")
+ .Attr("batch_size: int")
+ .Attr("window_size: int = 5")
+ .Attr("min_count: int = 5")
+ .Attr("subsample: float = 1e-3")
+ .Doc(R"doc(
+Parses a text file and creates a batch of examples.
+
+vocab_word: A vector of words in the corpus.
+vocab_freq: Frequencies of words. Sorted in the non-ascending order.
+words_per_epoch: Number of words per epoch in the data file.
+current_epoch: The current epoch number.
+total_words_processed: The total number of words processed so far.
+examples: A vector of word ids.
+labels: A vector of word ids.
+filename: The corpus's text file name.
+batch_size: The size of produced batch.
+window_size: The number of words to predict to the left and right of the target.
+min_count: The minimum number of word occurrences for it to be included in the
+ vocabulary.
+subsample: Threshold for word occurrence. Words that appear with higher
+ frequency will be randomly down-sampled. Set to 0 to disable.
+)doc");
+
+REGISTER_OP("NegTrain")
+ .Input("w_in: Ref(float)")
+ .Input("w_out: Ref(float)")
+ .Input("examples: int32")
+ .Input("labels: int32")
+ .Input("lr: float")
+ .Attr("vocab_count: list(int)")
+ .Attr("num_negative_samples: int")
+ .Doc(R"doc(
+Training via negative sampling.
+
+w_in: input word embedding.
+w_out: output word embedding.
+examples: A vector of word ids.
+labels: A vector of word ids.
+vocab_count: Count of words in the vocabulary.
+num_negative_samples: Number of negative samples per exaple.
+)doc");
+
+} // end namespace tensorflow
diff --git a/tensorflow/models/embedding/word2vec_optimized.py b/tensorflow/models/embedding/word2vec_optimized.py
new file mode 100644
index 0000000000..23e7645a0b
--- /dev/null
+++ b/tensorflow/models/embedding/word2vec_optimized.py
@@ -0,0 +1,405 @@
+"""Multi-threaded word2vec unbatched skip-gram model.
+
+Trains the model described in:
+(Mikolov, et. al.) Efficient Estimation of Word Representations in Vector Space
+ICLR 2013.
+http://arxiv.org/abs/1301.3781
+This model does true SGD (i.e. no minibatching). To do this efficiently, custom
+ops are used to sequentially process data within a 'batch'.
+
+The key ops used are:
+* skipgram custom op that does input processing.
+* neg_train custom op that efficiently calculates and applies the gradient using
+ true SGD.
+"""
+
+import sys
+import threading
+import time
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.models.embedding import gen_word2vec as word2vec
+
+flags = tf.app.flags
+
+flags.DEFINE_string("save_path", None, "Directory to write the model.")
+flags.DEFINE_string(
+ "train_data", None,
+ "Training data. E.g., unzipped file http://mattmahoney.net/dc/text8.zip.")
+flags.DEFINE_string(
+ "eval_data", None, "Analogy questions. "
+ "https://word2vec.googlecode.com/svn/trunk/questions-words.txt.")
+flags.DEFINE_integer("embedding_size", 200, "The embedding dimension size.")
+flags.DEFINE_integer(
+ "epochs_to_train", 15,
+ "Number of epochs to train. Each epoch processes the training data once "
+ "completely.")
+flags.DEFINE_float("learning_rate", 0.025, "Initial learning rate.")
+flags.DEFINE_integer("num_neg_samples", 25,
+ "Negative samples per training example.")
+flags.DEFINE_integer("batch_size", 500,
+ "Numbers of training examples each step processes "
+ "(no minibatching).")
+flags.DEFINE_integer("concurrent_steps", 12,
+ "The number of concurrent training steps.")
+flags.DEFINE_integer("window_size", 5,
+ "The number of words to predict to the left and right "
+ "of the target word.")
+flags.DEFINE_integer("min_count", 5,
+ "The minimum number of word occurrences for it to be "
+ "included in the vocabulary.")
+flags.DEFINE_float("subsample", 1e-3,
+ "Subsample threshold for word occurrence. Words that appear "
+ "with higher frequency will be randomly down-sampled. Set "
+ "to 0 to disable.")
+flags.DEFINE_boolean(
+ "interactive", False,
+ "If true, enters an IPython interactive session to play with the trained "
+ "model. E.g., try model.analogy('france', 'paris', 'russia') and "
+ "model.nearby(['proton', 'elephant', 'maxwell']")
+
+FLAGS = flags.FLAGS
+
+
+class Options(object):
+ """Options used by our word2vec model."""
+
+ def __init__(self):
+ # Model options.
+
+ # Embedding dimension.
+ self.emb_dim = FLAGS.embedding_size
+
+ # Training options.
+
+ # The training text file.
+ self.train_data = FLAGS.train_data
+
+ # Number of negative samples per example.
+ self.num_samples = FLAGS.num_neg_samples
+
+ # The initial learning rate.
+ self.learning_rate = FLAGS.learning_rate
+
+ # Number of epochs to train. After these many epochs, the learning
+ # rate decays linearly to zero and the training stops.
+ self.epochs_to_train = FLAGS.epochs_to_train
+
+ # Concurrent training steps.
+ self.concurrent_steps = FLAGS.concurrent_steps
+
+ # Number of examples for one training step.
+ self.batch_size = FLAGS.batch_size
+
+ # The number of words to predict to the left and right of the target word.
+ self.window_size = FLAGS.window_size
+
+ # The minimum number of word occurrences for it to be included in the
+ # vocabulary.
+ self.min_count = FLAGS.min_count
+
+ # Subsampling threshold for word occurrence.
+ self.subsample = FLAGS.subsample
+
+ # Where to write out summaries.
+ self.save_path = FLAGS.save_path
+
+ # Eval options.
+
+ # The text file for eval.
+ self.eval_data = FLAGS.eval_data
+
+
+class Word2Vec(object):
+ """Word2Vec model (Skipgram)."""
+
+ def __init__(self, options, session):
+ self._options = options
+ self._session = session
+ self._word2id = {}
+ self._id2word = []
+ self.build_graph()
+ self.build_eval_graph()
+ self.save_vocab()
+ self._read_analogies()
+
+ def _read_analogies(self):
+ """Reads through the analogy question file.
+
+ Returns:
+ questions: a [n, 4] numpy array containing the analogy question's
+ word ids.
+ questions_skipped: questions skipped due to unknown words.
+ """
+ questions = []
+ questions_skipped = 0
+ with open(self._options.eval_data) as analogy_f:
+ for line in analogy_f:
+ if line.startswith(":"): # Skip comments.
+ continue
+ words = line.strip().lower().split(" ")
+ ids = [self._word2id.get(w.strip()) for w in words]
+ if None in ids or len(ids) != 4:
+ questions_skipped += 1
+ else:
+ questions.append(np.array(ids))
+ print "Eval analogy file: ", self._options.eval_data
+ print "Questions: ", len(questions)
+ print "Skipped: ", questions_skipped
+ self._analogy_questions = np.array(questions, dtype=np.int32)
+
+ def build_graph(self):
+ """Build the model graph."""
+ opts = self._options
+
+ # The training data. A text file.
+ (words, counts, words_per_epoch, current_epoch, total_words_processed,
+ examples, labels) = word2vec.skipgram(filename=opts.train_data,
+ batch_size=opts.batch_size,
+ window_size=opts.window_size,
+ min_count=opts.min_count,
+ subsample=opts.subsample)
+ (opts.vocab_words, opts.vocab_counts,
+ opts.words_per_epoch) = self._session.run([words, counts, words_per_epoch])
+ opts.vocab_size = len(opts.vocab_words)
+ print "Data file: ", opts.train_data
+ print "Vocab size: ", opts.vocab_size - 1, " + UNK"
+ print "Words per epoch: ", opts.words_per_epoch
+
+ self._id2word = opts.vocab_words
+ for i, w in enumerate(self._id2word):
+ self._word2id[w] = i
+
+ # Declare all variables we need.
+ # Input words embedding: [vocab_size, emb_dim]
+ w_in = tf.Variable(
+ tf.random_uniform(
+ [opts.vocab_size,
+ opts.emb_dim], -0.5 / opts.emb_dim, 0.5 / opts.emb_dim),
+ name="w_in")
+
+ # Global step: scalar, i.e., shape [].
+ w_out = tf.Variable(tf.zeros([opts.vocab_size, opts.emb_dim]), name="w_out")
+
+ # Global step: []
+ global_step = tf.Variable(0, name="global_step")
+
+ # Linear learning rate decay.
+ words_to_train = float(opts.words_per_epoch * opts.epochs_to_train)
+ lr = opts.learning_rate * tf.maximum(
+ 0.0001,
+ 1.0 - tf.cast(total_words_processed, tf.float32) / words_to_train)
+
+ # Training nodes.
+ inc = global_step.assign_add(1)
+ with tf.control_dependencies([inc]):
+ train = word2vec.neg_train(w_in,
+ w_out,
+ examples,
+ labels,
+ lr,
+ vocab_count=opts.vocab_counts.tolist(),
+ num_negative_samples=opts.num_samples)
+
+ self._w_in = w_in
+ self._examples = examples
+ self._labels = labels
+ self._lr = lr
+ self._train = train
+ self.step = global_step
+ self._epoch = current_epoch
+ self._words = total_words_processed
+
+ def save_vocab(self):
+ """Save the vocabulary to a file so the model can be reloaded."""
+ opts = self._options
+ with open(opts.save_path + "/vocab.txt", "w") as f:
+ for i in xrange(opts.vocab_size):
+ f.write(opts.vocab_words[i] + " " + str(opts.vocab_counts[i]) + "\n")
+
+ def build_eval_graph(self):
+ """Build the evaluation graph."""
+ # Eval graph
+ opts = self._options
+
+ # Each analogy task is to predict the 4th word (d) given three
+ # words: a, b, c. E.g., a=italy, b=rome, c=france, we should
+ # predict d=paris.
+
+ # The eval feeds three vectors of word ids for a, b, c, each of
+ # which is of size N, where N is the number of analogies we want to
+ # evaluate in one batch.
+ analogy_a = tf.placeholder(dtype=tf.int32) # [N]
+ analogy_b = tf.placeholder(dtype=tf.int32) # [N]
+ analogy_c = tf.placeholder(dtype=tf.int32) # [N]
+
+ # Normalized word embeddings of shape [vocab_size, emb_dim].
+ nemb = tf.nn.l2_normalize(self._w_in, 1)
+
+ # Each row of a_emb, b_emb, c_emb is a word's embedding vector.
+ # They all have the shape [N, emb_dim]
+ a_emb = tf.gather(nemb, analogy_a) # a's embs
+ b_emb = tf.gather(nemb, analogy_b) # b's embs
+ c_emb = tf.gather(nemb, analogy_c) # c's embs
+
+ # We expect that d's embedding vectors on the unit hyper-sphere is
+ # near: c_emb + (b_emb - a_emb), which has the shape [N, emb_dim].
+ target = c_emb + (b_emb - a_emb)
+
+ # Compute cosine distance between each pair of target and vocab.
+ # dist has shape [N, vocab_size].
+ dist = tf.matmul(target, nemb, transpose_b=True)
+
+ # For each question (row in dist), find the top 4 words.
+ _, pred_idx = tf.nn.top_k(dist, 4)
+
+ # Nodes for computing neighbors for a given word according to
+ # their cosine distance.
+ nearby_word = tf.placeholder(dtype=tf.int32) # word id
+ nearby_emb = tf.gather(nemb, nearby_word)
+ nearby_dist = tf.matmul(nearby_emb, nemb, transpose_b=True)
+ nearby_val, nearby_idx = tf.nn.top_k(nearby_dist,
+ min(1000, opts.vocab_size))
+
+ # Nodes in the construct graph which are used by training and
+ # evaluation to run/feed/fetch.
+ self._analogy_a = analogy_a
+ self._analogy_b = analogy_b
+ self._analogy_c = analogy_c
+ self._analogy_pred_idx = pred_idx
+ self._nearby_word = nearby_word
+ self._nearby_val = nearby_val
+ self._nearby_idx = nearby_idx
+
+ # Properly initialize all variables.
+ tf.initialize_all_variables().run()
+
+ self.saver = tf.train.Saver()
+
+ def _train_thread_body(self):
+ initial_epoch, = self._session.run([self._epoch])
+ while True:
+ _, epoch = self._session.run([self._train, self._epoch])
+ if epoch != initial_epoch:
+ break
+
+ def train(self):
+ """Train the model."""
+ opts = self._options
+
+ initial_epoch, initial_words = self._session.run([self._epoch, self._words])
+
+ workers = []
+ for _ in xrange(opts.concurrent_steps):
+ t = threading.Thread(target=self._train_thread_body)
+ t.start()
+ workers.append(t)
+
+ last_words, last_time = initial_words, time.time()
+ while True:
+ time.sleep(5) # Reports our progress once a while.
+ (epoch, step, words,
+ lr) = self._session.run([self._epoch, self.step, self._words, self._lr])
+ now = time.time()
+ last_words, last_time, rate = words, now, (words - last_words) / (
+ now - last_time)
+ print "Epoch %4d Step %8d: lr = %5.3f words/sec = %8.0f\r" % (epoch, step,
+ lr, rate),
+ sys.stdout.flush()
+ if epoch != initial_epoch:
+ break
+
+ for t in workers:
+ t.join()
+
+ def _predict(self, analogy):
+ """Predict the top 4 answers for analogy questions."""
+ idx, = self._session.run([self._analogy_pred_idx], {
+ self._analogy_a: analogy[:, 0],
+ self._analogy_b: analogy[:, 1],
+ self._analogy_c: analogy[:, 2]
+ })
+ return idx
+
+ def eval(self):
+ """Evaluate analogy questions and reports accuracy."""
+
+ # How many questions we get right at precision@1.
+ correct = 0
+
+ total = self._analogy_questions.shape[0]
+ start = 0
+ while start < total:
+ limit = start + 2500
+ sub = self._analogy_questions[start:limit, :]
+ idx = self._predict(sub)
+ start = limit
+ for question in xrange(sub.shape[0]):
+ for j in xrange(4):
+ if idx[question, j] == sub[question, 3]:
+ # Bingo! We predicted correctly. E.g., [italy, rome, france, paris].
+ correct += 1
+ break
+ elif idx[question, j] in sub[question, :3]:
+ # We need to skip words already in the question.
+ continue
+ else:
+ # The correct label is not the precision@1
+ break
+ print
+ print "Eval %4d/%d accuracy = %4.1f%%" % (correct, total,
+ correct * 100.0 / total)
+
+ def analogy(self, w0, w1, w2):
+ """Predict word w3 as in w0:w1 vs w2:w3."""
+ wid = np.array([[self._word2id.get(w, 0) for w in [w0, w1, w2]]])
+ idx = self._predict(wid)
+ for c in [self._id2word[i] for i in idx[0, :]]:
+ if c not in [w0, w1, w2]:
+ return c
+ return "unknown"
+
+ def nearby(self, words, num=20):
+ """Prints out nearby words given a list of words."""
+ ids = np.array([self._word2id.get(x, 0) for x in words])
+ vals, idx = self._session.run(
+ [self._nearby_val, self._nearby_idx], {self._nearby_word: ids})
+ for i in xrange(len(words)):
+ print "\n%s\n=====================================" % (words[i])
+ for (neighbor, distance) in zip(idx[i, :num], vals[i, :num]):
+ print "%-20s %6.4f" % (self._id2word[neighbor], distance)
+
+
+def _start_shell(local_ns=None):
+ # An interactive shell is useful for debugging/development.
+ import IPython
+ user_ns = {}
+ if local_ns:
+ user_ns.update(local_ns)
+ user_ns.update(globals())
+ IPython.start_ipython(argv=[], user_ns=user_ns)
+
+
+def main(_):
+ """Train a word2vec model."""
+ opts = Options()
+ with tf.Graph().as_default(), tf.Session() as session:
+ model = Word2Vec(opts, session)
+ for _ in xrange(opts.epochs_to_train):
+ model.train() # Process one epoch
+ model.eval() # Eval analogies.
+ # Perform a final save.
+ model.saver.save(session, opts.save_path + "model", global_step=model.step)
+ if FLAGS.interactive:
+ # E.g.,
+ # [0]: model.Analogy('france', 'paris', 'russia')
+ # [1]: model.Nearby(['proton', 'elephant', 'maxwell'])
+ _start_shell(locals())
+
+
+if __name__ == "__main__":
+ tf.app.run()