From 130d6c69cfd6a719cf4ccc31ae9b921c3c3cd56b Mon Sep 17 00:00:00 2001 From: Sukriti Ramesh Date: Tue, 19 Dec 2017 13:54:53 -0800 Subject: Migrate SavedModel simple save functionality from contrib to tensorflow/python/saved_model. PiperOrigin-RevId: 179599527 --- tensorflow/contrib/saved_model/BUILD | 16 ---- .../saved_model/python/saved_model/utils.py | 81 ---------------- .../saved_model/python/saved_model/utils_test.py | 102 --------------------- 3 files changed, 199 deletions(-) delete mode 100644 tensorflow/contrib/saved_model/python/saved_model/utils.py delete mode 100644 tensorflow/contrib/saved_model/python/saved_model/utils_test.py (limited to 'tensorflow/contrib/saved_model') diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index 20be819e07..245fe07f2b 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -82,22 +82,6 @@ py_test( ], ) -py_test( - name = "utils_test", - size = "small", - srcs = ["python/saved_model/utils_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":saved_model_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:variables", - "//tensorflow/python/saved_model:loader", - "//tensorflow/python/saved_model:signature_constants", - "//tensorflow/python/saved_model:tag_constants", - ], -) - filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/saved_model/python/saved_model/utils.py b/tensorflow/contrib/saved_model/python/saved_model/utils.py deleted file mode 100644 index 9f34af64a6..0000000000 --- a/tensorflow/contrib/saved_model/python/saved_model/utils.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SavedModel utility functions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.saved_model import builder -from tensorflow.python.saved_model import signature_constants -from tensorflow.python.saved_model import signature_def_utils -from tensorflow.python.saved_model import tag_constants - - -def simple_save(session, export_dir, inputs, outputs, legacy_init_op=None): - """Convenience function to build a SavedModel suitable for serving. - - In many common cases, saving models for serving will be as simple as: - - simple_save(session, - export_dir, - inputs={"x": x, "y": y}, - outputs={"z": z}) - - Although in many cases it's not necessary to understand all of the many ways - to configure a SavedModel, this method has a few practical implications: - - It will be treated as a graph for inference / serving (i.e. uses the tag - `tag_constants.SERVING`) - - The saved model will load in TensorFlow Serving and supports the - [Predict API](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/predict.proto). - To use the Classify, Regress, or MultiInference APIs, please - use either - [tf.Estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator) - or the lower level - [SavedModel APIs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). - - Some TensorFlow ops depend on information on disk or other information - called "assets". These are generally handled automatically by adding the - assets to the `GraphKeys.ASSET_FILEPATHS` collection. Only assets in that - collection are exported; if you need more custom behavior, you'll need to - use the [SavedModelBuilder](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/builder.py). - - More information about SavedModel and signatures can be found here: - https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md. - - Args: - session: The TensorFlow session from which to save the meta graph and - variables. - export_dir: The path to which the SavedModel will be stored. - inputs: dict mapping string input names to tensors. These are added - to the SignatureDef as the inputs. - outputs: dict mapping string output names to tensors. These are added - to the SignatureDef as the outputs. - legacy_init_op: Legacy support for op or group of ops to execute after the - restore op upon a load. - """ - signature_def_map = { - signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: - signature_def_utils.predict_signature_def(inputs, outputs) - } - b = builder.SavedModelBuilder(export_dir) - b.add_meta_graph_and_variables( - session, - tags=[tag_constants.SERVING], - signature_def_map=signature_def_map, - assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS), - legacy_init_op=legacy_init_op, - clear_devices=True) - b.save() diff --git a/tensorflow/contrib/saved_model/python/saved_model/utils_test.py b/tensorflow/contrib/saved_model/python/saved_model/utils_test.py deleted file mode 100644 index 36dfb88871..0000000000 --- a/tensorflow/contrib/saved_model/python/saved_model/utils_test.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for saved_model utils.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os - -from tensorflow.contrib.saved_model.python.saved_model import utils -from tensorflow.python.framework import ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test -from tensorflow.python.saved_model import loader -from tensorflow.python.saved_model import signature_constants -from tensorflow.python.saved_model import tag_constants - - -class UtilsTest(test.TestCase): - - def _init_and_validate_variable(self, sess, variable_name, variable_value): - v = variables.Variable(variable_value, name=variable_name) - sess.run(variables.global_variables_initializer()) - self.assertEqual(variable_value, v.eval()) - return v - - def _check_variable_info(self, actual_variable, expected_variable): - self.assertEqual(actual_variable.name, expected_variable.name) - self.assertEqual(actual_variable.dtype, expected_variable.dtype) - self.assertEqual(len(actual_variable.shape), len(expected_variable.shape)) - for i in range(len(actual_variable.shape)): - self.assertEqual(actual_variable.shape[i], expected_variable.shape[i]) - - def _check_tensor_info(self, actual_tensor_info, expected_tensor): - self.assertEqual(actual_tensor_info.name, expected_tensor.name) - self.assertEqual(actual_tensor_info.dtype, expected_tensor.dtype) - self.assertEqual( - len(actual_tensor_info.tensor_shape.dim), len(expected_tensor.shape)) - for i in range(len(actual_tensor_info.tensor_shape.dim)): - self.assertEqual(actual_tensor_info.tensor_shape.dim[i].size, - expected_tensor.shape[i]) - - def testSimpleSave(self): - """Test simple_save that uses the default parameters.""" - export_dir = os.path.join(test.get_temp_dir(), - "test_simple_save") - - # Initialize input and output variables and save a prediction graph using - # the default parameters. - with self.test_session(graph=ops.Graph()) as sess: - var_x = self._init_and_validate_variable(sess, "var_x", 1) - var_y = self._init_and_validate_variable(sess, "var_y", 2) - inputs = {"x": var_x} - outputs = {"y": var_y} - utils.simple_save(sess, export_dir, inputs, outputs) - - # Restore the graph with a valid tag and check the global variables and - # signature def map. - with self.test_session(graph=ops.Graph()) as sess: - graph = loader.load(sess, [tag_constants.SERVING], export_dir) - collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - - # Check value and metadata of the saved variables. - self.assertEqual(len(collection_vars), 2) - self.assertEqual(1, collection_vars[0].eval()) - self.assertEqual(2, collection_vars[1].eval()) - self._check_variable_info(collection_vars[0], var_x) - self._check_variable_info(collection_vars[1], var_y) - - # Check that the appropriate signature_def_map is created with the - # default key and method name, and the specified inputs and outputs. - signature_def_map = graph.signature_def - self.assertEqual(1, len(signature_def_map)) - self.assertEqual(signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - list(signature_def_map.keys())[0]) - - signature_def = signature_def_map[ - signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] - self.assertEqual(signature_constants.PREDICT_METHOD_NAME, - signature_def.method_name) - - self.assertEqual(1, len(signature_def.inputs)) - self._check_tensor_info(signature_def.inputs["x"], var_x) - self.assertEqual(1, len(signature_def.outputs)) - self._check_tensor_info(signature_def.outputs["y"], var_y) - - -if __name__ == "__main__": - test.main() -- cgit v1.2.3