diff options
author | Asim Shankar <ashankar@google.com> | 2018-07-19 08:56:07 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-19 09:02:54 -0700 |
commit | 2509b3a2152c8dda9fff8ed58f414c1316fa5379 (patch) | |
tree | 4660b2dc70d623d8270cff9545edacdcd534a8fd | |
parent | e9e48b963b1ad1274ad8a0ad7d07d7fa990fe6b9 (diff) |
eager guide: s/tfe.Checkpoint/tf.train.Checkpoint/
PiperOrigin-RevId: 205248470
-rw-r--r-- | tensorflow/contrib/eager/python/examples/gan/mnist.py | 5 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py | 2 | ||||
-rw-r--r-- | tensorflow/docs_src/guide/eager.md | 16 | ||||
-rw-r--r-- | third_party/examples/eager/spinn/spinn.py | 2 |
4 files changed, 12 insertions, 13 deletions
diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py index b33243021b..9a42179299 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py @@ -29,7 +29,6 @@ import time import tensorflow as tf -import tensorflow.contrib.eager as tfe from tensorflow.examples.tutorials.mnist import input_data layers = tf.keras.layers @@ -265,7 +264,7 @@ def train_one_epoch(generator, discriminator, generator_optimizer, def main(_): (device, data_format) = ('/gpu:0', 'channels_first') - if FLAGS.no_gpu or tfe.num_gpus() <= 0: + if FLAGS.no_gpu or tf.contrib.eager.num_gpus() <= 0: (device, data_format) = ('/cpu:0', 'channels_last') print('Using device %s, and data format %s.' % (device, data_format)) @@ -291,7 +290,7 @@ def main(_): latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if latest_cpkt: print('Using latest checkpoint at ' + latest_cpkt) - checkpoint = tfe.Checkpoint(**model_objects) + checkpoint = tf.train.Checkpoint(**model_objects) # Restore variables on creation if a checkpoint exists. checkpoint.restore(latest_cpkt) diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index d64bf5354e..15776c694e 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -315,7 +315,7 @@ def main(_): FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout, use_cudnn_rnn) optimizer = tf.train.GradientDescentOptimizer(learning_rate) - checkpoint = tfe.Checkpoint( + checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, model=model, # GradientDescentOptimizer has no state to checkpoint, but noting it # here lets us swap in an optimizer that does. diff --git a/tensorflow/docs_src/guide/eager.md b/tensorflow/docs_src/guide/eager.md index 42ad9652f8..3b54d6d2bb 100644 --- a/tensorflow/docs_src/guide/eager.md +++ b/tensorflow/docs_src/guide/eager.md @@ -504,13 +504,13 @@ with tf.device("gpu:0"): ### Object-based saving -`tfe.Checkpoint` can save and restore `tf.Variable`s to and from +`tf.train.Checkpoint` can save and restore `tf.Variable`s to and from checkpoints: ```py x = tf.Variable(10.) -checkpoint = tfe.Checkpoint(x=x) # save as "x" +checkpoint = tf.train.Checkpoint(x=x) # save as "x" x.assign(2.) # Assign a new value to the variables and save. save_path = checkpoint.save('./ckpt/') @@ -523,18 +523,18 @@ checkpoint.restore(save_path) print(x) # => 2.0 ``` -To save and load models, `tfe.Checkpoint` stores the internal state of objects, +To save and load models, `tf.train.Checkpoint` stores the internal state of objects, without requiring hidden variables. To record the state of a `model`, -an `optimizer`, and a global step, pass them to a `tfe.Checkpoint`: +an `optimizer`, and a global step, pass them to a `tf.train.Checkpoint`: ```py model = MyModel() optimizer = tf.train.AdamOptimizer(learning_rate=0.001) checkpoint_dir = ‘/path/to/model_dir’ checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") -root = tfe.Checkpoint(optimizer=optimizer, - model=model, - optimizer_step=tf.train.get_or_create_global_step()) +root = tf.train.Checkpoint(optimizer=optimizer, + model=model, + optimizer_step=tf.train.get_or_create_global_step()) root.save(file_prefix=checkpoint_prefix) # or @@ -824,7 +824,7 @@ gives you eager's interactive experimentation and debuggability with the distributed performance benefits of graph execution. Write, debug, and iterate in eager execution, then import the model graph for -production deployment. Use `tfe.Checkpoint` to save and restore model +production deployment. Use `tf.train.Checkpoint` to save and restore model variables, this allows movement between eager and graph execution environments. See the examples in: [tensorflow/contrib/eager/python/examples](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples). diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py index c242ef3fdd..de63ebe9e6 100644 --- a/third_party/examples/eager/spinn/spinn.py +++ b/third_party/examples/eager/spinn/spinn.py @@ -626,7 +626,7 @@ def train_or_infer_spinn(embed, model = SNLIClassifier(config, embed) global_step = tf.train.get_or_create_global_step() trainer = SNLIClassifierTrainer(model, config.lr) - checkpoint = tfe.Checkpoint(trainer=trainer, global_step=global_step) + checkpoint = tf.train.Checkpoint(trainer=trainer, global_step=global_step) checkpoint.restore(tf.train.latest_checkpoint(config.logdir)) if inference_sentence_pair: |