aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-27 21:17:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-27 21:21:46 -0700
commite3095dc262bfdda08d01fce105680515a3d1a7f4 (patch)
tree8891280202c8e4f952418ead287eb44b0b2bd87a
parent919e59cc49ada3d529e080ee8eebaaec7f621844 (diff)
Create a save_model and load_model util to support saving keras.Model to/from checkpoint. Currently, the model topology is still loaded from json (placed under saved_model/assets). Later, we will load from saved_model.pb.
PiperOrigin-RevId: 206412614
-rw-r--r--tensorflow/contrib/saved_model/BUILD29
-rw-r--r--tensorflow/contrib/saved_model/__init__.py3
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/__init__.py1
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py108
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py201
-rw-r--r--tensorflow/python/saved_model/constants.py6
6 files changed, 345 insertions, 3 deletions
diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD
index 26fd4e2023..fbb50befdf 100644
--- a/tensorflow/contrib/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/BUILD
@@ -93,3 +93,32 @@ py_test(
"//tensorflow/python/saved_model:utils",
],
)
+
+py_library(
+ name = "keras_saved_model",
+ srcs = ["python/saved_model/keras_saved_model.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/python:lib",
+ "//tensorflow/python:util",
+ "//tensorflow/python/keras:engine",
+ "//tensorflow/python/saved_model:constants",
+ ],
+)
+
+py_test(
+ name = "keras_saved_model_test",
+ size = "small",
+ srcs = ["python/saved_model/keras_saved_model_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ ":saved_model_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:training",
+ "//tensorflow/python/keras",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/saved_model/__init__.py b/tensorflow/contrib/saved_model/__init__.py
index b4f27a055d..95e1a8967b 100644
--- a/tensorflow/contrib/saved_model/__init__.py
+++ b/tensorflow/contrib/saved_model/__init__.py
@@ -24,11 +24,12 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long
+from tensorflow.contrib.saved_model.python.saved_model.keras_saved_model import *
from tensorflow.contrib.saved_model.python.saved_model.signature_def_utils import *
# pylint: enable=unused-import,widcard-import,line-too-long
from tensorflow.python.util.all_util import remove_undocumented
-_allowed_symbols = ["get_signature_def_by_key"]
+_allowed_symbols = ["get_signature_def_by_key", "load_model", "save_model"]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/saved_model/python/saved_model/__init__.py b/tensorflow/contrib/saved_model/python/saved_model/__init__.py
index 7b91622b61..e3b76bb6f3 100644
--- a/tensorflow/contrib/saved_model/python/saved_model/__init__.py
+++ b/tensorflow/contrib/saved_model/python/saved_model/__init__.py
@@ -24,5 +24,6 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import
+from tensorflow.contrib.saved_model.python.saved_model import keras_saved_model
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
# pylint: enable=wildcard-import
diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py
new file mode 100644
index 0000000000..e2a969f053
--- /dev/null
+++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py
@@ -0,0 +1,108 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=protected-access
+"""Utility functions to save/load keras Model to/from SavedModel."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.keras.models import model_from_json
+from tensorflow.python.lib.io import file_io
+from tensorflow.python.saved_model import constants
+from tensorflow.python.util import compat
+
+
+def save_model(model, saved_model_path):
+ """Save a `tf.keras.Model` into Tensorflow SavedModel format.
+
+ `save_model` generates such files/folders under the `saved_model_path` folder:
+ 1) an asset folder containing the json string of the model's
+ configuration(topology).
+ 2) a checkpoint containing the model weights.
+
+ Note that subclassed models can not be saved via this function, unless you
+ provide an implementation for get_config() and from_config().
+ Also note that `tf.keras.optimizers.Optimizer` instances can not currently be
+ saved to checkpoints. Use optimizers from `tf.train`.
+
+ Args:
+ model: A `tf.keras.Model` to be saved.
+ saved_model_path: a string specifying the path to the SavedModel directory.
+
+ Raises:
+ NotImplementedError: If the passed in model is a subclassed model.
+ """
+ if not model._is_graph_network:
+ raise NotImplementedError
+
+ # save model configuration as a json string under assets folder.
+ model_json = model.to_json()
+ assets_destination_dir = os.path.join(
+ compat.as_bytes(saved_model_path),
+ compat.as_bytes(constants.ASSETS_DIRECTORY))
+
+ if not file_io.file_exists(assets_destination_dir):
+ file_io.recursive_create_dir(assets_destination_dir)
+
+ model_json_filepath = os.path.join(
+ compat.as_bytes(assets_destination_dir),
+ compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
+ file_io.write_string_to_file(model_json_filepath, model_json)
+
+ # save model weights in checkpoint format.
+ checkpoint_destination_dir = os.path.join(
+ compat.as_bytes(saved_model_path),
+ compat.as_bytes(constants.VARIABLES_DIRECTORY))
+
+ if not file_io.file_exists(checkpoint_destination_dir):
+ file_io.recursive_create_dir(checkpoint_destination_dir)
+
+ checkpoint_prefix = os.path.join(
+ compat.as_text(checkpoint_destination_dir),
+ compat.as_text(constants.VARIABLES_FILENAME))
+ model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
+
+
+def load_model(saved_model_path):
+ """Load a keras.Model from SavedModel.
+
+ load_model reinstantiates model state by:
+ 1) loading model topology from json (this will eventually come
+ from metagraph).
+ 2) loading model weights from checkpoint.
+
+ Args:
+ saved_model_path: a string specifying the path to an existing SavedModel.
+
+ Returns:
+ a keras.Model instance.
+ """
+ # restore model topology from json string
+ model_json_filepath = os.path.join(
+ compat.as_bytes(saved_model_path),
+ compat.as_bytes(constants.ASSETS_DIRECTORY),
+ compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
+ model_json = file_io.read_file_to_string(model_json_filepath)
+ model = model_from_json(model_json)
+
+ # restore model weights
+ checkpoint_prefix = os.path.join(
+ compat.as_text(saved_model_path),
+ compat.as_text(constants.VARIABLES_DIRECTORY),
+ compat.as_text(constants.VARIABLES_FILENAME))
+ model.load_weights(checkpoint_prefix)
+ return model
diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
new file mode 100644
index 0000000000..107ae1b07b
--- /dev/null
+++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
@@ -0,0 +1,201 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=protected-access
+"""Tests for saving/loading function for keras Model."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+import numpy as np
+
+from tensorflow.contrib.saved_model.python.saved_model import keras_saved_model
+from tensorflow.python import keras
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training
+from tensorflow.python.platform import test
+from tensorflow.python.training import training as training_module
+
+
+class TestModelSavingandLoading(test.TestCase):
+
+ def test_saving_sequential_model(self):
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+
+ ref_y = model.predict(x)
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ temp_saved_model = os.path.join(temp_dir, 'saved_model')
+ keras_saved_model.save_model(model, temp_saved_model)
+
+ loaded_model = keras_saved_model.load_model(temp_saved_model)
+ y = loaded_model.predict(x)
+ self.assertAllClose(ref_y, y, atol=1e-05)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_saving_sequential_model_without_compile(self):
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+
+ x = np.random.random((1, 3))
+ ref_y = model.predict(x)
+
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ temp_saved_model = os.path.join(temp_dir, 'saved_model')
+ keras_saved_model.save_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_model(temp_saved_model)
+
+ y = loaded_model.predict(x)
+ self.assertAllClose(ref_y, y, atol=1e-05)
+
+ def test_saving_functional_model(self):
+ with self.test_session():
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ output = keras.layers.Dense(3)(x)
+
+ model = keras.models.Model(inputs, output)
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[keras.metrics.categorical_accuracy])
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+ model.train_on_batch(x, y)
+
+ ref_y = model.predict(x)
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ temp_saved_model = os.path.join(temp_dir, 'saved_model')
+ keras_saved_model.save_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_model(temp_saved_model)
+
+ y = loaded_model.predict(x)
+ self.assertAllClose(ref_y, y, atol=1e-05)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_saving_functional_model_without_compile(self):
+ with self.test_session():
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ output = keras.layers.Dense(3)(x)
+
+ model = keras.models.Model(inputs, output)
+
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+
+ ref_y = model.predict(x)
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ temp_saved_model = os.path.join(temp_dir, 'saved_model')
+ keras_saved_model.save_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_model(temp_saved_model)
+
+ y = loaded_model.predict(x)
+ self.assertAllClose(ref_y, y, atol=1e-05)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_saving_with_tf_optimizer(self):
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.Dense(3))
+ model.compile(
+ loss='mse',
+ optimizer=training_module.RMSPropOptimizer(0.1),
+ metrics=['acc'])
+
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+ model.train_on_batch(x, y)
+
+ ref_y = model.predict(x)
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ temp_saved_model = os.path.join(temp_dir, 'saved_model')
+ keras_saved_model.save_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_model(temp_saved_model)
+ loaded_model.compile(
+ loss='mse',
+ optimizer=training_module.RMSPropOptimizer(0.1),
+ metrics=['acc'])
+ y = loaded_model.predict(x)
+ self.assertAllClose(ref_y, y, atol=1e-05)
+
+ # test that new updates are the same with both models
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+
+ ref_loss = model.train_on_batch(x, y)
+ loss = loaded_model.train_on_batch(x, y)
+ self.assertAllClose(ref_loss, loss, atol=1e-05)
+
+ ref_y = model.predict(x)
+ y = loaded_model.predict(x)
+ self.assertAllClose(ref_y, y, atol=1e-05)
+
+ # test saving/loading again
+ keras_saved_model.save_model(loaded_model, temp_saved_model)
+ loaded_model = keras_saved_model.load_model(temp_saved_model)
+ y = loaded_model.predict(x)
+ self.assertAllClose(ref_y, y, atol=1e-05)
+
+ def test_saving_subclassed_model_raise_error(self):
+ # For now, saving subclassed model should raise an error. It should be
+ # avoided later with loading from SavedModel.pb.
+
+ class SubclassedModel(training.Model):
+
+ def __init__(self):
+ super(SubclassedModel, self).__init__()
+ self.layer1 = keras.layers.Dense(3)
+ self.layer2 = keras.layers.Dense(1)
+
+ def call(self, inp):
+ return self.layer2(self.layer1(inp))
+
+ model = SubclassedModel()
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+ temp_saved_model = os.path.join(temp_dir, 'saved_model')
+ with self.assertRaises(NotImplementedError):
+ keras_saved_model.save_model(model, temp_saved_model)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/saved_model/constants.py b/tensorflow/python/saved_model/constants.py
index 61c6ffbd0d..cb251f08bb 100644
--- a/tensorflow/python/saved_model/constants.py
+++ b/tensorflow/python/saved_model/constants.py
@@ -60,6 +60,10 @@ SAVED_MODEL_FILENAME_PBTXT = "saved_model.pbtxt"
tf_export("saved_model.constants.SAVED_MODEL_FILENAME_PBTXT").export_constant(
__name__, "SAVED_MODEL_FILENAME_PBTXT")
+# File name for json format of SavedModel.
+# Not exported while keras_saved_model is in contrib.
+SAVED_MODEL_FILENAME_JSON = "saved_model.json"
+
# Subdirectory name containing the variables/checkpoint files.
VARIABLES_DIRECTORY = "variables"
tf_export("saved_model.constants.VARIABLES_DIRECTORY").export_constant(
@@ -69,5 +73,3 @@ tf_export("saved_model.constants.VARIABLES_DIRECTORY").export_constant(
VARIABLES_FILENAME = "variables"
tf_export("saved_model.constants.VARIABLES_FILENAME").export_constant(
__name__, "VARIABLES_FILENAME")
-
-