aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py57
1 files changed, 31 insertions, 26 deletions
diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
index a90048d813..be5d60449d 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
+++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
@@ -315,32 +315,37 @@ def main(_):
have_gpu = tfe.num_gpus() > 0
use_cudnn_rnn = not FLAGS.no_use_cudnn_rnn and have_gpu
- with tfe.restore_variables_on_create(
- tf.train.latest_checkpoint(FLAGS.logdir)):
- with tf.device("/device:GPU:0" if have_gpu else None):
- # Make learning_rate a Variable so it can be included in the checkpoint
- # and we can resume training with the last saved learning_rate.
- learning_rate = tfe.Variable(20.0, name="learning_rate")
- sys.stderr.write("learning_rate=%f\n" % learning_rate.numpy())
- model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim,
- FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout,
- use_cudnn_rnn)
- optimizer = tf.train.GradientDescentOptimizer(learning_rate)
-
- best_loss = None
- for _ in range(FLAGS.epoch):
- train(model, optimizer, train_data, FLAGS.seq_len, FLAGS.clip)
- eval_loss = evaluate(model, eval_data)
- if not best_loss or eval_loss < best_loss:
- if FLAGS.logdir:
- tfe.Saver(model.trainable_weights + [learning_rate]).save(
- os.path.join(FLAGS.logdir, "ckpt"))
- best_loss = eval_loss
- else:
- learning_rate.assign(learning_rate / 4.0)
- sys.stderr.write("eval_loss did not reduce in this epoch, "
- "changing learning rate to %f for the next epoch\n" %
- learning_rate.numpy())
+ with tf.device("/device:GPU:0" if have_gpu else None):
+ # Make learning_rate a Variable so it can be included in the checkpoint
+ # and we can resume training with the last saved learning_rate.
+ learning_rate = tfe.Variable(20.0, name="learning_rate")
+ model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim,
+ FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout,
+ use_cudnn_rnn)
+ optimizer = tf.train.GradientDescentOptimizer(learning_rate)
+ checkpoint = tfe.Checkpoint(
+ learning_rate=learning_rate, model=model,
+ # GradientDescentOptimizer has no state to checkpoint, but noting it
+ # here lets us swap in an optimizer that does.
+ optimizer=optimizer)
+ # Restore existing variables now (learning_rate), and restore new variables
+ # on creation if a checkpoint exists.
+ checkpoint.restore(tf.train.latest_checkpoint(FLAGS.logdir))
+ sys.stderr.write("learning_rate=%f\n" % learning_rate.numpy())
+
+ best_loss = None
+ for _ in range(FLAGS.epoch):
+ train(model, optimizer, train_data, FLAGS.seq_len, FLAGS.clip)
+ eval_loss = evaluate(model, eval_data)
+ if not best_loss or eval_loss < best_loss:
+ if FLAGS.logdir:
+ checkpoint.save(os.path.join(FLAGS.logdir, "ckpt"))
+ best_loss = eval_loss
+ else:
+ learning_rate.assign(learning_rate / 4.0)
+ sys.stderr.write("eval_loss did not reduce in this epoch, "
+ "changing learning rate to %f for the next epoch\n" %
+ learning_rate.numpy())
if __name__ == "__main__":