diff options
author | 2015-11-06 18:37:11 -0800 | |
---|---|---|
committer | 2015-11-06 18:37:11 -0800 | |
commit | cd9e60c1cd8afef6e39b4b73525d64aee33b656b (patch) | |
tree | a2b18fc3aab6169b0982bd987725325e68d7bd66 /tensorflow/models/embedding/word2vec.py | |
parent | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (diff) |
TensorFlow: Upstream latest changes to Git.
Changes:
- Updates to installation instructions.
- Updates to documentation.
- Minor modifications and tests for word2vec.
Base CL: 107284192
Diffstat (limited to 'tensorflow/models/embedding/word2vec.py')
-rw-r--r-- | tensorflow/models/embedding/word2vec.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/models/embedding/word2vec.py b/tensorflow/models/embedding/word2vec.py index 4ebf3d6f27..a31de44a1d 100644 --- a/tensorflow/models/embedding/word2vec.py +++ b/tensorflow/models/embedding/word2vec.py @@ -402,7 +402,7 @@ class Word2Vec(object): if now - last_checkpoint_time > opts.checkpoint_interval: self.saver.save(self._session, opts.save_path + "model", - global_step=step) + global_step=step.astype(int)) last_checkpoint_time = now if epoch != initial_epoch: break @@ -482,6 +482,9 @@ def _start_shell(local_ns=None): def main(_): """Train a word2vec model.""" + if not FLAGS.train_data or not FLAGS.eval_data or not FLAGS.save_path: + print "--train_data --eval_data and --save_path must be specified." + sys.exit(1) opts = Options() with tf.Graph().as_default(), tf.Session() as session: model = Word2Vec(opts, session) |