diff options
Diffstat (limited to 'third_party/examples/eager/spinn/spinn.py')
-rw-r--r-- | third_party/examples/eager/spinn/spinn.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py index 67456a5bdf..de63ebe9e6 100644 --- a/third_party/examples/eager/spinn/spinn.py +++ b/third_party/examples/eager/spinn/spinn.py @@ -419,7 +419,7 @@ class SNLIClassifierTrainer(tfe.Checkpointable): # Create a custom learning rate Variable for the RMSProp optimizer, because # the learning rate needs to be manually decayed later (see # decay_learning_rate()). - self._learning_rate = tfe.Variable(lr, name="learning_rate") + self._learning_rate = tf.Variable(lr, name="learning_rate") self._optimizer = tf.train.RMSPropOptimizer(self._learning_rate, epsilon=1e-6) @@ -626,7 +626,7 @@ def train_or_infer_spinn(embed, model = SNLIClassifier(config, embed) global_step = tf.train.get_or_create_global_step() trainer = SNLIClassifierTrainer(model, config.lr) - checkpoint = tfe.Checkpoint(trainer=trainer, global_step=global_step) + checkpoint = tf.train.Checkpoint(trainer=trainer, global_step=global_step) checkpoint.restore(tf.train.latest_checkpoint(config.logdir)) if inference_sentence_pair: |