aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-08-01 18:01:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-01 18:09:35 -0700
commit07ef0a28a9ddb5b661c38f383fecd2b3c239468d (patch)
tree77be7499dec067c8d5281ba3735950ef941e2282 /tensorflow/python
parent6e1714de56291165106fd9bee270ada964b34f88 (diff)
Add a warning for Model.save_weights in TensorFlow format with a Keras optimizer
save_weights in HDF5 format does not save optimizer weights anyway, but since TensorFlow optimizers are saved in TensorFlow format it's a bit surprising when Keras optimizers aren't. PiperOrigin-RevId: 207027546
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/keras/engine/network.py10
-rw-r--r--tensorflow/python/keras/engine/saving_test.py17
2 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 9fa3969dac..e278d74c2f 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -1457,6 +1457,16 @@ class Network(base_layer.Layer):
session = None
else:
session = backend.get_session()
+ optimizer = getattr(self, 'optimizer', None)
+ if (optimizer
+ and not isinstance(optimizer, checkpointable.CheckpointableBase)):
+ logging.warning(
+ ('This model was compiled with a Keras optimizer (%s) but is being '
+ 'saved in TensorFlow format with `save_weights`. The model\'s '
+ 'weights will be saved, but unlike with TensorFlow optimizers in '
+ 'the TensorFlow format the optimizer\'s state will not be '
+ 'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.')
+ % (optimizer,))
self._checkpointable_saver.save(filepath, session=session)
def load_weights(self, filepath, by_name=False):
diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py
index e029e614e0..f2f8a27b76 100644
--- a/tensorflow/python/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/engine/saving_test.py
@@ -35,6 +35,7 @@ from tensorflow.python.keras.engine import training
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training as training_module
try:
@@ -663,6 +664,22 @@ class SubclassedModel(training.Model):
class TestWeightSavingAndLoadingTFFormat(test.TestCase):
+ def test_keras_optimizer_warning(self):
+ graph = ops.Graph()
+ with graph.as_default(), self.test_session(graph):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.Dense(3))
+ model.compile(loss='mse', optimizer='adam', metrics=['acc'])
+ model._make_train_function()
+ temp_dir = self.get_temp_dir()
+ prefix = os.path.join(temp_dir, 'ckpt')
+ with test.mock.patch.object(logging, 'warning') as mock_log:
+ model.save_weights(prefix)
+ self.assertRegexpMatches(
+ str(mock_log.call_args),
+ 'Keras optimizer')
+
@test_util.run_in_graph_and_eager_modes
def test_tensorflow_format_overwrite(self):
with self.test_session() as session: