diff options
author | Allen Lavoie <allenl@google.com> | 2018-08-27 16:47:28 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-27 16:51:39 -0700 |
commit | f577ae972f457cd7ba8dc8be14a80d8d6e27b8cb (patch) | |
tree | 788af09acc7cce36e479dd8a0a2af0ff6dc856df /tensorflow/python/training | |
parent | 9500c1d80de70dabd1b538287a667c6fda0c394d (diff) |
Checkpointable: Fix the ignore-missing logic for name-based checkpoint restores
Restore previously checked if a key existed, but didn't quite ignore that value properly if it was missing.
PiperOrigin-RevId: 210455409
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r-- | tensorflow/python/training/checkpointable/util.py | 13 | ||||
-rw-r--r-- | tensorflow/python/training/checkpointable/util_test.py | 6 |
2 files changed, 16 insertions, 3 deletions
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py index d1b50d1362..45d217e8b1 100644 --- a/tensorflow/python/training/checkpointable/util.py +++ b/tensorflow/python/training/checkpointable/util.py @@ -199,6 +199,7 @@ class _NameBasedRestoreCoordinator(object): for saveable in self.globally_named_object_attributes( checkpointable): restored_tensors = [] + tensor_missing = False for spec in saveable.specs: if spec.name in self.dtype_map: with ops.device("cpu:0"): @@ -209,9 +210,15 @@ class _NameBasedRestoreCoordinator(object): dtypes=[self.dtype_map[spec.name]], name="%s_checkpoint_read" % (spec.name,)) restored_tensors.append(array_ops.identity(restored)) - - saveable.restore(restored_tensors=restored_tensors, - restored_shapes=None) + else: + tensor_missing = True + + if not tensor_missing: + # Ignores values missing from the checkpoint, as with object-based + # restore. Status assertions can be used to check exact matches, + # although it's unlikely to ever happen for name-based checkpoints. + saveable.restore(restored_tensors=restored_tensors, + restored_shapes=None) # TODO(allenl): If this ends up in a public API, consider adding LINT.IfChange diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py index 697b44c3ff..bef4bf2a16 100644 --- a/tensorflow/python/training/checkpointable/util_test.py +++ b/tensorflow/python/training/checkpointable/util_test.py @@ -1482,6 +1482,12 @@ class CheckpointCompatibilityTests(test.TestCase): status = object_saver.restore(save_path) status.initialize_or_restore() self._check_sentinels(root) + # Check that there is no error when keys are missing from the name-based + # checkpoint. + root.not_in_name_checkpoint = resource_variable_ops.ResourceVariable([1.]) + status = object_saver.restore(save_path) + with self.assertRaises(AssertionError): + status.assert_existing_objects_matched() def testSaveGraphLoadEager(self): checkpoint_directory = self.get_temp_dir() |