aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/saved_model
diff options
context:
space:
mode:
authorGravatar Sukriti Ramesh <sukritiramesh@google.com>2017-12-19 13:54:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-19 13:58:01 -0800
commit130d6c69cfd6a719cf4ccc31ae9b921c3c3cd56b (patch)
treea2961e7dad69f5dc87e1ea3b0d7b82a57b01af08 /tensorflow/contrib/saved_model
parentfcf61d57079c8874cd479d4b0dfdb48033e742d8 (diff)
Migrate SavedModel simple save functionality from contrib to
tensorflow/python/saved_model. PiperOrigin-RevId: 179599527
Diffstat (limited to 'tensorflow/contrib/saved_model')
-rw-r--r--tensorflow/contrib/saved_model/BUILD16
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/utils.py81
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/utils_test.py102
3 files changed, 0 insertions, 199 deletions
diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD
index 20be819e07..245fe07f2b 100644
--- a/tensorflow/contrib/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/BUILD
@@ -82,22 +82,6 @@ py_test(
],
)
-py_test(
- name = "utils_test",
- size = "small",
- srcs = ["python/saved_model/utils_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":saved_model_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:variables",
- "//tensorflow/python/saved_model:loader",
- "//tensorflow/python/saved_model:signature_constants",
- "//tensorflow/python/saved_model:tag_constants",
- ],
-)
-
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/saved_model/python/saved_model/utils.py b/tensorflow/contrib/saved_model/python/saved_model/utils.py
deleted file mode 100644
index 9f34af64a6..0000000000
--- a/tensorflow/contrib/saved_model/python/saved_model/utils.py
+++ /dev/null
@@ -1,81 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""SavedModel utility functions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.framework import ops
-from tensorflow.python.saved_model import builder
-from tensorflow.python.saved_model import signature_constants
-from tensorflow.python.saved_model import signature_def_utils
-from tensorflow.python.saved_model import tag_constants
-
-
-def simple_save(session, export_dir, inputs, outputs, legacy_init_op=None):
- """Convenience function to build a SavedModel suitable for serving.
-
- In many common cases, saving models for serving will be as simple as:
-
- simple_save(session,
- export_dir,
- inputs={"x": x, "y": y},
- outputs={"z": z})
-
- Although in many cases it's not necessary to understand all of the many ways
- to configure a SavedModel, this method has a few practical implications:
- - It will be treated as a graph for inference / serving (i.e. uses the tag
- `tag_constants.SERVING`)
- - The saved model will load in TensorFlow Serving and supports the
- [Predict API](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/predict.proto).
- To use the Classify, Regress, or MultiInference APIs, please
- use either
- [tf.Estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator)
- or the lower level
- [SavedModel APIs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md).
- - Some TensorFlow ops depend on information on disk or other information
- called "assets". These are generally handled automatically by adding the
- assets to the `GraphKeys.ASSET_FILEPATHS` collection. Only assets in that
- collection are exported; if you need more custom behavior, you'll need to
- use the [SavedModelBuilder](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/builder.py).
-
- More information about SavedModel and signatures can be found here:
- https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md.
-
- Args:
- session: The TensorFlow session from which to save the meta graph and
- variables.
- export_dir: The path to which the SavedModel will be stored.
- inputs: dict mapping string input names to tensors. These are added
- to the SignatureDef as the inputs.
- outputs: dict mapping string output names to tensors. These are added
- to the SignatureDef as the outputs.
- legacy_init_op: Legacy support for op or group of ops to execute after the
- restore op upon a load.
- """
- signature_def_map = {
- signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
- signature_def_utils.predict_signature_def(inputs, outputs)
- }
- b = builder.SavedModelBuilder(export_dir)
- b.add_meta_graph_and_variables(
- session,
- tags=[tag_constants.SERVING],
- signature_def_map=signature_def_map,
- assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS),
- legacy_init_op=legacy_init_op,
- clear_devices=True)
- b.save()
diff --git a/tensorflow/contrib/saved_model/python/saved_model/utils_test.py b/tensorflow/contrib/saved_model/python/saved_model/utils_test.py
deleted file mode 100644
index 36dfb88871..0000000000
--- a/tensorflow/contrib/saved_model/python/saved_model/utils_test.py
+++ /dev/null
@@ -1,102 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for saved_model utils."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-from tensorflow.contrib.saved_model.python.saved_model import utils
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import test
-from tensorflow.python.saved_model import loader
-from tensorflow.python.saved_model import signature_constants
-from tensorflow.python.saved_model import tag_constants
-
-
-class UtilsTest(test.TestCase):
-
- def _init_and_validate_variable(self, sess, variable_name, variable_value):
- v = variables.Variable(variable_value, name=variable_name)
- sess.run(variables.global_variables_initializer())
- self.assertEqual(variable_value, v.eval())
- return v
-
- def _check_variable_info(self, actual_variable, expected_variable):
- self.assertEqual(actual_variable.name, expected_variable.name)
- self.assertEqual(actual_variable.dtype, expected_variable.dtype)
- self.assertEqual(len(actual_variable.shape), len(expected_variable.shape))
- for i in range(len(actual_variable.shape)):
- self.assertEqual(actual_variable.shape[i], expected_variable.shape[i])
-
- def _check_tensor_info(self, actual_tensor_info, expected_tensor):
- self.assertEqual(actual_tensor_info.name, expected_tensor.name)
- self.assertEqual(actual_tensor_info.dtype, expected_tensor.dtype)
- self.assertEqual(
- len(actual_tensor_info.tensor_shape.dim), len(expected_tensor.shape))
- for i in range(len(actual_tensor_info.tensor_shape.dim)):
- self.assertEqual(actual_tensor_info.tensor_shape.dim[i].size,
- expected_tensor.shape[i])
-
- def testSimpleSave(self):
- """Test simple_save that uses the default parameters."""
- export_dir = os.path.join(test.get_temp_dir(),
- "test_simple_save")
-
- # Initialize input and output variables and save a prediction graph using
- # the default parameters.
- with self.test_session(graph=ops.Graph()) as sess:
- var_x = self._init_and_validate_variable(sess, "var_x", 1)
- var_y = self._init_and_validate_variable(sess, "var_y", 2)
- inputs = {"x": var_x}
- outputs = {"y": var_y}
- utils.simple_save(sess, export_dir, inputs, outputs)
-
- # Restore the graph with a valid tag and check the global variables and
- # signature def map.
- with self.test_session(graph=ops.Graph()) as sess:
- graph = loader.load(sess, [tag_constants.SERVING], export_dir)
- collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
-
- # Check value and metadata of the saved variables.
- self.assertEqual(len(collection_vars), 2)
- self.assertEqual(1, collection_vars[0].eval())
- self.assertEqual(2, collection_vars[1].eval())
- self._check_variable_info(collection_vars[0], var_x)
- self._check_variable_info(collection_vars[1], var_y)
-
- # Check that the appropriate signature_def_map is created with the
- # default key and method name, and the specified inputs and outputs.
- signature_def_map = graph.signature_def
- self.assertEqual(1, len(signature_def_map))
- self.assertEqual(signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
- list(signature_def_map.keys())[0])
-
- signature_def = signature_def_map[
- signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
- self.assertEqual(signature_constants.PREDICT_METHOD_NAME,
- signature_def.method_name)
-
- self.assertEqual(1, len(signature_def.inputs))
- self._check_tensor_info(signature_def.inputs["x"], var_x)
- self.assertEqual(1, len(signature_def.outputs))
- self._check_tensor_info(signature_def.outputs["y"], var_y)
-
-
-if __name__ == "__main__":
- test.main()