aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar Sukriti Ramesh <sukritiramesh@google.com>2017-12-19 13:54:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-19 13:58:01 -0800
commit130d6c69cfd6a719cf4ccc31ae9b921c3c3cd56b (patch)
treea2961e7dad69f5dc87e1ea3b0d7b82a57b01af08 /tensorflow/python/saved_model
parentfcf61d57079c8874cd479d4b0dfdb48033e742d8 (diff)
Migrate SavedModel simple save functionality from contrib to
tensorflow/python/saved_model. PiperOrigin-RevId: 179599527
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/BUILD34
-rw-r--r--tensorflow/python/saved_model/saved_model.py4
-rw-r--r--tensorflow/python/saved_model/simple_save.py81
-rw-r--r--tensorflow/python/saved_model/simple_save_test.py102
4 files changed, 221 insertions, 0 deletions
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index 39c6439811..e34aa7cc2c 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -25,6 +25,7 @@ py_library(
":main_op",
":signature_constants",
":signature_def_utils",
+ ":simple_save",
":tag_constants",
":utils",
"//tensorflow/python:util",
@@ -90,6 +91,23 @@ py_library(
)
py_library(
+ name = "simple_save",
+ srcs = [
+ "simple_save.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":builder",
+ ":signature_constants",
+ ":signature_def_utils",
+ ":tag_constants",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
name = "main_op",
srcs = [
"main_op.py",
@@ -198,6 +216,22 @@ py_test(
],
)
+py_test(
+ name = "simple_save_test",
+ size = "small",
+ srcs = ["simple_save_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":loader",
+ ":signature_constants",
+ ":simple_save",
+ ":tag_constants",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:variables",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
diff --git a/tensorflow/python/saved_model/saved_model.py b/tensorflow/python/saved_model/saved_model.py
index 8c59f7afe7..caabd7bc30 100644
--- a/tensorflow/python/saved_model/saved_model.py
+++ b/tensorflow/python/saved_model/saved_model.py
@@ -30,6 +30,9 @@ from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils
# pylint: enable=unused-import
+# pylint: disable=wildcard-import
+from tensorflow.python.saved_model.simple_save import *
+# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
@@ -41,6 +44,7 @@ _allowed_symbols = [
"main_op",
"signature_constants",
"signature_def_utils",
+ "simple_save",
"tag_constants",
"utils",
]
diff --git a/tensorflow/python/saved_model/simple_save.py b/tensorflow/python/saved_model/simple_save.py
new file mode 100644
index 0000000000..9a81e5cd80
--- /dev/null
+++ b/tensorflow/python/saved_model/simple_save.py
@@ -0,0 +1,81 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""SavedModel simple save functionality."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.saved_model import builder
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import signature_def_utils
+from tensorflow.python.saved_model import tag_constants
+
+
+def simple_save(session, export_dir, inputs, outputs, legacy_init_op=None):
+ """Convenience function to build a SavedModel suitable for serving.
+
+ In many common cases, saving models for serving will be as simple as:
+
+ simple_save(session,
+ export_dir,
+ inputs={"x": x, "y": y},
+ outputs={"z": z})
+
+ Although in many cases it's not necessary to understand all of the many ways
+ to configure a SavedModel, this method has a few practical implications:
+ - It will be treated as a graph for inference / serving (i.e. uses the tag
+ `tag_constants.SERVING`)
+ - The SavedModel will load in TensorFlow Serving and supports the
+ [Predict API](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/predict.proto).
+ To use the Classify, Regress, or MultiInference APIs, please
+ use either
+ [tf.Estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator)
+ or the lower level
+ [SavedModel APIs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md).
+ - Some TensorFlow ops depend on information on disk or other information
+ called "assets". These are generally handled automatically by adding the
+ assets to the `GraphKeys.ASSET_FILEPATHS` collection. Only assets in that
+ collection are exported; if you need more custom behavior, you'll need to
+ use the [SavedModelBuilder](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/builder.py).
+
+ More information about SavedModel and signatures can be found here:
+ https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md.
+
+ Args:
+ session: The TensorFlow session from which to save the meta graph and
+ variables.
+ export_dir: The path to which the SavedModel will be stored.
+ inputs: dict mapping string input names to tensors. These are added
+ to the SignatureDef as the inputs.
+ outputs: dict mapping string output names to tensors. These are added
+ to the SignatureDef as the outputs.
+ legacy_init_op: Legacy support for op or group of ops to execute after the
+ restore op upon a load.
+ """
+ signature_def_map = {
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
+ signature_def_utils.predict_signature_def(inputs, outputs)
+ }
+ b = builder.SavedModelBuilder(export_dir)
+ b.add_meta_graph_and_variables(
+ session,
+ tags=[tag_constants.SERVING],
+ signature_def_map=signature_def_map,
+ assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS),
+ legacy_init_op=legacy_init_op,
+ clear_devices=True)
+ b.save()
diff --git a/tensorflow/python/saved_model/simple_save_test.py b/tensorflow/python/saved_model/simple_save_test.py
new file mode 100644
index 0000000000..b2fa40d4f1
--- /dev/null
+++ b/tensorflow/python/saved_model/simple_save_test.py
@@ -0,0 +1,102 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for SavedModel simple save functionality."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import simple_save
+from tensorflow.python.saved_model import tag_constants
+
+
+class SimpleSaveTest(test.TestCase):
+
+ def _init_and_validate_variable(self, sess, variable_name, variable_value):
+ v = variables.Variable(variable_value, name=variable_name)
+ sess.run(variables.global_variables_initializer())
+ self.assertEqual(variable_value, v.eval())
+ return v
+
+ def _check_variable_info(self, actual_variable, expected_variable):
+ self.assertEqual(actual_variable.name, expected_variable.name)
+ self.assertEqual(actual_variable.dtype, expected_variable.dtype)
+ self.assertEqual(len(actual_variable.shape), len(expected_variable.shape))
+ for i in range(len(actual_variable.shape)):
+ self.assertEqual(actual_variable.shape[i], expected_variable.shape[i])
+
+ def _check_tensor_info(self, actual_tensor_info, expected_tensor):
+ self.assertEqual(actual_tensor_info.name, expected_tensor.name)
+ self.assertEqual(actual_tensor_info.dtype, expected_tensor.dtype)
+ self.assertEqual(
+ len(actual_tensor_info.tensor_shape.dim), len(expected_tensor.shape))
+ for i in range(len(actual_tensor_info.tensor_shape.dim)):
+ self.assertEqual(actual_tensor_info.tensor_shape.dim[i].size,
+ expected_tensor.shape[i])
+
+ def testSimpleSave(self):
+ """Test simple_save that uses the default parameters."""
+ export_dir = os.path.join(test.get_temp_dir(),
+ "test_simple_save")
+
+ # Initialize input and output variables and save a prediction graph using
+ # the default parameters.
+ with self.test_session(graph=ops.Graph()) as sess:
+ var_x = self._init_and_validate_variable(sess, "var_x", 1)
+ var_y = self._init_and_validate_variable(sess, "var_y", 2)
+ inputs = {"x": var_x}
+ outputs = {"y": var_y}
+ simple_save.simple_save(sess, export_dir, inputs, outputs)
+
+ # Restore the graph with a valid tag and check the global variables and
+ # signature def map.
+ with self.test_session(graph=ops.Graph()) as sess:
+ graph = loader.load(sess, [tag_constants.SERVING], export_dir)
+ collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+
+ # Check value and metadata of the saved variables.
+ self.assertEqual(len(collection_vars), 2)
+ self.assertEqual(1, collection_vars[0].eval())
+ self.assertEqual(2, collection_vars[1].eval())
+ self._check_variable_info(collection_vars[0], var_x)
+ self._check_variable_info(collection_vars[1], var_y)
+
+ # Check that the appropriate signature_def_map is created with the
+ # default key and method name, and the specified inputs and outputs.
+ signature_def_map = graph.signature_def
+ self.assertEqual(1, len(signature_def_map))
+ self.assertEqual(signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
+ list(signature_def_map.keys())[0])
+
+ signature_def = signature_def_map[
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
+ self.assertEqual(signature_constants.PREDICT_METHOD_NAME,
+ signature_def.method_name)
+
+ self.assertEqual(1, len(signature_def.inputs))
+ self._check_tensor_info(signature_def.inputs["x"], var_x)
+ self.assertEqual(1, len(signature_def.outputs))
+ self._check_tensor_info(signature_def.outputs["y"], var_y)
+
+
+if __name__ == "__main__":
+ test.main()