diff options
-rw-r--r-- | tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py | 57 |
1 files changed, 31 insertions, 26 deletions
diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index a90048d813..be5d60449d 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -315,32 +315,37 @@ def main(_): have_gpu = tfe.num_gpus() > 0 use_cudnn_rnn = not FLAGS.no_use_cudnn_rnn and have_gpu - with tfe.restore_variables_on_create( - tf.train.latest_checkpoint(FLAGS.logdir)): - with tf.device("/device:GPU:0" if have_gpu else None): - # Make learning_rate a Variable so it can be included in the checkpoint - # and we can resume training with the last saved learning_rate. - learning_rate = tfe.Variable(20.0, name="learning_rate") - sys.stderr.write("learning_rate=%f\n" % learning_rate.numpy()) - model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim, - FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout, - use_cudnn_rnn) - optimizer = tf.train.GradientDescentOptimizer(learning_rate) - - best_loss = None - for _ in range(FLAGS.epoch): - train(model, optimizer, train_data, FLAGS.seq_len, FLAGS.clip) - eval_loss = evaluate(model, eval_data) - if not best_loss or eval_loss < best_loss: - if FLAGS.logdir: - tfe.Saver(model.trainable_weights + [learning_rate]).save( - os.path.join(FLAGS.logdir, "ckpt")) - best_loss = eval_loss - else: - learning_rate.assign(learning_rate / 4.0) - sys.stderr.write("eval_loss did not reduce in this epoch, " - "changing learning rate to %f for the next epoch\n" % - learning_rate.numpy()) + with tf.device("/device:GPU:0" if have_gpu else None): + # Make learning_rate a Variable so it can be included in the checkpoint + # and we can resume training with the last saved learning_rate. + learning_rate = tfe.Variable(20.0, name="learning_rate") + model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim, + FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout, + use_cudnn_rnn) + optimizer = tf.train.GradientDescentOptimizer(learning_rate) + checkpoint = tfe.Checkpoint( + learning_rate=learning_rate, model=model, + # GradientDescentOptimizer has no state to checkpoint, but noting it + # here lets us swap in an optimizer that does. + optimizer=optimizer) + # Restore existing variables now (learning_rate), and restore new variables + # on creation if a checkpoint exists. + checkpoint.restore(tf.train.latest_checkpoint(FLAGS.logdir)) + sys.stderr.write("learning_rate=%f\n" % learning_rate.numpy()) + + best_loss = None + for _ in range(FLAGS.epoch): + train(model, optimizer, train_data, FLAGS.seq_len, FLAGS.clip) + eval_loss = evaluate(model, eval_data) + if not best_loss or eval_loss < best_loss: + if FLAGS.logdir: + checkpoint.save(os.path.join(FLAGS.logdir, "ckpt")) + best_loss = eval_loss + else: + learning_rate.assign(learning_rate / 4.0) + sys.stderr.write("eval_loss did not reduce in this epoch, " + "changing learning rate to %f for the next epoch\n" % + learning_rate.numpy()) if __name__ == "__main__": |