aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-08-27 16:47:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 16:51:39 -0700
commitf577ae972f457cd7ba8dc8be14a80d8d6e27b8cb (patch)
tree788af09acc7cce36e479dd8a0a2af0ff6dc856df /tensorflow/python/training
parent9500c1d80de70dabd1b538287a667c6fda0c394d (diff)
Checkpointable: Fix the ignore-missing logic for name-based checkpoint restores
Restore previously checked if a key existed, but didn't quite ignore that value properly if it was missing. PiperOrigin-RevId: 210455409
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r--tensorflow/python/training/checkpointable/util.py13
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py6
2 files changed, 16 insertions, 3 deletions
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index d1b50d1362..45d217e8b1 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -199,6 +199,7 @@ class _NameBasedRestoreCoordinator(object):
for saveable in self.globally_named_object_attributes(
checkpointable):
restored_tensors = []
+ tensor_missing = False
for spec in saveable.specs:
if spec.name in self.dtype_map:
with ops.device("cpu:0"):
@@ -209,9 +210,15 @@ class _NameBasedRestoreCoordinator(object):
dtypes=[self.dtype_map[spec.name]],
name="%s_checkpoint_read" % (spec.name,))
restored_tensors.append(array_ops.identity(restored))
-
- saveable.restore(restored_tensors=restored_tensors,
- restored_shapes=None)
+ else:
+ tensor_missing = True
+
+ if not tensor_missing:
+ # Ignores values missing from the checkpoint, as with object-based
+ # restore. Status assertions can be used to check exact matches,
+ # although it's unlikely to ever happen for name-based checkpoints.
+ saveable.restore(restored_tensors=restored_tensors,
+ restored_shapes=None)
# TODO(allenl): If this ends up in a public API, consider adding LINT.IfChange
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index 697b44c3ff..bef4bf2a16 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -1482,6 +1482,12 @@ class CheckpointCompatibilityTests(test.TestCase):
status = object_saver.restore(save_path)
status.initialize_or_restore()
self._check_sentinels(root)
+ # Check that there is no error when keys are missing from the name-based
+ # checkpoint.
+ root.not_in_name_checkpoint = resource_variable_ops.ResourceVariable([1.])
+ status = object_saver.restore(save_path)
+ with self.assertRaises(AssertionError):
+ status.assert_existing_objects_matched()
def testSaveGraphLoadEager(self):
checkpoint_directory = self.get_temp_dir()