aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpointable/util_test.py
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-10-03 16:41:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 16:53:02 -0700
commitaeb044c9784d30a25c0d15fa31f479001be55052 (patch)
treead370ec8d99e8277808f3ace522a76c5f8c0d188 /tensorflow/python/training/checkpointable/util_test.py
parentd5b362a67a57f53f610536ed6068a5b67bc37b88 (diff)
assert_nontrivial_match in tf.keras.Model.load_weights (TF format)
Adds a bit of sanity checking by default to load_weights (e.g. for the case when absolutely nothing matches) while still supporting restore-on-create and the addition of new Layers to checkpointed models. PiperOrigin-RevId: 215652168
Diffstat (limited to 'tensorflow/python/training/checkpointable/util_test.py')
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py5
1 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index f8b5bd8501..14b47a1940 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -437,6 +437,7 @@ class CheckpointingTests(test.TestCase):
optimizer=on_create_optimizer, model=on_create_model)
# Deferred restoration
status = on_create_root.restore(save_path=save_path)
+ status.assert_nontrivial_match()
status.assert_existing_objects_matched()
with self.assertRaises(AssertionError):
status.assert_consumed()
@@ -1509,6 +1510,8 @@ class CheckpointCompatibilityTests(test.TestCase):
status.assert_consumed()
with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"):
status.assert_existing_objects_matched()
+ with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"):
+ status.assert_nontrivial_match()
else:
# When graph building, we haven't read any keys, so we don't know
# whether the restore will be complete.
@@ -1516,6 +1519,8 @@ class CheckpointCompatibilityTests(test.TestCase):
status.assert_consumed()
with self.assertRaisesRegexp(AssertionError, "not restored"):
status.assert_existing_objects_matched()
+ with self.assertRaisesRegexp(AssertionError, "not restored"):
+ status.assert_nontrivial_match()
status.run_restore_ops()
self._check_sentinels(root)
self._set_sentinels(root)