diff options
author | Allen Lavoie <allenl@google.com> | 2018-08-01 18:01:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-01 18:09:35 -0700 |
commit | 07ef0a28a9ddb5b661c38f383fecd2b3c239468d (patch) | |
tree | 77be7499dec067c8d5281ba3735950ef941e2282 /tensorflow/python | |
parent | 6e1714de56291165106fd9bee270ada964b34f88 (diff) |
Add a warning for Model.save_weights in TensorFlow format with a Keras optimizer
save_weights in HDF5 format does not save optimizer weights anyway, but since TensorFlow optimizers are saved in TensorFlow format it's a bit surprising when Keras optimizers aren't.
PiperOrigin-RevId: 207027546
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/keras/engine/network.py | 10 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/saving_test.py | 17 |
2 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 9fa3969dac..e278d74c2f 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -1457,6 +1457,16 @@ class Network(base_layer.Layer): session = None else: session = backend.get_session() + optimizer = getattr(self, 'optimizer', None) + if (optimizer + and not isinstance(optimizer, checkpointable.CheckpointableBase)): + logging.warning( + ('This model was compiled with a Keras optimizer (%s) but is being ' + 'saved in TensorFlow format with `save_weights`. The model\'s ' + 'weights will be saved, but unlike with TensorFlow optimizers in ' + 'the TensorFlow format the optimizer\'s state will not be ' + 'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.') + % (optimizer,)) self._checkpointable_saver.save(filepath, session=session) def load_weights(self, filepath, by_name=False): diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py index e029e614e0..f2f8a27b76 100644 --- a/tensorflow/python/keras/engine/saving_test.py +++ b/tensorflow/python/keras/engine/saving_test.py @@ -35,6 +35,7 @@ from tensorflow.python.keras.engine import training from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import training as training_module try: @@ -663,6 +664,22 @@ class SubclassedModel(training.Model): class TestWeightSavingAndLoadingTFFormat(test.TestCase): + def test_keras_optimizer_warning(self): + graph = ops.Graph() + with graph.as_default(), self.test_session(graph): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.Dense(3)) + model.compile(loss='mse', optimizer='adam', metrics=['acc']) + model._make_train_function() + temp_dir = self.get_temp_dir() + prefix = os.path.join(temp_dir, 'ckpt') + with test.mock.patch.object(logging, 'warning') as mock_log: + model.save_weights(prefix) + self.assertRegexpMatches( + str(mock_log.call_args), + 'Keras optimizer') + @test_util.run_in_graph_and_eager_modes def test_tensorflow_format_overwrite(self): with self.test_session() as session: |