aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/saved_model
diff options
context:
space:
mode:
authorGravatar Sukriti Ramesh <sukritiramesh@google.com>2017-03-09 18:59:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-09 19:19:47 -0800
commitdd64c3223cbe587e0e3a820aaaf539e104bee7cb (patch)
tree3c7c236de27c05dc23b59c8aba306687d74b694c /tensorflow/contrib/saved_model
parentc559153f2bad97f58f63ed17150ee97d4f0f0dd0 (diff)
Add utils to get signature def from meta graph def.
Change: 149720034
Diffstat (limited to 'tensorflow/contrib/saved_model')
-rw-r--r--tensorflow/contrib/saved_model/BUILD69
-rw-r--r--tensorflow/contrib/saved_model/__init__.py34
-rw-r--r--tensorflow/contrib/saved_model/python/__init__.py28
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/__init__.py28
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/signature_def_utils.py42
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py191
6 files changed, 392 insertions, 0 deletions
diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD
new file mode 100644
index 0000000000..42514c8379
--- /dev/null
+++ b/tensorflow/contrib/saved_model/BUILD
@@ -0,0 +1,69 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# Description:
+# SavedModel contrib libraries.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_library(
+ name = "saved_model_py",
+ srcs = [
+ "__init__.py",
+ "python/__init__.py",
+ ] + glob(
+ ["python/saved_model/*.py"],
+ exclude = ["python/saved_model/*_test.py"],
+ ),
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+py_test(
+ name = "signature_def_utils_test",
+ size = "small",
+ srcs = ["python/saved_model/signature_def_utils_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":saved_model_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow/python/saved_model:signature_def_utils",
+ "//tensorflow/python/saved_model:utils",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/saved_model/__init__.py b/tensorflow/contrib/saved_model/__init__.py
new file mode 100644
index 0000000000..b4f27a055d
--- /dev/null
+++ b/tensorflow/contrib/saved_model/__init__.py
@@ -0,0 +1,34 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""SavedModel contrib support.
+
+SavedModel provides a language-neutral format to save machine-learned models
+that is recoverable and hermetic. It enables higher-level systems and tools to
+produce, consume and transform TensorFlow models.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import,line-too-long
+from tensorflow.contrib.saved_model.python.saved_model.signature_def_utils import *
+# pylint: enable=unused-import,widcard-import,line-too-long
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = ["get_signature_def_by_key"]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/saved_model/python/__init__.py b/tensorflow/contrib/saved_model/python/__init__.py
new file mode 100644
index 0000000000..f186c520c5
--- /dev/null
+++ b/tensorflow/contrib/saved_model/python/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""SavedModel contrib support.
+
+SavedModel provides a language-neutral format to save machine-learned models
+that is recoverable and hermetic. It enables higher-level systems and tools to
+produce, consume and transform TensorFlow models.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=wildcard-import
+from tensorflow.contrib.saved_model.python.saved_model import *
+# pylint: enable=wildcard-import
diff --git a/tensorflow/contrib/saved_model/python/saved_model/__init__.py b/tensorflow/contrib/saved_model/python/saved_model/__init__.py
new file mode 100644
index 0000000000..7b91622b61
--- /dev/null
+++ b/tensorflow/contrib/saved_model/python/saved_model/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""SavedModel contrib support.
+
+SavedModel provides a language-neutral format to save machine-learned models
+that is recoverable and hermetic. It enables higher-level systems and tools to
+produce, consume and transform TensorFlow models.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=wildcard-import
+from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
+# pylint: enable=wildcard-import
diff --git a/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils.py b/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils.py
new file mode 100644
index 0000000000..f521647999
--- /dev/null
+++ b/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils.py
@@ -0,0 +1,42 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""SignatureDef utility functions implementation."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+def get_signature_def_by_key(meta_graph_def, signature_def_key):
+ """Utility function to get a SignatureDef protocol buffer by its key.
+
+ Args:
+ meta_graph_def: MetaGraphDef protocol buffer with the SignatureDefMap to
+ look up.
+ signature_def_key: Key of the SignatureDef protocol buffer to find in the
+ SignatureDefMap.
+
+ Returns:
+ A SignatureDef protocol buffer corresponding to the supplied key, if it
+ exists.
+
+ Raises:
+ ValueError: If no entry corresponding to the supplied key is found in the
+ SignatureDefMap of the MetaGraphDef.
+ """
+ if signature_def_key not in meta_graph_def.signature_def:
+ raise ValueError("No SignatureDef with key '%s' found in MetaGraphDef." %
+ signature_def_key)
+ return meta_graph_def.signature_def[signature_def_key]
diff --git a/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py b/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py
new file mode 100644
index 0000000000..282dd7dc3b
--- /dev/null
+++ b/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py
@@ -0,0 +1,191 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for SignatureDef utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils as signature_def_contrib_utils
+from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import signature_def_utils
+from tensorflow.python.saved_model import utils
+
+
+class SignatureDefUtilsTest(test.TestCase):
+
+ def _add_to_signature_def_map(self, meta_graph_def, signature_def_map=None):
+ if signature_def_map is not None:
+ for key in signature_def_map:
+ meta_graph_def.signature_def[key].CopyFrom(signature_def_map[key])
+
+ def _check_tensor_info(self, tensor_info_map, map_key, expected_tensor_name):
+ actual_tensor_info = tensor_info_map[map_key]
+ self.assertEqual(expected_tensor_name, actual_tensor_info.name)
+
+ def testGetSignatureDefByKey(self):
+ x = array_ops.placeholder(dtypes.float32, 1, name="x")
+ x_tensor_info = utils.build_tensor_info(x)
+
+ y = array_ops.placeholder(dtypes.float32, name="y")
+ y_tensor_info = utils.build_tensor_info(y)
+
+ foo_signature_def = signature_def_utils.build_signature_def({
+ "foo-input": x_tensor_info
+ }, {"foo-output": y_tensor_info}, "foo-method-name")
+ bar_signature_def = signature_def_utils.build_signature_def({
+ "bar-input": x_tensor_info
+ }, {"bar-output": y_tensor_info}, "bar-method-name")
+ meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ self._add_to_signature_def_map(
+ meta_graph_def, {"foo": foo_signature_def,
+ "bar": bar_signature_def})
+
+ # Look up a key that does not exist in the SignatureDefMap.
+ missing_key = "missing-key"
+ with self.assertRaisesRegexp(
+ ValueError,
+ "No SignatureDef with key '%s' found in MetaGraphDef" % missing_key):
+ signature_def_contrib_utils.get_signature_def_by_key(
+ meta_graph_def, missing_key)
+
+ # Look up the key, `foo` which exists in the SignatureDefMap.
+ foo_signature_def = signature_def_contrib_utils.get_signature_def_by_key(
+ meta_graph_def, "foo")
+ self.assertTrue("foo-method-name", foo_signature_def.method_name)
+
+ # Check inputs in signature def.
+ self.assertEqual(1, len(foo_signature_def.inputs))
+ self._check_tensor_info(foo_signature_def.inputs, "foo-input", "x:0")
+
+ # Check outputs in signature def.
+ self.assertEqual(1, len(foo_signature_def.outputs))
+ self._check_tensor_info(foo_signature_def.outputs, "foo-output", "y:0")
+
+ # Look up the key, `bar` which exists in the SignatureDefMap.
+ bar_signature_def = signature_def_contrib_utils.get_signature_def_by_key(
+ meta_graph_def, "bar")
+ self.assertTrue("bar-method-name", bar_signature_def.method_name)
+
+ # Check inputs in signature def.
+ self.assertEqual(1, len(bar_signature_def.inputs))
+ self._check_tensor_info(bar_signature_def.inputs, "bar-input", "x:0")
+
+ # Check outputs in signature def.
+ self.assertEqual(1, len(bar_signature_def.outputs))
+ self._check_tensor_info(bar_signature_def.outputs, "bar-output", "y:0")
+
+ def testGetSignatureDefByKeyRegression(self):
+ input1 = constant_op.constant("a", name="input-1")
+ output1 = constant_op.constant("b", name="output-1")
+
+ meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ self._add_to_signature_def_map(meta_graph_def, {
+ "my_regression":
+ signature_def_utils.regression_signature_def(input1, output1)
+ })
+
+ # Look up the regression signature with the key used while saving.
+ signature_def = signature_def_contrib_utils.get_signature_def_by_key(
+ meta_graph_def, "my_regression")
+
+ # Check the method name to match the constants regression method name.
+ self.assertEqual(signature_constants.REGRESS_METHOD_NAME,
+ signature_def.method_name)
+
+ # Check inputs in signature def.
+ self.assertEqual(1, len(signature_def.inputs))
+ self._check_tensor_info(signature_def.inputs,
+ signature_constants.REGRESS_INPUTS, "input-1:0")
+
+ # Check outputs in signature def.
+ self.assertEqual(1, len(signature_def.outputs))
+ self._check_tensor_info(signature_def.outputs,
+ signature_constants.REGRESS_OUTPUTS, "output-1:0")
+
+ def testGetSignatureDefByKeyClassification(self):
+ input1 = constant_op.constant("a", name="input-1")
+ output1 = constant_op.constant("b", name="output-1")
+ output2 = constant_op.constant("c", name="output-2")
+
+ meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ self._add_to_signature_def_map(meta_graph_def, {
+ "my_classification":
+ signature_def_utils.classification_signature_def(
+ input1, output1, output2)
+ })
+
+ # Look up the classification signature def with the key used while saving.
+ signature_def = signature_def_contrib_utils.get_signature_def_by_key(
+ meta_graph_def, "my_classification")
+
+ # Check the method name to match the constants classification method name.
+ self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME,
+ signature_def.method_name)
+
+ # Check inputs in signature def.
+ self.assertEqual(1, len(signature_def.inputs))
+ self._check_tensor_info(signature_def.inputs,
+ signature_constants.CLASSIFY_INPUTS, "input-1:0")
+
+ # Check outputs in signature def.
+ self.assertEqual(2, len(signature_def.outputs))
+ self._check_tensor_info(signature_def.outputs,
+ signature_constants.CLASSIFY_OUTPUT_CLASSES,
+ "output-1:0")
+ self._check_tensor_info(signature_def.outputs,
+ signature_constants.CLASSIFY_OUTPUT_SCORES,
+ "output-2:0")
+
+ def testPredictionSignatureDef(self):
+ input1 = constant_op.constant("a", name="input-1")
+ input2 = constant_op.constant("b", name="input-2")
+ output1 = constant_op.constant("c", name="output-1")
+ output2 = constant_op.constant("d", name="output-2")
+
+ meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ self._add_to_signature_def_map(meta_graph_def, {
+ "my_prediction":
+ signature_def_utils.predict_signature_def({
+ "input-1": input1,
+ "input-2": input2
+ }, {"output-1": output1,
+ "output-2": output2})
+ })
+
+ # Look up the prediction signature def with the key used while saving.
+ signature_def = signature_def_contrib_utils.get_signature_def_by_key(
+ meta_graph_def, "my_prediction")
+ self.assertEqual(signature_constants.PREDICT_METHOD_NAME,
+ signature_def.method_name)
+
+ # Check inputs in signature def.
+ self.assertEqual(2, len(signature_def.inputs))
+ self._check_tensor_info(signature_def.inputs, "input-1", "input-1:0")
+ self._check_tensor_info(signature_def.inputs, "input-2", "input-2:0")
+
+ # Check outputs in signature def.
+ self.assertEqual(2, len(signature_def.outputs))
+ self._check_tensor_info(signature_def.outputs, "output-1", "output-1:0")
+ self._check_tensor_info(signature_def.outputs, "output-2", "output-2:0")
+
+
+if __name__ == "__main__":
+ test.main()