aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-04-10 12:54:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-10 12:56:36 -0700
commit4bf8270ed534c4cd37160e757d7b8a3dc765d1f0 (patch)
treef02b383fe67c07b625437bd0932470031a8343ba /tensorflow/contrib/optimizer_v2
parent22a5485a4f0db8d45efc30492499cba79cc1a47e (diff)
Checkpointable: wrap restore ops in init_scope
This should make restore() work with defun-wrapped code, when variables are created inside the function. Just lifts the restore code into the outer context. Adds a test for it. PiperOrigin-RevId: 192331065
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index 08f9699e85..abcffeb618 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.optimizer_v2 import adam
from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -372,6 +373,50 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(training_continuation + 1,
self.evaluate(root.save_counter))
+ # pylint: disable=cell-var-from-loop
+ @test_util.run_in_graph_and_eager_modes()
+ def testWithDefun(self):
+ num_training_steps = 2
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ for training_continuation in range(3):
+ with ops.Graph().as_default(), self.test_session(
+ graph=ops.get_default_graph()), test_util.device(use_gpu=True):
+ model = MyModel()
+ # Don't actually train so we can test variable values
+ optimizer = adam.AdamOptimizer(0.)
+ root = checkpointable_utils.Checkpoint(
+ optimizer=optimizer, model=model,
+ global_step=training_util.get_or_create_global_step())
+ checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
+ status = root.restore(save_path=checkpoint_path)
+ def train_fn():
+ @function.defun
+ def _call_model(x):
+ return model(x)
+ with backprop.GradientTape() as tape:
+ loss = _call_model(constant_op.constant([[3.]]))
+ gradients = tape.gradient(loss, model.variables)
+ return optimizer.apply_gradients(zip(gradients, model.variables),
+ global_step=root.global_step)
+ if not context.executing_eagerly():
+ train_fn = functools.partial(
+ self.evaluate, train_fn())
+ status.initialize_or_restore()
+ for _ in range(num_training_steps):
+ train_fn()
+ if training_continuation > 0:
+ status.assert_consumed()
+ self.assertAllClose([[42.]], self.evaluate(model.variables[0]))
+ else:
+ self.evaluate(model.variables[0].assign([[42.]]))
+ root.save(file_prefix=checkpoint_prefix)
+ self.assertEqual((training_continuation + 1) * num_training_steps,
+ self.evaluate(root.global_step))
+ self.assertEqual(training_continuation + 1,
+ self.evaluate(root.save_counter))
+ # pylint: enable=cell-var-from-loop
+
def _get_checkpoint_name(self, name):
root = checkpointable.Checkpointable()
checkpointable_utils.add_variable(