diff options
Diffstat (limited to 'tensorflow/models/embedding/word2vec.py')
-rw-r--r-- | tensorflow/models/embedding/word2vec.py | 503 |
1 files changed, 503 insertions, 0 deletions
diff --git a/tensorflow/models/embedding/word2vec.py b/tensorflow/models/embedding/word2vec.py new file mode 100644 index 0000000000..4ebf3d6f27 --- /dev/null +++ b/tensorflow/models/embedding/word2vec.py @@ -0,0 +1,503 @@ +"""Multi-threaded word2vec mini-batched skip-gram model. + +Trains the model described in: +(Mikolov, et. al.) Efficient Estimation of Word Representations in Vector Space +ICLR 2013. +http://arxiv.org/abs/1301.3781 +This model does traditional minibatching. + +The key ops used are: +* placeholder for feeding in tensors for each example. +* embedding_lookup for fetching rows from the embedding matrix. +* sigmoid_cross_entropy_with_logits to calculate the loss. +* GradientDescentOptimizer for optimizing the loss. +* skipgram custom op that does input processing. +""" + +import sys +import threading +import time + +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + +from tensorflow.models.embedding import gen_word2vec as word2vec + +flags = tf.app.flags + +flags.DEFINE_string("save_path", None, "Directory to write the model and " + "training summaries.") +flags.DEFINE_string("train_data", None, "Training text file. " + "E.g., unzipped file http://mattmahoney.net/dc/text8.zip.") +flags.DEFINE_string( + "eval_data", None, "File consisting of analogies of four tokens." + "embedding 2 - embedding 1 + embedding 3 should be close " + "to embedding 4." + "E.g. https://word2vec.googlecode.com/svn/trunk/questions-words.txt.") +flags.DEFINE_integer("embedding_size", 200, "The embedding dimension size.") +flags.DEFINE_integer( + "epochs_to_train", 15, + "Number of epochs to train. Each epoch processes the training data once " + "completely.") +flags.DEFINE_float("learning_rate", 0.2, "Initial learning rate.") +flags.DEFINE_integer("num_neg_samples", 100, + "Negative samples per training example.") +flags.DEFINE_integer("batch_size", 16, + "Number of training examples processed per step " + "(size of a minibatch).") +flags.DEFINE_integer("concurrent_steps", 12, + "The number of concurrent training steps.") +flags.DEFINE_integer("window_size", 5, + "The number of words to predict to the left and right " + "of the target word.") +flags.DEFINE_integer("min_count", 5, + "The minimum number of word occurrences for it to be " + "included in the vocabulary.") +flags.DEFINE_float("subsample", 1e-3, + "Subsample threshold for word occurrence. Words that appear " + "with higher frequency will be randomly down-sampled. Set " + "to 0 to disable.") +flags.DEFINE_boolean( + "interactive", False, + "If true, enters an IPython interactive session to play with the trained " + "model. E.g., try model.analogy('france', 'paris', 'russia') and " + "model.nearby(['proton', 'elephant', 'maxwell']") +flags.DEFINE_integer("statistics_interval", 5, + "Print statistics every n seconds.") +flags.DEFINE_integer("summary_interval", 5, + "Save training summary to file every n seconds (rounded " + "up to statistics interval.") +flags.DEFINE_integer("checkpoint_interval", 600, + "Checkpoint the model (i.e. save the parameters) every n " + "seconds (rounded up to statistics interval.") + +FLAGS = flags.FLAGS + + +class Options(object): + """Options used by our word2vec model.""" + + def __init__(self): + # Model options. + + # Embedding dimension. + self.emb_dim = FLAGS.embedding_size + + # Training options. + # The training text file. + self.train_data = FLAGS.train_data + + # Number of negative samples per example. + self.num_samples = FLAGS.num_neg_samples + + # The initial learning rate. + self.learning_rate = FLAGS.learning_rate + + # Number of epochs to train. After these many epochs, the learning + # rate decays linearly to zero and the training stops. + self.epochs_to_train = FLAGS.epochs_to_train + + # Concurrent training steps. + self.concurrent_steps = FLAGS.concurrent_steps + + # Number of examples for one training step. + self.batch_size = FLAGS.batch_size + + # The number of words to predict to the left and right of the target word. + self.window_size = FLAGS.window_size + + # The minimum number of word occurrences for it to be included in the + # vocabulary. + self.min_count = FLAGS.min_count + + # Subsampling threshold for word occurrence. + self.subsample = FLAGS.subsample + + # How often to print statistics. + self.statistics_interval = FLAGS.statistics_interval + + # How often to write to the summary file (rounds up to the nearest + # statistics_interval). + self.summary_interval = FLAGS.summary_interval + + # How often to write checkpoints (rounds up to the nearest statistics + # interval). + self.checkpoint_interval = FLAGS.checkpoint_interval + + # Where to write out summaries. + self.save_path = FLAGS.save_path + + # Eval options. + # The text file for eval. + self.eval_data = FLAGS.eval_data + + +class Word2Vec(object): + """Word2Vec model (Skipgram).""" + + def __init__(self, options, session): + self._options = options + self._session = session + self._word2id = {} + self._id2word = [] + self.build_graph() + self.build_eval_graph() + self.save_vocab() + self._read_analogies() + + def _read_analogies(self): + """Reads through the analogy question file. + + Returns: + questions: a [n, 4] numpy array containing the analogy question's + word ids. + questions_skipped: questions skipped due to unknown words. + """ + questions = [] + questions_skipped = 0 + with open(self._options.eval_data) as analogy_f: + for line in analogy_f: + if line.startswith(":"): # Skip comments. + continue + words = line.strip().lower().split(" ") + ids = [self._word2id.get(w.strip()) for w in words] + if None in ids or len(ids) != 4: + questions_skipped += 1 + else: + questions.append(np.array(ids)) + print "Eval analogy file: ", self._options.eval_data + print "Questions: ", len(questions) + print "Skipped: ", questions_skipped + self._analogy_questions = np.array(questions, dtype=np.int32) + + def forward(self, examples, labels): + """Build the graph for the forward pass.""" + opts = self._options + + # Declare all variables we need. + # Embedding: [vocab_size, emb_dim] + init_width = 0.5 / opts.emb_dim + emb = tf.Variable( + tf.random_uniform( + [opts.vocab_size, opts.emb_dim], -init_width, init_width), + name="emb") + self._emb = emb + + # Softmax weight: [vocab_size, emb_dim]. Transposed. + sm_w_t = tf.Variable( + tf.zeros([opts.vocab_size, opts.emb_dim]), + name="sm_w_t") + + # Softmax bias: [emb_dim]. + sm_b = tf.Variable(tf.zeros([opts.vocab_size]), name="sm_b") + + # Global step: scalar, i.e., shape []. + self.global_step = tf.Variable(0, name="global_step") + + # Nodes to compute the nce loss w/ candidate sampling. + labels_matrix = tf.reshape( + tf.cast(labels, + dtype=tf.int64), + [opts.batch_size, 1]) + + # Negative sampling. + sampled_ids, _, _ = (tf.nn.fixed_unigram_candidate_sampler( + true_classes=labels_matrix, + num_true=1, + num_sampled=opts.num_samples, + unique=True, + range_max=opts.vocab_size, + distortion=0.75, + unigrams=opts.vocab_counts.tolist())) + + # Embeddings for examples: [batch_size, emb_dim] + example_emb = tf.nn.embedding_lookup(emb, examples) + + # Weights for labels: [batch_size, emb_dim] + true_w = tf.nn.embedding_lookup(sm_w_t, labels) + # Biases for labels: [batch_size, 1] + true_b = tf.nn.embedding_lookup(sm_b, labels) + + # Weights for sampled ids: [num_sampled, emb_dim] + sampled_w = tf.nn.embedding_lookup(sm_w_t, sampled_ids) + # Biases for sampled ids: [num_sampled, 1] + sampled_b = tf.nn.embedding_lookup(sm_b, sampled_ids) + + # True logits: [batch_size, 1] + true_logits = tf.reduce_sum(tf.mul(example_emb, true_w), 1) + true_b + + # Sampled logits: [batch_size, num_sampled] + # We replicate sampled noise lables for all examples in the batch + # using the matmul. + sampled_b_vec = tf.reshape(sampled_b, [opts.num_samples]) + sampled_logits = tf.matmul(example_emb, + sampled_w, + transpose_b=True) + sampled_b_vec + return true_logits, sampled_logits + + def nce_loss(self, true_logits, sampled_logits): + """Build the graph for the NCE loss.""" + + # cross-entropy(logits, labels) + opts = self._options + true_xent = tf.nn.sigmoid_cross_entropy_with_logits( + true_logits, tf.ones_like(true_logits)) + sampled_xent = tf.nn.sigmoid_cross_entropy_with_logits( + sampled_logits, tf.zeros_like(sampled_logits)) + + # NCE-loss is the sum of the true and noise (sampled words) + # contributions, averaged over the batch. + nce_loss_tensor = (tf.reduce_sum(true_xent) + + tf.reduce_sum(sampled_xent)) / opts.batch_size + return nce_loss_tensor + + def optimize(self, loss): + """Build the graph to optimize the loss function.""" + + # Optimizer nodes. + # Linear learning rate decay. + opts = self._options + words_to_train = float(opts.words_per_epoch * opts.epochs_to_train) + lr = opts.learning_rate * tf.maximum( + 0.0001, 1.0 - tf.cast(self._words, tf.float32) / words_to_train) + self._lr = lr + optimizer = tf.train.GradientDescentOptimizer(lr) + train = optimizer.minimize(loss, + global_step=self.global_step, + gate_gradients=optimizer.GATE_NONE) + self._train = train + + def build_eval_graph(self): + """Build the eval graph.""" + # Eval graph + + # Each analogy task is to predict the 4th word (d) given three + # words: a, b, c. E.g., a=italy, b=rome, c=france, we should + # predict d=paris. + + # The eval feeds three vectors of word ids for a, b, c, each of + # which is of size N, where N is the number of analogies we want to + # evaluate in one batch. + analogy_a = tf.placeholder(dtype=tf.int32) # [N] + analogy_b = tf.placeholder(dtype=tf.int32) # [N] + analogy_c = tf.placeholder(dtype=tf.int32) # [N] + + # Normalized word embeddings of shape [vocab_size, emb_dim]. + nemb = tf.nn.l2_normalize(self._emb, 1) + + # Each row of a_emb, b_emb, c_emb is a word's embedding vector. + # They all have the shape [N, emb_dim] + a_emb = tf.gather(nemb, analogy_a) # a's embs + b_emb = tf.gather(nemb, analogy_b) # b's embs + c_emb = tf.gather(nemb, analogy_c) # c's embs + + # We expect that d's embedding vectors on the unit hyper-sphere is + # near: c_emb + (b_emb - a_emb), which has the shape [N, emb_dim]. + target = c_emb + (b_emb - a_emb) + + # Compute cosine distance between each pair of target and vocab. + # dist has shape [N, vocab_size]. + dist = tf.matmul(target, nemb, transpose_b=True) + + # For each question (row in dist), find the top 4 words. + _, pred_idx = tf.nn.top_k(dist, 4) + + # Nodes for computing neighbors for a given word according to + # their cosine distance. + nearby_word = tf.placeholder(dtype=tf.int32) # word id + nearby_emb = tf.gather(nemb, nearby_word) + nearby_dist = tf.matmul(nearby_emb, nemb, transpose_b=True) + nearby_val, nearby_idx = tf.nn.top_k(nearby_dist, + min(1000, self._options.vocab_size)) + + # Nodes in the construct graph which are used by training and + # evaluation to run/feed/fetch. + self._analogy_a = analogy_a + self._analogy_b = analogy_b + self._analogy_c = analogy_c + self._analogy_pred_idx = pred_idx + self._nearby_word = nearby_word + self._nearby_val = nearby_val + self._nearby_idx = nearby_idx + + def build_graph(self): + """Build the graph for the full model.""" + opts = self._options + # The training data. A text file. + (words, counts, words_per_epoch, self._epoch, self._words, examples, + labels) = word2vec.skipgram(filename=opts.train_data, + batch_size=opts.batch_size, + window_size=opts.window_size, + min_count=opts.min_count, + subsample=opts.subsample) + (opts.vocab_words, opts.vocab_counts, + opts.words_per_epoch) = self._session.run([words, counts, words_per_epoch]) + opts.vocab_size = len(opts.vocab_words) + print "Data file: ", opts.train_data + print "Vocab size: ", opts.vocab_size - 1, " + UNK" + print "Words per epoch: ", opts.words_per_epoch + self._examples = examples + self._labels = labels + self._id2word = opts.vocab_words + for i, w in enumerate(self._id2word): + self._word2id[w] = i + true_logits, sampled_logits = self.forward(examples, labels) + loss = self.nce_loss(true_logits, sampled_logits) + tf.scalar_summary("NCE loss", loss) + self._loss = loss + self.optimize(loss) + + # Properly initialize all variables. + tf.initialize_all_variables().run() + + self.saver = tf.train.Saver() + + def save_vocab(self): + """Save the vocabulary to a file so the model can be reloaded.""" + opts = self._options + with open(opts.save_path + "/vocab.txt", "w") as f: + for i in xrange(opts.vocab_size): + f.write(opts.vocab_words[i] + " " + str(opts.vocab_counts[i]) + "\n") + + def _train_thread_body(self): + initial_epoch, = self._session.run([self._epoch]) + while True: + _, epoch = self._session.run([self._train, self._epoch]) + if epoch != initial_epoch: + break + + def train(self): + """Train the model.""" + opts = self._options + + initial_epoch, initial_words = self._session.run([self._epoch, self._words]) + + summary_op = tf.merge_all_summaries() + summary_writer = tf.train.SummaryWriter(opts.save_path, + graph_def=self._session.graph_def) + workers = [] + for _ in xrange(opts.concurrent_steps): + t = threading.Thread(target=self._train_thread_body) + t.start() + workers.append(t) + + last_words, last_time, last_summary_time = initial_words, time.time(), 0 + last_checkpoint_time = 0 + while True: + time.sleep(opts.statistics_interval) # Reports our progress once a while. + (epoch, step, loss, words, lr) = self._session.run( + [self._epoch, self.global_step, self._loss, self._words, self._lr]) + now = time.time() + last_words, last_time, rate = words, now, (words - last_words) / ( + now - last_time) + print("Epoch %4d Step %8d: lr = %5.3f loss = %6.2f words/sec = %8.0f\r" % + (epoch, step, lr, loss, rate)), + sys.stdout.flush() + if now - last_summary_time > opts.summary_interval: + summary_str = self._session.run(summary_op) + summary_writer.add_summary(summary_str, step) + last_summary_time = now + if now - last_checkpoint_time > opts.checkpoint_interval: + self.saver.save(self._session, + opts.save_path + "model", + global_step=step) + last_checkpoint_time = now + if epoch != initial_epoch: + break + + for t in workers: + t.join() + + return epoch + + def _predict(self, analogy): + """Predict the top 4 answers for analogy questions.""" + idx, = self._session.run([self._analogy_pred_idx], { + self._analogy_a: analogy[:, 0], + self._analogy_b: analogy[:, 1], + self._analogy_c: analogy[:, 2] + }) + return idx + + def eval(self): + """Evaluate analogy questions and reports accuracy.""" + + # How many questions we get right at precision@1. + correct = 0 + + total = self._analogy_questions.shape[0] + start = 0 + while start < total: + limit = start + 2500 + sub = self._analogy_questions[start:limit, :] + idx = self._predict(sub) + start = limit + for question in xrange(sub.shape[0]): + for j in xrange(4): + if idx[question, j] == sub[question, 3]: + # Bingo! We predicted correctly. E.g., [italy, rome, france, paris]. + correct += 1 + break + elif idx[question, j] in sub[question, :3]: + # We need to skip words already in the question. + continue + else: + # The correct label is not the precision@1 + break + print + print "Eval %4d/%d accuracy = %4.1f%%" % (correct, total, + correct * 100.0 / total) + + def analogy(self, w0, w1, w2): + """Predict word w3 as in w0:w1 vs w2:w3.""" + wid = np.array([[self._word2id.get(w, 0) for w in [w0, w1, w2]]]) + idx = self._predict(wid) + for c in [self._id2word[i] for i in idx[0, :]]: + if c not in [w0, w1, w2]: + return c + return "unknown" + + def nearby(self, words, num=20): + """Prints out nearby words given a list of words.""" + ids = np.array([self._word2id.get(x, 0) for x in words]) + vals, idx = self._session.run( + [self._nearby_val, self._nearby_idx], {self._nearby_word: ids}) + for i in xrange(len(words)): + print "\n%s\n=====================================" % (words[i]) + for (neighbor, distance) in zip(idx[i, :num], vals[i, :num]): + print "%-20s %6.4f" % (self._id2word[neighbor], distance) + + +def _start_shell(local_ns=None): + # An interactive shell is useful for debugging/development. + import IPython + user_ns = {} + if local_ns: + user_ns.update(local_ns) + user_ns.update(globals()) + IPython.start_ipython(argv=[], user_ns=user_ns) + + +def main(_): + """Train a word2vec model.""" + opts = Options() + with tf.Graph().as_default(), tf.Session() as session: + model = Word2Vec(opts, session) + for _ in xrange(opts.epochs_to_train): + model.train() # Process one epoch + model.eval() # Eval analogies. + # Perform a final save. + model.saver.save(session, + opts.save_path + "model", + global_step=model.global_step) + if FLAGS.interactive: + # E.g., + # [0]: model.analogy('france', 'paris', 'russia') + # [1]: model.nearby(['proton', 'elephant', 'maxwell']) + _start_shell(locals()) + + +if __name__ == "__main__": + tf.app.run() |