aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/tutorials
diff options
context:
space:
mode:
authorGravatar Martin Wicke <wicke@google.com>2017-03-24 11:06:18 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-24 12:26:57 -0700
commitf2574c273778eeb05a8ef3ba40544ddee98a9e07 (patch)
treeb50cabf762d11a82a41d46d27a482e2fc4fc08f4 /tensorflow/examples/tutorials
parentec2f8761168c40a76b95220221889b47f82700d9 (diff)
Fix lint issues after pull.
Change: 151154030
Diffstat (limited to 'tensorflow/examples/tutorials')
-rw-r--r--tensorflow/examples/tutorials/word2vec/word2vec_basic.py34
1 files changed, 19 insertions, 15 deletions
diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
index f54a7c37a1..13e5717b0d 100644
--- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
+++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+"""Basic word2vec example."""
from __future__ import absolute_import
from __future__ import division
@@ -50,21 +51,22 @@ filename = maybe_download('text8.zip', 31344016)
# Read the data into a list of strings.
def read_data(filename):
- """Extract the first file enclosed in a zip file as a list of words"""
+ """Extract the first file enclosed in a zip file as a list of words."""
with zipfile.ZipFile(filename) as f:
data = tf.compat.as_str(f.read(f.namelist()[0])).split()
return data
-words = read_data(filename)
-print('Data size', len(words))
+vocabulary = read_data(filename)
+print('Data size', len(vocabulary))
# Step 2: Build the dictionary and replace rare words with UNK token.
vocabulary_size = 50000
-def build_dataset(words, vocabulary_size):
+def build_dataset(words, n_words):
+ """Process raw inputs into a dataset."""
count = [['UNK', -1]]
- count.extend(collections.Counter(words).most_common(vocabulary_size - 1))
+ count.extend(collections.Counter(words).most_common(n_words - 1))
dictionary = dict()
for word, _ in count:
dictionary[word] = len(dictionary)
@@ -78,11 +80,12 @@ def build_dataset(words, vocabulary_size):
unk_count += 1
data.append(index)
count[0][1] = unk_count
- reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
- return data, count, dictionary, reverse_dictionary
+ reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
+ return data, count, dictionary, reversed_dictionary
-data, count, dictionary, reverse_dictionary = build_dataset(words, vocabulary_size)
-del words # Hint to reduce memory.
+data, count, dictionary, reverse_dictionary = build_dataset(vocabulary,
+ vocabulary_size)
+del vocabulary # Hint to reduce memory.
print('Most common words (+UNK)', count[:5])
print('Sample data', data[:10], [reverse_dictionary[i] for i in data[:10]])
@@ -189,7 +192,7 @@ num_steps = 100001
with tf.Session(graph=graph) as session:
# We must initialize all variables before we use them.
init.run()
- print("Initialized")
+ print('Initialized')
average_loss = 0
for step in xrange(num_steps):
@@ -206,7 +209,7 @@ with tf.Session(graph=graph) as session:
if step > 0:
average_loss /= 2000
# The average loss is an estimate of the loss over the last 2000 batches.
- print("Average loss at step ", step, ": ", average_loss)
+ print('Average loss at step ', step, ': ', average_loss)
average_loss = 0
# Note that this is expensive (~20% slowdown if computed every 500 steps)
@@ -216,10 +219,10 @@ with tf.Session(graph=graph) as session:
valid_word = reverse_dictionary[valid_examples[i]]
top_k = 8 # number of nearest neighbors
nearest = (-sim[i, :]).argsort()[1:top_k + 1]
- log_str = "Nearest to %s:" % valid_word
+ log_str = 'Nearest to %s:' % valid_word
for k in xrange(top_k):
close_word = reverse_dictionary[nearest[k]]
- log_str = "%s %s," % (log_str, close_word)
+ log_str = '%s %s,' % (log_str, close_word)
print(log_str)
final_embeddings = normalized_embeddings.eval()
@@ -227,7 +230,7 @@ with tf.Session(graph=graph) as session:
def plot_with_labels(low_dim_embs, labels, filename='tsne.png'):
- assert low_dim_embs.shape[0] >= len(labels), "More labels than embeddings"
+ assert low_dim_embs.shape[0] >= len(labels), 'More labels than embeddings'
plt.figure(figsize=(18, 18)) # in inches
for i, label in enumerate(labels):
x, y = low_dim_embs[i, :]
@@ -242,6 +245,7 @@ def plot_with_labels(low_dim_embs, labels, filename='tsne.png'):
plt.savefig(filename)
try:
+ # pylint: disable=g-import-not-at-top
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
@@ -252,4 +256,4 @@ try:
plot_with_labels(low_dim_embs, labels)
except ImportError:
- print("Please install sklearn, matplotlib, and scipy to visualize embeddings.")
+ print('Please install sklearn, matplotlib, and scipy to show embeddings.')