diff options
author | Sukriti Ramesh <sukritiramesh@google.com> | 2017-12-19 13:54:53 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-19 13:58:01 -0800 |
commit | 130d6c69cfd6a719cf4ccc31ae9b921c3c3cd56b (patch) | |
tree | a2961e7dad69f5dc87e1ea3b0d7b82a57b01af08 /tensorflow/python/saved_model | |
parent | fcf61d57079c8874cd479d4b0dfdb48033e742d8 (diff) |
Migrate SavedModel simple save functionality from contrib to
tensorflow/python/saved_model.
PiperOrigin-RevId: 179599527
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r-- | tensorflow/python/saved_model/BUILD | 34 | ||||
-rw-r--r-- | tensorflow/python/saved_model/saved_model.py | 4 | ||||
-rw-r--r-- | tensorflow/python/saved_model/simple_save.py | 81 | ||||
-rw-r--r-- | tensorflow/python/saved_model/simple_save_test.py | 102 |
4 files changed, 221 insertions, 0 deletions
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 39c6439811..e34aa7cc2c 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -25,6 +25,7 @@ py_library( ":main_op", ":signature_constants", ":signature_def_utils", + ":simple_save", ":tag_constants", ":utils", "//tensorflow/python:util", @@ -90,6 +91,23 @@ py_library( ) py_library( + name = "simple_save", + srcs = [ + "simple_save.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":builder", + ":signature_constants", + ":signature_def_utils", + ":tag_constants", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:lib", + "//tensorflow/python:util", + ], +) + +py_library( name = "main_op", srcs = [ "main_op.py", @@ -198,6 +216,22 @@ py_test( ], ) +py_test( + name = "simple_save_test", + size = "small", + srcs = ["simple_save_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":loader", + ":signature_constants", + ":simple_save", + ":tag_constants", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:variables", + ], +) + # ----------------------------------------------------------------------------- # Google-internal targets. These must be at the end for syncrepo. diff --git a/tensorflow/python/saved_model/saved_model.py b/tensorflow/python/saved_model/saved_model.py index 8c59f7afe7..caabd7bc30 100644 --- a/tensorflow/python/saved_model/saved_model.py +++ b/tensorflow/python/saved_model/saved_model.py @@ -30,6 +30,9 @@ from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants from tensorflow.python.saved_model import utils # pylint: enable=unused-import +# pylint: disable=wildcard-import +from tensorflow.python.saved_model.simple_save import * +# pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented @@ -41,6 +44,7 @@ _allowed_symbols = [ "main_op", "signature_constants", "signature_def_utils", + "simple_save", "tag_constants", "utils", ] diff --git a/tensorflow/python/saved_model/simple_save.py b/tensorflow/python/saved_model/simple_save.py new file mode 100644 index 0000000000..9a81e5cd80 --- /dev/null +++ b/tensorflow/python/saved_model/simple_save.py @@ -0,0 +1,81 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SavedModel simple save functionality.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.saved_model import builder +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import signature_def_utils +from tensorflow.python.saved_model import tag_constants + + +def simple_save(session, export_dir, inputs, outputs, legacy_init_op=None): + """Convenience function to build a SavedModel suitable for serving. + + In many common cases, saving models for serving will be as simple as: + + simple_save(session, + export_dir, + inputs={"x": x, "y": y}, + outputs={"z": z}) + + Although in many cases it's not necessary to understand all of the many ways + to configure a SavedModel, this method has a few practical implications: + - It will be treated as a graph for inference / serving (i.e. uses the tag + `tag_constants.SERVING`) + - The SavedModel will load in TensorFlow Serving and supports the + [Predict API](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/predict.proto). + To use the Classify, Regress, or MultiInference APIs, please + use either + [tf.Estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator) + or the lower level + [SavedModel APIs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). + - Some TensorFlow ops depend on information on disk or other information + called "assets". These are generally handled automatically by adding the + assets to the `GraphKeys.ASSET_FILEPATHS` collection. Only assets in that + collection are exported; if you need more custom behavior, you'll need to + use the [SavedModelBuilder](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/builder.py). + + More information about SavedModel and signatures can be found here: + https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md. + + Args: + session: The TensorFlow session from which to save the meta graph and + variables. + export_dir: The path to which the SavedModel will be stored. + inputs: dict mapping string input names to tensors. These are added + to the SignatureDef as the inputs. + outputs: dict mapping string output names to tensors. These are added + to the SignatureDef as the outputs. + legacy_init_op: Legacy support for op or group of ops to execute after the + restore op upon a load. + """ + signature_def_map = { + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + signature_def_utils.predict_signature_def(inputs, outputs) + } + b = builder.SavedModelBuilder(export_dir) + b.add_meta_graph_and_variables( + session, + tags=[tag_constants.SERVING], + signature_def_map=signature_def_map, + assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS), + legacy_init_op=legacy_init_op, + clear_devices=True) + b.save() diff --git a/tensorflow/python/saved_model/simple_save_test.py b/tensorflow/python/saved_model/simple_save_test.py new file mode 100644 index 0000000000..b2fa40d4f1 --- /dev/null +++ b/tensorflow/python/saved_model/simple_save_test.py @@ -0,0 +1,102 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SavedModel simple save functionality.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import simple_save +from tensorflow.python.saved_model import tag_constants + + +class SimpleSaveTest(test.TestCase): + + def _init_and_validate_variable(self, sess, variable_name, variable_value): + v = variables.Variable(variable_value, name=variable_name) + sess.run(variables.global_variables_initializer()) + self.assertEqual(variable_value, v.eval()) + return v + + def _check_variable_info(self, actual_variable, expected_variable): + self.assertEqual(actual_variable.name, expected_variable.name) + self.assertEqual(actual_variable.dtype, expected_variable.dtype) + self.assertEqual(len(actual_variable.shape), len(expected_variable.shape)) + for i in range(len(actual_variable.shape)): + self.assertEqual(actual_variable.shape[i], expected_variable.shape[i]) + + def _check_tensor_info(self, actual_tensor_info, expected_tensor): + self.assertEqual(actual_tensor_info.name, expected_tensor.name) + self.assertEqual(actual_tensor_info.dtype, expected_tensor.dtype) + self.assertEqual( + len(actual_tensor_info.tensor_shape.dim), len(expected_tensor.shape)) + for i in range(len(actual_tensor_info.tensor_shape.dim)): + self.assertEqual(actual_tensor_info.tensor_shape.dim[i].size, + expected_tensor.shape[i]) + + def testSimpleSave(self): + """Test simple_save that uses the default parameters.""" + export_dir = os.path.join(test.get_temp_dir(), + "test_simple_save") + + # Initialize input and output variables and save a prediction graph using + # the default parameters. + with self.test_session(graph=ops.Graph()) as sess: + var_x = self._init_and_validate_variable(sess, "var_x", 1) + var_y = self._init_and_validate_variable(sess, "var_y", 2) + inputs = {"x": var_x} + outputs = {"y": var_y} + simple_save.simple_save(sess, export_dir, inputs, outputs) + + # Restore the graph with a valid tag and check the global variables and + # signature def map. + with self.test_session(graph=ops.Graph()) as sess: + graph = loader.load(sess, [tag_constants.SERVING], export_dir) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + + # Check value and metadata of the saved variables. + self.assertEqual(len(collection_vars), 2) + self.assertEqual(1, collection_vars[0].eval()) + self.assertEqual(2, collection_vars[1].eval()) + self._check_variable_info(collection_vars[0], var_x) + self._check_variable_info(collection_vars[1], var_y) + + # Check that the appropriate signature_def_map is created with the + # default key and method name, and the specified inputs and outputs. + signature_def_map = graph.signature_def + self.assertEqual(1, len(signature_def_map)) + self.assertEqual(signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + list(signature_def_map.keys())[0]) + + signature_def = signature_def_map[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + self.assertEqual(signature_constants.PREDICT_METHOD_NAME, + signature_def.method_name) + + self.assertEqual(1, len(signature_def.inputs)) + self._check_tensor_info(signature_def.inputs["x"], var_x) + self.assertEqual(1, len(signature_def.outputs)) + self._check_tensor_info(signature_def.outputs["y"], var_y) + + +if __name__ == "__main__": + test.main() |