diff options
author | Allen Lavoie <allenl@google.com> | 2018-10-03 16:41:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 16:53:02 -0700 |
commit | aeb044c9784d30a25c0d15fa31f479001be55052 (patch) | |
tree | ad370ec8d99e8277808f3ace522a76c5f8c0d188 /tensorflow/python | |
parent | d5b362a67a57f53f610536ed6068a5b67bc37b88 (diff) |
assert_nontrivial_match in tf.keras.Model.load_weights (TF format)
Adds a bit of sanity checking by default to load_weights (e.g. for the case when absolutely nothing matches) while still supporting restore-on-create and the addition of new Layers to checkpointed models.
PiperOrigin-RevId: 215652168
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/keras/engine/network.py | 1 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/saving_test.py | 13 | ||||
-rw-r--r-- | tensorflow/python/training/checkpointable/util.py | 56 | ||||
-rw-r--r-- | tensorflow/python/training/checkpointable/util_test.py | 5 |
4 files changed, 71 insertions, 4 deletions
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 5ef8d13487..8d34006967 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -1526,6 +1526,7 @@ class Network(base_layer.Layer): # Restore existing variables (if any) immediately, and set up a # streaming restore for any variables created in the future. checkpointable_utils.streaming_restore(status=status, session=session) + status.assert_nontrivial_match() return status if h5py is None: raise ImportError( diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py index 02d99d5d69..f5045be907 100644 --- a/tensorflow/python/keras/engine/saving_test.py +++ b/tensorflow/python/keras/engine/saving_test.py @@ -38,6 +38,7 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_management from tensorflow.python.training import training as training_module +from tensorflow.python.training.checkpointable import util as checkpointable try: import h5py # pylint:disable=g-import-not-at-top @@ -922,6 +923,18 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): SubclassedModel, SubclassedModelRestore, _restore_init_fn) + @test_util.run_in_graph_and_eager_modes + def test_incompatible_checkpoint(self): + save_path = checkpointable.Checkpoint().save( + os.path.join(self.get_temp_dir(), 'ckpt')) + m = keras.Model() + with self.assertRaisesRegexp(AssertionError, 'Nothing to load'): + m.load_weights(save_path) + m.dense = keras.layers.Dense(2) + m.dense(constant_op.constant([[1.]])) + with self.assertRaisesRegexp( + AssertionError, 'Nothing except the root object matched'): + m.load_weights(save_path) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py index eff15b24ce..edab6cc6eb 100644 --- a/tensorflow/python/training/checkpointable/util.py +++ b/tensorflow/python/training/checkpointable/util.py @@ -854,6 +854,11 @@ class _LoadStatus(object): pass @abc.abstractmethod + def assert_nontrivial_match(self): + """Raises an exception if only the root object matched.""" + pass + + @abc.abstractmethod def run_restore_ops(self, session=None): """Runs restore ops from the checkpoint. Requires a valid checkpoint.""" pass @@ -975,6 +980,26 @@ class CheckpointLoadStatus(_LoadStatus): % (list(unused_python_objects),)) return self + def assert_nontrivial_match(self): + """Raises an exception if only the root object matched.""" + for checkpointable_object in list_objects(self._root_checkpointable): + self._checkpoint.all_python_objects.add(checkpointable_object) + if len(self._checkpoint.object_by_proto_id) <= 1: + unused_python_objects = ( + _ObjectIdentitySet(self._checkpoint.all_python_objects) + - _ObjectIdentitySet(self._checkpoint.object_by_proto_id.values())) + if unused_python_objects: + raise AssertionError( + ("Nothing except the root object matched a checkpointed value. " + "Typically this means that the checkpoint does not match the " + "Python program. The following objects have no matching " + "checkpointed value: %s") % (list(unused_python_objects),)) + else: + raise AssertionError( + "Nothing to load. No dependencies have been added to %s yet." % ( + self._root_checkpointable,)) + return self + def run_restore_ops(self, session=None): """Run operations to restore objects in the dependency graph.""" if context.executing_eagerly(): @@ -1039,6 +1064,11 @@ class InitializationOnlyStatus(_LoadStatus): raise AssertionError( "No checkpoint specified (save_path=None); nothing is being restored.") + def assert_nontrivial_match(self): + """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" + raise AssertionError( + "No checkpoint specified (save_path=None); nothing is being restored.") + def run_restore_ops(self, session=None): """For consistency with `CheckpointLoadStatus`. @@ -1122,6 +1152,14 @@ class NameBasedSaverStatus(_LoadStatus): # useful since we don't touch Python objects or Python state). return self.assert_consumed() + def assert_nontrivial_match(self): + """Raises an exception if currently created objects are unmatched.""" + # For name-based checkpoints there's no object information in the + # checkpoint, so there's no distinction between + # assert_nontrivial_match and assert_consumed (and both are less + # useful since we don't touch Python objects or Python state). + return self.assert_consumed() + def _gather_saveable_objects(self): """Walk the object graph, using global names for SaveableObjects.""" objects = list_objects(self._root_checkpointable) @@ -1779,13 +1817,15 @@ class Checkpoint(tracking.Checkpointable): status of a checkpoint restoration and run initialization/restore ops. The returned status object has the following methods: - - `assert_consumed()`: + + * `assert_consumed()`: Raises an exception if any variables/objects are unmatched: either checkpointed values which don't have a matching Python object or Python objects in the dependency graph with no values in the checkpoint. This method returns the status object, and so may be chained with `initialize_or_restore` or `run_restore_ops`. - - `assert_existing_objects_matched()`: + + * `assert_existing_objects_matched()`: Raises an exception if any existing Python objects in the dependency graph are unmatched. Unlike `assert_consumed`, this assertion will pass if values in the checkpoint have no corresponding Python @@ -1796,12 +1836,20 @@ class Checkpoint(tracking.Checkpointable): a `tf.train.Optimizer` was saved but only the state required for inference is being loaded. This method returns the status object, and so may be chained with `initialize_or_restore` or `run_restore_ops`. - - `initialize_or_restore(session=None)`: + + * `assert_nontrivial_match()`: Asserts that something aside from the root + object was matched. This is a very weak assertion, but is useful for + sanity checking in library code where objects may exist in the + checkpoint which haven't been created in Python and some Python + objects may not have a checkpointed value. + + * `initialize_or_restore(session=None)`: When graph building, runs variable initializers if `save_path` is `None`, but otherwise runs restore operations. If no `session` is explicitly specified, the default session is used. No effect when executing eagerly (variables are initialized or restored eagerly). - - `run_restore_ops(session=None)`: + + * `run_restore_ops(session=None)`: When graph building, runs restore operations. If no `session` is explicitly specified, the default session is used. No effect when executing eagerly (restore operations are run eagerly). May only be diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py index f8b5bd8501..14b47a1940 100644 --- a/tensorflow/python/training/checkpointable/util_test.py +++ b/tensorflow/python/training/checkpointable/util_test.py @@ -437,6 +437,7 @@ class CheckpointingTests(test.TestCase): optimizer=on_create_optimizer, model=on_create_model) # Deferred restoration status = on_create_root.restore(save_path=save_path) + status.assert_nontrivial_match() status.assert_existing_objects_matched() with self.assertRaises(AssertionError): status.assert_consumed() @@ -1509,6 +1510,8 @@ class CheckpointCompatibilityTests(test.TestCase): status.assert_consumed() with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"): status.assert_existing_objects_matched() + with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"): + status.assert_nontrivial_match() else: # When graph building, we haven't read any keys, so we don't know # whether the restore will be complete. @@ -1516,6 +1519,8 @@ class CheckpointCompatibilityTests(test.TestCase): status.assert_consumed() with self.assertRaisesRegexp(AssertionError, "not restored"): status.assert_existing_objects_matched() + with self.assertRaisesRegexp(AssertionError, "not restored"): + status.assert_nontrivial_match() status.run_restore_ops() self._check_sentinels(root) self._set_sentinels(root) |