aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/models/embedding/word2vec.py
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 18:37:11 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 18:37:11 -0800
commitcd9e60c1cd8afef6e39b4b73525d64aee33b656b (patch)
treea2b18fc3aab6169b0982bd987725325e68d7bd66 /tensorflow/models/embedding/word2vec.py
parentf41959ccb2d9d4c722fe8fc3351401d53bcf4900 (diff)
TensorFlow: Upstream latest changes to Git.
Changes: - Updates to installation instructions. - Updates to documentation. - Minor modifications and tests for word2vec. Base CL: 107284192
Diffstat (limited to 'tensorflow/models/embedding/word2vec.py')
-rw-r--r--tensorflow/models/embedding/word2vec.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/models/embedding/word2vec.py b/tensorflow/models/embedding/word2vec.py
index 4ebf3d6f27..a31de44a1d 100644
--- a/tensorflow/models/embedding/word2vec.py
+++ b/tensorflow/models/embedding/word2vec.py
@@ -402,7 +402,7 @@ class Word2Vec(object):
if now - last_checkpoint_time > opts.checkpoint_interval:
self.saver.save(self._session,
opts.save_path + "model",
- global_step=step)
+ global_step=step.astype(int))
last_checkpoint_time = now
if epoch != initial_epoch:
break
@@ -482,6 +482,9 @@ def _start_shell(local_ns=None):
def main(_):
"""Train a word2vec model."""
+ if not FLAGS.train_data or not FLAGS.eval_data or not FLAGS.save_path:
+ print "--train_data --eval_data and --save_path must be specified."
+ sys.exit(1)
opts = Options()
with tf.Graph().as_default(), tf.Session() as session:
model = Word2Vec(opts, session)