aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/tutorials
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-22 23:08:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-22 23:12:44 -0700
commit1e849fed6032015eb25149b801de0f7be2d87026 (patch)
tree900c810f50e3b5a3dcc8c85e9e9746bb34c792d6 /tensorflow/examples/tutorials
parentb911a9e6be7cec97d736c157986adff73e3a980c (diff)
Downloading text8.zip to temporary directory so that writing rights to the current directory is no longer required.
If a library is missing when displaying the plot - it shows now details, which one is missing. PiperOrigin-RevId: 166161780
Diffstat (limited to 'tensorflow/examples/tutorials')
-rw-r--r--tensorflow/examples/tutorials/word2vec/word2vec_basic.py26
1 files changed, 17 insertions, 9 deletions
diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
index e67442b14b..6c93617ae5 100644
--- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
+++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
@@ -22,6 +22,7 @@ import collections
import math
import os
import random
+from tempfile import gettempdir
import zipfile
import numpy as np
@@ -33,18 +34,22 @@ import tensorflow as tf
url = 'http://mattmahoney.net/dc/'
+# pylint: disable=redefined-outer-name
def maybe_download(filename, expected_bytes):
"""Download a file if not present, and make sure it's the right size."""
- if not os.path.exists(filename):
- filename, _ = urllib.request.urlretrieve(url + filename, filename)
- statinfo = os.stat(filename)
+ local_filename = os.path.join(gettempdir(), filename)
+ if not os.path.exists(local_filename):
+ local_filename, _ = urllib.request.urlretrieve(url + filename,
+ local_filename)
+ statinfo = os.stat(local_filename)
if statinfo.st_size == expected_bytes:
print('Found and verified', filename)
else:
print(statinfo.st_size)
- raise Exception(
- 'Failed to verify ' + filename + '. Can you get to it with a browser?')
- return filename
+ raise Exception('Failed to verify ' + local_filename +
+ '. Can you get to it with a browser?')
+ return local_filename
+
filename = maybe_download('text8.zip', 31344016)
@@ -233,7 +238,9 @@ with tf.Session(graph=graph) as session:
# Step 6: Visualize the embeddings.
-def plot_with_labels(low_dim_embs, labels, filename='tsne.png'):
+# pylint: disable=missing-docstring
+# Function to draw visualization of distance between embeddings.
+def plot_with_labels(low_dim_embs, labels, filename):
assert low_dim_embs.shape[0] >= len(labels), 'More labels than embeddings'
plt.figure(figsize=(18, 18)) # in inches
for i, label in enumerate(labels):
@@ -257,7 +264,8 @@ try:
plot_only = 500
low_dim_embs = tsne.fit_transform(final_embeddings[:plot_only, :])
labels = [reverse_dictionary[i] for i in xrange(plot_only)]
- plot_with_labels(low_dim_embs, labels)
+ plot_with_labels(low_dim_embs, labels, os.path.join(gettempdir(), 'tsne.png'))
-except ImportError:
+except ImportError as ex:
print('Please install sklearn, matplotlib, and scipy to show embeddings.')
+ print(ex)