diff options
author | 2018-10-03 16:41:21 -0700 | |
---|---|---|
committer | 2018-10-03 16:53:02 -0700 | |
commit | aeb044c9784d30a25c0d15fa31f479001be55052 (patch) | |
tree | ad370ec8d99e8277808f3ace522a76c5f8c0d188 /tensorflow/python/training/checkpointable/util_test.py | |
parent | d5b362a67a57f53f610536ed6068a5b67bc37b88 (diff) |
assert_nontrivial_match in tf.keras.Model.load_weights (TF format)
Adds a bit of sanity checking by default to load_weights (e.g. for the case when absolutely nothing matches) while still supporting restore-on-create and the addition of new Layers to checkpointed models.
PiperOrigin-RevId: 215652168
Diffstat (limited to 'tensorflow/python/training/checkpointable/util_test.py')
-rw-r--r-- | tensorflow/python/training/checkpointable/util_test.py | 5 |
1 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py index f8b5bd8501..14b47a1940 100644 --- a/tensorflow/python/training/checkpointable/util_test.py +++ b/tensorflow/python/training/checkpointable/util_test.py @@ -437,6 +437,7 @@ class CheckpointingTests(test.TestCase): optimizer=on_create_optimizer, model=on_create_model) # Deferred restoration status = on_create_root.restore(save_path=save_path) + status.assert_nontrivial_match() status.assert_existing_objects_matched() with self.assertRaises(AssertionError): status.assert_consumed() @@ -1509,6 +1510,8 @@ class CheckpointCompatibilityTests(test.TestCase): status.assert_consumed() with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"): status.assert_existing_objects_matched() + with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"): + status.assert_nontrivial_match() else: # When graph building, we haven't read any keys, so we don't know # whether the restore will be complete. @@ -1516,6 +1519,8 @@ class CheckpointCompatibilityTests(test.TestCase): status.assert_consumed() with self.assertRaisesRegexp(AssertionError, "not restored"): status.assert_existing_objects_matched() + with self.assertRaisesRegexp(AssertionError, "not restored"): + status.assert_nontrivial_match() status.run_restore_ops() self._check_sentinels(root) self._set_sentinels(root) |