aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-05-15 14:26:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-15 14:30:04 -0700
commitaf86ca4983fe14214e032e0d76cb3a08dc8e1e9e (patch)
tree84a03a9cf8cacb6f8378c9e6fadf95eecf5f1a0b /tensorflow/contrib/optimizer_v2
parentc430edbb088a96db529a0a13438d1f629e48b6f0 (diff)
Checkpointable: Restore-on-create for name-based checkpoints when executing eagerly
Should make loading name-based checkpoints more natural with object-based APIs when executing eagerly. Before this CL they could be loaded, but users needed to use "run_restore_ops" after all variables were created (which is less useful and confusing). PiperOrigin-RevId: 196729311
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py14
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index b1f2e9d860..20316ec0e3 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -722,12 +722,22 @@ class CheckpointCompatibilityTests(test.TestCase):
with self.assertRaises(AssertionError):
self._check_sentinels(root)
object_saver = checkpointable_utils.CheckpointableSaver(root)
+ self._set_sentinels(root)
status = object_saver.restore(save_path)
- with self.assertRaises(AssertionError):
- status.assert_consumed()
+ if context.executing_eagerly():
+ self._check_sentinels(root)
+ if context.executing_eagerly():
+ with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"):
+ status.assert_consumed()
+ else:
+ # When graph building, we haven't read any keys, so we don't know
+ # whether the restore will be complete.
+ with self.assertRaisesRegexp(AssertionError, "not restored"):
+ status.assert_consumed()
status.run_restore_ops()
self._check_sentinels(root)
self._set_sentinels(root)
+ status = object_saver.restore(save_path)
status.initialize_or_restore()
self._check_sentinels(root)