aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-10-03 16:41:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 16:53:02 -0700
commitaeb044c9784d30a25c0d15fa31f479001be55052 (patch)
treead370ec8d99e8277808f3ace522a76c5f8c0d188 /tensorflow/python
parentd5b362a67a57f53f610536ed6068a5b67bc37b88 (diff)
assert_nontrivial_match in tf.keras.Model.load_weights (TF format)
Adds a bit of sanity checking by default to load_weights (e.g. for the case when absolutely nothing matches) while still supporting restore-on-create and the addition of new Layers to checkpointed models. PiperOrigin-RevId: 215652168
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/keras/engine/network.py1
-rw-r--r--tensorflow/python/keras/engine/saving_test.py13
-rw-r--r--tensorflow/python/training/checkpointable/util.py56
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py5
4 files changed, 71 insertions, 4 deletions
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 5ef8d13487..8d34006967 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -1526,6 +1526,7 @@ class Network(base_layer.Layer):
# Restore existing variables (if any) immediately, and set up a
# streaming restore for any variables created in the future.
checkpointable_utils.streaming_restore(status=status, session=session)
+ status.assert_nontrivial_match()
return status
if h5py is None:
raise ImportError(
diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py
index 02d99d5d69..f5045be907 100644
--- a/tensorflow/python/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/engine/saving_test.py
@@ -38,6 +38,7 @@ from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import training as training_module
+from tensorflow.python.training.checkpointable import util as checkpointable
try:
import h5py # pylint:disable=g-import-not-at-top
@@ -922,6 +923,18 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
SubclassedModel, SubclassedModelRestore,
_restore_init_fn)
+ @test_util.run_in_graph_and_eager_modes
+ def test_incompatible_checkpoint(self):
+ save_path = checkpointable.Checkpoint().save(
+ os.path.join(self.get_temp_dir(), 'ckpt'))
+ m = keras.Model()
+ with self.assertRaisesRegexp(AssertionError, 'Nothing to load'):
+ m.load_weights(save_path)
+ m.dense = keras.layers.Dense(2)
+ m.dense(constant_op.constant([[1.]]))
+ with self.assertRaisesRegexp(
+ AssertionError, 'Nothing except the root object matched'):
+ m.load_weights(save_path)
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index eff15b24ce..edab6cc6eb 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -854,6 +854,11 @@ class _LoadStatus(object):
pass
@abc.abstractmethod
+ def assert_nontrivial_match(self):
+ """Raises an exception if only the root object matched."""
+ pass
+
+ @abc.abstractmethod
def run_restore_ops(self, session=None):
"""Runs restore ops from the checkpoint. Requires a valid checkpoint."""
pass
@@ -975,6 +980,26 @@ class CheckpointLoadStatus(_LoadStatus):
% (list(unused_python_objects),))
return self
+ def assert_nontrivial_match(self):
+ """Raises an exception if only the root object matched."""
+ for checkpointable_object in list_objects(self._root_checkpointable):
+ self._checkpoint.all_python_objects.add(checkpointable_object)
+ if len(self._checkpoint.object_by_proto_id) <= 1:
+ unused_python_objects = (
+ _ObjectIdentitySet(self._checkpoint.all_python_objects)
+ - _ObjectIdentitySet(self._checkpoint.object_by_proto_id.values()))
+ if unused_python_objects:
+ raise AssertionError(
+ ("Nothing except the root object matched a checkpointed value. "
+ "Typically this means that the checkpoint does not match the "
+ "Python program. The following objects have no matching "
+ "checkpointed value: %s") % (list(unused_python_objects),))
+ else:
+ raise AssertionError(
+ "Nothing to load. No dependencies have been added to %s yet." % (
+ self._root_checkpointable,))
+ return self
+
def run_restore_ops(self, session=None):
"""Run operations to restore objects in the dependency graph."""
if context.executing_eagerly():
@@ -1039,6 +1064,11 @@ class InitializationOnlyStatus(_LoadStatus):
raise AssertionError(
"No checkpoint specified (save_path=None); nothing is being restored.")
+ def assert_nontrivial_match(self):
+ """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
+ raise AssertionError(
+ "No checkpoint specified (save_path=None); nothing is being restored.")
+
def run_restore_ops(self, session=None):
"""For consistency with `CheckpointLoadStatus`.
@@ -1122,6 +1152,14 @@ class NameBasedSaverStatus(_LoadStatus):
# useful since we don't touch Python objects or Python state).
return self.assert_consumed()
+ def assert_nontrivial_match(self):
+ """Raises an exception if currently created objects are unmatched."""
+ # For name-based checkpoints there's no object information in the
+ # checkpoint, so there's no distinction between
+ # assert_nontrivial_match and assert_consumed (and both are less
+ # useful since we don't touch Python objects or Python state).
+ return self.assert_consumed()
+
def _gather_saveable_objects(self):
"""Walk the object graph, using global names for SaveableObjects."""
objects = list_objects(self._root_checkpointable)
@@ -1779,13 +1817,15 @@ class Checkpoint(tracking.Checkpointable):
status of a checkpoint restoration and run initialization/restore ops.
The returned status object has the following methods:
- - `assert_consumed()`:
+
+ * `assert_consumed()`:
Raises an exception if any variables/objects are unmatched: either
checkpointed values which don't have a matching Python object or
Python objects in the dependency graph with no values in the
checkpoint. This method returns the status object, and so may be
chained with `initialize_or_restore` or `run_restore_ops`.
- - `assert_existing_objects_matched()`:
+
+ * `assert_existing_objects_matched()`:
Raises an exception if any existing Python objects in the dependency
graph are unmatched. Unlike `assert_consumed`, this assertion will
pass if values in the checkpoint have no corresponding Python
@@ -1796,12 +1836,20 @@ class Checkpoint(tracking.Checkpointable):
a `tf.train.Optimizer` was saved but only the state required for
inference is being loaded. This method returns the status object, and
so may be chained with `initialize_or_restore` or `run_restore_ops`.
- - `initialize_or_restore(session=None)`:
+
+ * `assert_nontrivial_match()`: Asserts that something aside from the root
+ object was matched. This is a very weak assertion, but is useful for
+ sanity checking in library code where objects may exist in the
+ checkpoint which haven't been created in Python and some Python
+ objects may not have a checkpointed value.
+
+ * `initialize_or_restore(session=None)`:
When graph building, runs variable initializers if `save_path` is
`None`, but otherwise runs restore operations. If no `session` is
explicitly specified, the default session is used. No effect when
executing eagerly (variables are initialized or restored eagerly).
- - `run_restore_ops(session=None)`:
+
+ * `run_restore_ops(session=None)`:
When graph building, runs restore operations. If no `session` is
explicitly specified, the default session is used. No effect when
executing eagerly (restore operations are run eagerly). May only be
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index f8b5bd8501..14b47a1940 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -437,6 +437,7 @@ class CheckpointingTests(test.TestCase):
optimizer=on_create_optimizer, model=on_create_model)
# Deferred restoration
status = on_create_root.restore(save_path=save_path)
+ status.assert_nontrivial_match()
status.assert_existing_objects_matched()
with self.assertRaises(AssertionError):
status.assert_consumed()
@@ -1509,6 +1510,8 @@ class CheckpointCompatibilityTests(test.TestCase):
status.assert_consumed()
with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"):
status.assert_existing_objects_matched()
+ with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"):
+ status.assert_nontrivial_match()
else:
# When graph building, we haven't read any keys, so we don't know
# whether the restore will be complete.
@@ -1516,6 +1519,8 @@ class CheckpointCompatibilityTests(test.TestCase):
status.assert_consumed()
with self.assertRaisesRegexp(AssertionError, "not restored"):
status.assert_existing_objects_matched()
+ with self.assertRaisesRegexp(AssertionError, "not restored"):
+ status.assert_nontrivial_match()
status.run_restore_ops()
self._check_sentinels(root)
self._set_sentinels(root)