From 1e849fed6032015eb25149b801de0f7be2d87026 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 22 Aug 2017 23:08:18 -0700 Subject: Downloading text8.zip to temporary directory so that writing rights to the current directory is no longer required. If a library is missing when displaying the plot - it shows now details, which one is missing. PiperOrigin-RevId: 166161780 --- .../examples/tutorials/word2vec/word2vec_basic.py | 26 ++++++++++++++-------- 1 file changed, 17 insertions(+), 9 deletions(-) (limited to 'tensorflow/examples/tutorials') diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py index e67442b14b..6c93617ae5 100644 --- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py +++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py @@ -22,6 +22,7 @@ import collections import math import os import random +from tempfile import gettempdir import zipfile import numpy as np @@ -33,18 +34,22 @@ import tensorflow as tf url = 'http://mattmahoney.net/dc/' +# pylint: disable=redefined-outer-name def maybe_download(filename, expected_bytes): """Download a file if not present, and make sure it's the right size.""" - if not os.path.exists(filename): - filename, _ = urllib.request.urlretrieve(url + filename, filename) - statinfo = os.stat(filename) + local_filename = os.path.join(gettempdir(), filename) + if not os.path.exists(local_filename): + local_filename, _ = urllib.request.urlretrieve(url + filename, + local_filename) + statinfo = os.stat(local_filename) if statinfo.st_size == expected_bytes: print('Found and verified', filename) else: print(statinfo.st_size) - raise Exception( - 'Failed to verify ' + filename + '. Can you get to it with a browser?') - return filename + raise Exception('Failed to verify ' + local_filename + + '. Can you get to it with a browser?') + return local_filename + filename = maybe_download('text8.zip', 31344016) @@ -233,7 +238,9 @@ with tf.Session(graph=graph) as session: # Step 6: Visualize the embeddings. -def plot_with_labels(low_dim_embs, labels, filename='tsne.png'): +# pylint: disable=missing-docstring +# Function to draw visualization of distance between embeddings. +def plot_with_labels(low_dim_embs, labels, filename): assert low_dim_embs.shape[0] >= len(labels), 'More labels than embeddings' plt.figure(figsize=(18, 18)) # in inches for i, label in enumerate(labels): @@ -257,7 +264,8 @@ try: plot_only = 500 low_dim_embs = tsne.fit_transform(final_embeddings[:plot_only, :]) labels = [reverse_dictionary[i] for i in xrange(plot_only)] - plot_with_labels(low_dim_embs, labels) + plot_with_labels(low_dim_embs, labels, os.path.join(gettempdir(), 'tsne.png')) -except ImportError: +except ImportError as ex: print('Please install sklearn, matplotlib, and scipy to show embeddings.') + print(ex) -- cgit v1.2.3