aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-07-19 08:56:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-19 09:02:54 -0700
commit2509b3a2152c8dda9fff8ed58f414c1316fa5379 (patch)
tree4660b2dc70d623d8270cff9545edacdcd534a8fd
parente9e48b963b1ad1274ad8a0ad7d07d7fa990fe6b9 (diff)
eager guide: s/tfe.Checkpoint/tf.train.Checkpoint/
PiperOrigin-RevId: 205248470
-rw-r--r--tensorflow/contrib/eager/python/examples/gan/mnist.py5
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py2
-rw-r--r--tensorflow/docs_src/guide/eager.md16
-rw-r--r--third_party/examples/eager/spinn/spinn.py2
4 files changed, 12 insertions, 13 deletions
diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py
index b33243021b..9a42179299 100644
--- a/tensorflow/contrib/eager/python/examples/gan/mnist.py
+++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py
@@ -29,7 +29,6 @@ import time
import tensorflow as tf
-import tensorflow.contrib.eager as tfe
from tensorflow.examples.tutorials.mnist import input_data
layers = tf.keras.layers
@@ -265,7 +264,7 @@ def train_one_epoch(generator, discriminator, generator_optimizer,
def main(_):
(device, data_format) = ('/gpu:0', 'channels_first')
- if FLAGS.no_gpu or tfe.num_gpus() <= 0:
+ if FLAGS.no_gpu or tf.contrib.eager.num_gpus() <= 0:
(device, data_format) = ('/cpu:0', 'channels_last')
print('Using device %s, and data format %s.' % (device, data_format))
@@ -291,7 +290,7 @@ def main(_):
latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if latest_cpkt:
print('Using latest checkpoint at ' + latest_cpkt)
- checkpoint = tfe.Checkpoint(**model_objects)
+ checkpoint = tf.train.Checkpoint(**model_objects)
# Restore variables on creation if a checkpoint exists.
checkpoint.restore(latest_cpkt)
diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
index d64bf5354e..15776c694e 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
+++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
@@ -315,7 +315,7 @@ def main(_):
FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout,
use_cudnn_rnn)
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
- checkpoint = tfe.Checkpoint(
+ checkpoint = tf.train.Checkpoint(
learning_rate=learning_rate, model=model,
# GradientDescentOptimizer has no state to checkpoint, but noting it
# here lets us swap in an optimizer that does.
diff --git a/tensorflow/docs_src/guide/eager.md b/tensorflow/docs_src/guide/eager.md
index 42ad9652f8..3b54d6d2bb 100644
--- a/tensorflow/docs_src/guide/eager.md
+++ b/tensorflow/docs_src/guide/eager.md
@@ -504,13 +504,13 @@ with tf.device("gpu:0"):
### Object-based saving
-`tfe.Checkpoint` can save and restore `tf.Variable`s to and from
+`tf.train.Checkpoint` can save and restore `tf.Variable`s to and from
checkpoints:
```py
x = tf.Variable(10.)
-checkpoint = tfe.Checkpoint(x=x) # save as "x"
+checkpoint = tf.train.Checkpoint(x=x) # save as "x"
x.assign(2.) # Assign a new value to the variables and save.
save_path = checkpoint.save('./ckpt/')
@@ -523,18 +523,18 @@ checkpoint.restore(save_path)
print(x) # => 2.0
```
-To save and load models, `tfe.Checkpoint` stores the internal state of objects,
+To save and load models, `tf.train.Checkpoint` stores the internal state of objects,
without requiring hidden variables. To record the state of a `model`,
-an `optimizer`, and a global step, pass them to a `tfe.Checkpoint`:
+an `optimizer`, and a global step, pass them to a `tf.train.Checkpoint`:
```py
model = MyModel()
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
checkpoint_dir = ‘/path/to/model_dir’
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
-root = tfe.Checkpoint(optimizer=optimizer,
- model=model,
- optimizer_step=tf.train.get_or_create_global_step())
+root = tf.train.Checkpoint(optimizer=optimizer,
+ model=model,
+ optimizer_step=tf.train.get_or_create_global_step())
root.save(file_prefix=checkpoint_prefix)
# or
@@ -824,7 +824,7 @@ gives you eager's interactive experimentation and debuggability with the
distributed performance benefits of graph execution.
Write, debug, and iterate in eager execution, then import the model graph for
-production deployment. Use `tfe.Checkpoint` to save and restore model
+production deployment. Use `tf.train.Checkpoint` to save and restore model
variables, this allows movement between eager and graph execution environments.
See the examples in:
[tensorflow/contrib/eager/python/examples](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples).
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py
index c242ef3fdd..de63ebe9e6 100644
--- a/third_party/examples/eager/spinn/spinn.py
+++ b/third_party/examples/eager/spinn/spinn.py
@@ -626,7 +626,7 @@ def train_or_infer_spinn(embed,
model = SNLIClassifier(config, embed)
global_step = tf.train.get_or_create_global_step()
trainer = SNLIClassifierTrainer(model, config.lr)
- checkpoint = tfe.Checkpoint(trainer=trainer, global_step=global_step)
+ checkpoint = tf.train.Checkpoint(trainer=trainer, global_step=global_step)
checkpoint.restore(tf.train.latest_checkpoint(config.logdir))
if inference_sentence_pair: