diff options
author | Martin Wicke <wicke@google.com> | 2017-03-24 11:06:18 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-24 12:26:57 -0700 |
commit | f2574c273778eeb05a8ef3ba40544ddee98a9e07 (patch) | |
tree | b50cabf762d11a82a41d46d27a482e2fc4fc08f4 /tensorflow/examples/tutorials | |
parent | ec2f8761168c40a76b95220221889b47f82700d9 (diff) |
Fix lint issues after pull.
Change: 151154030
Diffstat (limited to 'tensorflow/examples/tutorials')
-rw-r--r-- | tensorflow/examples/tutorials/word2vec/word2vec_basic.py | 34 |
1 files changed, 19 insertions, 15 deletions
diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py index f54a7c37a1..13e5717b0d 100644 --- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py +++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Basic word2vec example.""" from __future__ import absolute_import from __future__ import division @@ -50,21 +51,22 @@ filename = maybe_download('text8.zip', 31344016) # Read the data into a list of strings. def read_data(filename): - """Extract the first file enclosed in a zip file as a list of words""" + """Extract the first file enclosed in a zip file as a list of words.""" with zipfile.ZipFile(filename) as f: data = tf.compat.as_str(f.read(f.namelist()[0])).split() return data -words = read_data(filename) -print('Data size', len(words)) +vocabulary = read_data(filename) +print('Data size', len(vocabulary)) # Step 2: Build the dictionary and replace rare words with UNK token. vocabulary_size = 50000 -def build_dataset(words, vocabulary_size): +def build_dataset(words, n_words): + """Process raw inputs into a dataset.""" count = [['UNK', -1]] - count.extend(collections.Counter(words).most_common(vocabulary_size - 1)) + count.extend(collections.Counter(words).most_common(n_words - 1)) dictionary = dict() for word, _ in count: dictionary[word] = len(dictionary) @@ -78,11 +80,12 @@ def build_dataset(words, vocabulary_size): unk_count += 1 data.append(index) count[0][1] = unk_count - reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys())) - return data, count, dictionary, reverse_dictionary + reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys())) + return data, count, dictionary, reversed_dictionary -data, count, dictionary, reverse_dictionary = build_dataset(words, vocabulary_size) -del words # Hint to reduce memory. +data, count, dictionary, reverse_dictionary = build_dataset(vocabulary, + vocabulary_size) +del vocabulary # Hint to reduce memory. print('Most common words (+UNK)', count[:5]) print('Sample data', data[:10], [reverse_dictionary[i] for i in data[:10]]) @@ -189,7 +192,7 @@ num_steps = 100001 with tf.Session(graph=graph) as session: # We must initialize all variables before we use them. init.run() - print("Initialized") + print('Initialized') average_loss = 0 for step in xrange(num_steps): @@ -206,7 +209,7 @@ with tf.Session(graph=graph) as session: if step > 0: average_loss /= 2000 # The average loss is an estimate of the loss over the last 2000 batches. - print("Average loss at step ", step, ": ", average_loss) + print('Average loss at step ', step, ': ', average_loss) average_loss = 0 # Note that this is expensive (~20% slowdown if computed every 500 steps) @@ -216,10 +219,10 @@ with tf.Session(graph=graph) as session: valid_word = reverse_dictionary[valid_examples[i]] top_k = 8 # number of nearest neighbors nearest = (-sim[i, :]).argsort()[1:top_k + 1] - log_str = "Nearest to %s:" % valid_word + log_str = 'Nearest to %s:' % valid_word for k in xrange(top_k): close_word = reverse_dictionary[nearest[k]] - log_str = "%s %s," % (log_str, close_word) + log_str = '%s %s,' % (log_str, close_word) print(log_str) final_embeddings = normalized_embeddings.eval() @@ -227,7 +230,7 @@ with tf.Session(graph=graph) as session: def plot_with_labels(low_dim_embs, labels, filename='tsne.png'): - assert low_dim_embs.shape[0] >= len(labels), "More labels than embeddings" + assert low_dim_embs.shape[0] >= len(labels), 'More labels than embeddings' plt.figure(figsize=(18, 18)) # in inches for i, label in enumerate(labels): x, y = low_dim_embs[i, :] @@ -242,6 +245,7 @@ def plot_with_labels(low_dim_embs, labels, filename='tsne.png'): plt.savefig(filename) try: + # pylint: disable=g-import-not-at-top from sklearn.manifold import TSNE import matplotlib.pyplot as plt @@ -252,4 +256,4 @@ try: plot_with_labels(low_dim_embs, labels) except ImportError: - print("Please install sklearn, matplotlib, and scipy to visualize embeddings.") + print('Please install sklearn, matplotlib, and scipy to show embeddings.') |