diff options
Diffstat (limited to 'tensorflow/examples/tutorials/word2vec/word2vec_basic.py')
-rw-r--r-- | tensorflow/examples/tutorials/word2vec/word2vec_basic.py | 21 |
1 files changed, 12 insertions, 9 deletions
diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py index 628c6e2741..c717693a56 100644 --- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py +++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py @@ -31,6 +31,7 @@ import tensorflow as tf # Step 1: Download the data. url = 'http://mattmahoney.net/dc/' + def maybe_download(filename, expected_bytes): """Download a file if not present, and make sure it's the right size.""" if not os.path.exists(filename): @@ -60,6 +61,7 @@ print('Data size', len(words)) # Step 2: Build the dictionary and replace rare words with UNK token. vocabulary_size = 50000 + def build_dataset(words): count = [['UNK', -1]] count.extend(collections.Counter(words).most_common(vocabulary_size - 1)) @@ -94,14 +96,14 @@ def generate_batch(batch_size, num_skips, skip_window): assert num_skips <= 2 * skip_window batch = np.ndarray(shape=(batch_size), dtype=np.int32) labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) - span = 2 * skip_window + 1 # [ skip_window target skip_window ] + span = 2 * skip_window + 1 # [ skip_window target skip_window ] buffer = collections.deque(maxlen=span) for _ in range(span): buffer.append(data[data_index]) data_index = (data_index + 1) % len(data) for i in range(batch_size // num_skips): target = skip_window # target label at the center of the buffer - targets_to_avoid = [ skip_window ] + targets_to_avoid = [skip_window] for j in range(num_skips): while target in targets_to_avoid: target = random.randint(0, span - 1) @@ -115,7 +117,7 @@ def generate_batch(batch_size, num_skips, skip_window): batch, labels = generate_batch(batch_size=8, num_skips=2, skip_window=1) for i in range(8): print(batch[i], reverse_dictionary[batch[i]], - '->', labels[i, 0], reverse_dictionary[labels[i, 0]]) + '->', labels[i, 0], reverse_dictionary[labels[i, 0]]) # Step 4: Build and train a skip-gram model. @@ -187,7 +189,7 @@ with tf.Session(graph=graph) as session: for step in xrange(num_steps): batch_inputs, batch_labels = generate_batch( batch_size, num_skips, skip_window) - feed_dict = {train_inputs : batch_inputs, train_labels : batch_labels} + feed_dict = {train_inputs: batch_inputs, train_labels: batch_labels} # We perform one update step by evaluating the optimizer op (including it # in the list of returned values for session.run() @@ -206,8 +208,8 @@ with tf.Session(graph=graph) as session: sim = similarity.eval() for i in xrange(valid_size): valid_word = reverse_dictionary[valid_examples[i]] - top_k = 8 # number of nearest neighbors - nearest = (-sim[i, :]).argsort()[1:top_k+1] + top_k = 8 # number of nearest neighbors + nearest = (-sim[i, :]).argsort()[1:top_k + 1] log_str = "Nearest to %s:" % valid_word for k in xrange(top_k): close_word = reverse_dictionary[nearest[k]] @@ -217,11 +219,12 @@ with tf.Session(graph=graph) as session: # Step 6: Visualize the embeddings. + def plot_with_labels(low_dim_embs, labels, filename='tsne.png'): assert low_dim_embs.shape[0] >= len(labels), "More labels than embeddings" - plt.figure(figsize=(18, 18)) #in inches + plt.figure(figsize=(18, 18)) # in inches for i, label in enumerate(labels): - x, y = low_dim_embs[i,:] + x, y = low_dim_embs[i, :] plt.scatter(x, y) plt.annotate(label, xy=(x, y), @@ -238,7 +241,7 @@ try: tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000) plot_only = 500 - low_dim_embs = tsne.fit_transform(final_embeddings[:plot_only,:]) + low_dim_embs = tsne.fit_transform(final_embeddings[:plot_only, :]) labels = [reverse_dictionary[i] for i in xrange(plot_only)] plot_with_labels(low_dim_embs, labels) |