aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-10-03 16:41:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 16:53:02 -0700
commitaeb044c9784d30a25c0d15fa31f479001be55052 (patch)
treead370ec8d99e8277808f3ace522a76c5f8c0d188 /tensorflow/python/keras
parentd5b362a67a57f53f610536ed6068a5b67bc37b88 (diff)
assert_nontrivial_match in tf.keras.Model.load_weights (TF format)
Adds a bit of sanity checking by default to load_weights (e.g. for the case when absolutely nothing matches) while still supporting restore-on-create and the addition of new Layers to checkpointed models. PiperOrigin-RevId: 215652168
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r--tensorflow/python/keras/engine/network.py1
-rw-r--r--tensorflow/python/keras/engine/saving_test.py13
2 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 5ef8d13487..8d34006967 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -1526,6 +1526,7 @@ class Network(base_layer.Layer):
# Restore existing variables (if any) immediately, and set up a
# streaming restore for any variables created in the future.
checkpointable_utils.streaming_restore(status=status, session=session)
+ status.assert_nontrivial_match()
return status
if h5py is None:
raise ImportError(
diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py
index 02d99d5d69..f5045be907 100644
--- a/tensorflow/python/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/engine/saving_test.py
@@ -38,6 +38,7 @@ from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import training as training_module
+from tensorflow.python.training.checkpointable import util as checkpointable
try:
import h5py # pylint:disable=g-import-not-at-top
@@ -922,6 +923,18 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
SubclassedModel, SubclassedModelRestore,
_restore_init_fn)
+ @test_util.run_in_graph_and_eager_modes
+ def test_incompatible_checkpoint(self):
+ save_path = checkpointable.Checkpoint().save(
+ os.path.join(self.get_temp_dir(), 'ckpt'))
+ m = keras.Model()
+ with self.assertRaisesRegexp(AssertionError, 'Nothing to load'):
+ m.load_weights(save_path)
+ m.dense = keras.layers.Dense(2)
+ m.dense(constant_op.constant([[1.]]))
+ with self.assertRaisesRegexp(
+ AssertionError, 'Nothing except the root object matched'):
+ m.load_weights(save_path)
if __name__ == '__main__':
test.main()