diff options
author | 2018-10-03 16:41:21 -0700 | |
---|---|---|
committer | 2018-10-03 16:53:02 -0700 | |
commit | aeb044c9784d30a25c0d15fa31f479001be55052 (patch) | |
tree | ad370ec8d99e8277808f3ace522a76c5f8c0d188 /tensorflow/python/keras | |
parent | d5b362a67a57f53f610536ed6068a5b67bc37b88 (diff) |
assert_nontrivial_match in tf.keras.Model.load_weights (TF format)
Adds a bit of sanity checking by default to load_weights (e.g. for the case when absolutely nothing matches) while still supporting restore-on-create and the addition of new Layers to checkpointed models.
PiperOrigin-RevId: 215652168
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r-- | tensorflow/python/keras/engine/network.py | 1 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/saving_test.py | 13 |
2 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 5ef8d13487..8d34006967 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -1526,6 +1526,7 @@ class Network(base_layer.Layer): # Restore existing variables (if any) immediately, and set up a # streaming restore for any variables created in the future. checkpointable_utils.streaming_restore(status=status, session=session) + status.assert_nontrivial_match() return status if h5py is None: raise ImportError( diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py index 02d99d5d69..f5045be907 100644 --- a/tensorflow/python/keras/engine/saving_test.py +++ b/tensorflow/python/keras/engine/saving_test.py @@ -38,6 +38,7 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_management from tensorflow.python.training import training as training_module +from tensorflow.python.training.checkpointable import util as checkpointable try: import h5py # pylint:disable=g-import-not-at-top @@ -922,6 +923,18 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): SubclassedModel, SubclassedModelRestore, _restore_init_fn) + @test_util.run_in_graph_and_eager_modes + def test_incompatible_checkpoint(self): + save_path = checkpointable.Checkpoint().save( + os.path.join(self.get_temp_dir(), 'ckpt')) + m = keras.Model() + with self.assertRaisesRegexp(AssertionError, 'Nothing to load'): + m.load_weights(save_path) + m.dense = keras.layers.Dense(2) + m.dense(constant_op.constant([[1.]])) + with self.assertRaisesRegexp( + AssertionError, 'Nothing except the root object matched'): + m.load_weights(save_path) if __name__ == '__main__': test.main() |