aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/tutorials/word2vec/word2vec_basic.py')
-rw-r--r--tensorflow/examples/tutorials/word2vec/word2vec_basic.py21
1 files changed, 12 insertions, 9 deletions
diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
index 628c6e2741..c717693a56 100644
--- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
+++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
@@ -31,6 +31,7 @@ import tensorflow as tf
# Step 1: Download the data.
url = 'http://mattmahoney.net/dc/'
+
def maybe_download(filename, expected_bytes):
"""Download a file if not present, and make sure it's the right size."""
if not os.path.exists(filename):
@@ -60,6 +61,7 @@ print('Data size', len(words))
# Step 2: Build the dictionary and replace rare words with UNK token.
vocabulary_size = 50000
+
def build_dataset(words):
count = [['UNK', -1]]
count.extend(collections.Counter(words).most_common(vocabulary_size - 1))
@@ -94,14 +96,14 @@ def generate_batch(batch_size, num_skips, skip_window):
assert num_skips <= 2 * skip_window
batch = np.ndarray(shape=(batch_size), dtype=np.int32)
labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
- span = 2 * skip_window + 1 # [ skip_window target skip_window ]
+ span = 2 * skip_window + 1 # [ skip_window target skip_window ]
buffer = collections.deque(maxlen=span)
for _ in range(span):
buffer.append(data[data_index])
data_index = (data_index + 1) % len(data)
for i in range(batch_size // num_skips):
target = skip_window # target label at the center of the buffer
- targets_to_avoid = [ skip_window ]
+ targets_to_avoid = [skip_window]
for j in range(num_skips):
while target in targets_to_avoid:
target = random.randint(0, span - 1)
@@ -115,7 +117,7 @@ def generate_batch(batch_size, num_skips, skip_window):
batch, labels = generate_batch(batch_size=8, num_skips=2, skip_window=1)
for i in range(8):
print(batch[i], reverse_dictionary[batch[i]],
- '->', labels[i, 0], reverse_dictionary[labels[i, 0]])
+ '->', labels[i, 0], reverse_dictionary[labels[i, 0]])
# Step 4: Build and train a skip-gram model.
@@ -187,7 +189,7 @@ with tf.Session(graph=graph) as session:
for step in xrange(num_steps):
batch_inputs, batch_labels = generate_batch(
batch_size, num_skips, skip_window)
- feed_dict = {train_inputs : batch_inputs, train_labels : batch_labels}
+ feed_dict = {train_inputs: batch_inputs, train_labels: batch_labels}
# We perform one update step by evaluating the optimizer op (including it
# in the list of returned values for session.run()
@@ -206,8 +208,8 @@ with tf.Session(graph=graph) as session:
sim = similarity.eval()
for i in xrange(valid_size):
valid_word = reverse_dictionary[valid_examples[i]]
- top_k = 8 # number of nearest neighbors
- nearest = (-sim[i, :]).argsort()[1:top_k+1]
+ top_k = 8 # number of nearest neighbors
+ nearest = (-sim[i, :]).argsort()[1:top_k + 1]
log_str = "Nearest to %s:" % valid_word
for k in xrange(top_k):
close_word = reverse_dictionary[nearest[k]]
@@ -217,11 +219,12 @@ with tf.Session(graph=graph) as session:
# Step 6: Visualize the embeddings.
+
def plot_with_labels(low_dim_embs, labels, filename='tsne.png'):
assert low_dim_embs.shape[0] >= len(labels), "More labels than embeddings"
- plt.figure(figsize=(18, 18)) #in inches
+ plt.figure(figsize=(18, 18)) # in inches
for i, label in enumerate(labels):
- x, y = low_dim_embs[i,:]
+ x, y = low_dim_embs[i, :]
plt.scatter(x, y)
plt.annotate(label,
xy=(x, y),
@@ -238,7 +241,7 @@ try:
tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)
plot_only = 500
- low_dim_embs = tsne.fit_transform(final_embeddings[:plot_only,:])
+ low_dim_embs = tsne.fit_transform(final_embeddings[:plot_only, :])
labels = [reverse_dictionary[i] for i in xrange(plot_only)]
plot_with_labels(low_dim_embs, labels)