diff options
author | 2018-05-15 14:26:21 -0700 | |
---|---|---|
committer | 2018-05-15 14:30:04 -0700 | |
commit | af86ca4983fe14214e032e0d76cb3a08dc8e1e9e (patch) | |
tree | 84a03a9cf8cacb6f8378c9e6fadf95eecf5f1a0b /tensorflow/contrib/optimizer_v2 | |
parent | c430edbb088a96db529a0a13438d1f629e48b6f0 (diff) |
Checkpointable: Restore-on-create for name-based checkpoints when executing eagerly
Should make loading name-based checkpoints more natural with object-based APIs when executing eagerly. Before this CL they could be loaded, but users needed to use "run_restore_ops" after all variables were created (which is less useful and confusing).
PiperOrigin-RevId: 196729311
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r-- | tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index b1f2e9d860..20316ec0e3 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -722,12 +722,22 @@ class CheckpointCompatibilityTests(test.TestCase): with self.assertRaises(AssertionError): self._check_sentinels(root) object_saver = checkpointable_utils.CheckpointableSaver(root) + self._set_sentinels(root) status = object_saver.restore(save_path) - with self.assertRaises(AssertionError): - status.assert_consumed() + if context.executing_eagerly(): + self._check_sentinels(root) + if context.executing_eagerly(): + with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"): + status.assert_consumed() + else: + # When graph building, we haven't read any keys, so we don't know + # whether the restore will be complete. + with self.assertRaisesRegexp(AssertionError, "not restored"): + status.assert_consumed() status.run_restore_ops() self._check_sentinels(root) self._set_sentinels(root) + status = object_saver.restore(save_path) status.initialize_or_restore() self._check_sentinels(root) |