aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/examples/eager/spinn/spinn.py
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/examples/eager/spinn/spinn.py')
-rw-r--r--third_party/examples/eager/spinn/spinn.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py
index 67456a5bdf..de63ebe9e6 100644
--- a/third_party/examples/eager/spinn/spinn.py
+++ b/third_party/examples/eager/spinn/spinn.py
@@ -419,7 +419,7 @@ class SNLIClassifierTrainer(tfe.Checkpointable):
# Create a custom learning rate Variable for the RMSProp optimizer, because
# the learning rate needs to be manually decayed later (see
# decay_learning_rate()).
- self._learning_rate = tfe.Variable(lr, name="learning_rate")
+ self._learning_rate = tf.Variable(lr, name="learning_rate")
self._optimizer = tf.train.RMSPropOptimizer(self._learning_rate,
epsilon=1e-6)
@@ -626,7 +626,7 @@ def train_or_infer_spinn(embed,
model = SNLIClassifier(config, embed)
global_step = tf.train.get_or_create_global_step()
trainer = SNLIClassifierTrainer(model, config.lr)
- checkpoint = tfe.Checkpoint(trainer=trainer, global_step=global_step)
+ checkpoint = tf.train.Checkpoint(trainer=trainer, global_step=global_step)
checkpoint.restore(tf.train.latest_checkpoint(config.logdir))
if inference_sentence_pair: