diff options
author | Sukriti Ramesh <sukritiramesh@google.com> | 2017-03-09 18:59:30 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-09 19:19:47 -0800 |
commit | dd64c3223cbe587e0e3a820aaaf539e104bee7cb (patch) | |
tree | 3c7c236de27c05dc23b59c8aba306687d74b694c /tensorflow/contrib/saved_model | |
parent | c559153f2bad97f58f63ed17150ee97d4f0f0dd0 (diff) |
Add utils to get signature def from meta graph def.
Change: 149720034
Diffstat (limited to 'tensorflow/contrib/saved_model')
6 files changed, 392 insertions, 0 deletions
diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD new file mode 100644 index 0000000000..42514c8379 --- /dev/null +++ b/tensorflow/contrib/saved_model/BUILD @@ -0,0 +1,69 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Description: +# SavedModel contrib libraries. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "saved_model_py", + srcs = [ + "__init__.py", + "python/__init__.py", + ] + glob( + ["python/saved_model/*.py"], + exclude = ["python/saved_model/*_test.py"], + ), + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:protos_all_py", + ], +) + +py_test( + name = "signature_def_utils_test", + size = "small", + srcs = ["python/saved_model/signature_def_utils_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":saved_model_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python/saved_model:signature_constants", + "//tensorflow/python/saved_model:signature_def_utils", + "//tensorflow/python/saved_model:utils", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/saved_model/__init__.py b/tensorflow/contrib/saved_model/__init__.py new file mode 100644 index 0000000000..b4f27a055d --- /dev/null +++ b/tensorflow/contrib/saved_model/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SavedModel contrib support. + +SavedModel provides a language-neutral format to save machine-learned models +that is recoverable and hermetic. It enables higher-level systems and tools to +produce, consume and transform TensorFlow models. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import,line-too-long +from tensorflow.contrib.saved_model.python.saved_model.signature_def_utils import * +# pylint: enable=unused-import,widcard-import,line-too-long + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ["get_signature_def_by_key"] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/saved_model/python/__init__.py b/tensorflow/contrib/saved_model/python/__init__.py new file mode 100644 index 0000000000..f186c520c5 --- /dev/null +++ b/tensorflow/contrib/saved_model/python/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SavedModel contrib support. + +SavedModel provides a language-neutral format to save machine-learned models +that is recoverable and hermetic. It enables higher-level systems and tools to +produce, consume and transform TensorFlow models. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import +from tensorflow.contrib.saved_model.python.saved_model import * +# pylint: enable=wildcard-import diff --git a/tensorflow/contrib/saved_model/python/saved_model/__init__.py b/tensorflow/contrib/saved_model/python/saved_model/__init__.py new file mode 100644 index 0000000000..7b91622b61 --- /dev/null +++ b/tensorflow/contrib/saved_model/python/saved_model/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SavedModel contrib support. + +SavedModel provides a language-neutral format to save machine-learned models +that is recoverable and hermetic. It enables higher-level systems and tools to +produce, consume and transform TensorFlow models. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import +from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils +# pylint: enable=wildcard-import diff --git a/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils.py b/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils.py new file mode 100644 index 0000000000..f521647999 --- /dev/null +++ b/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils.py @@ -0,0 +1,42 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SignatureDef utility functions implementation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def get_signature_def_by_key(meta_graph_def, signature_def_key): + """Utility function to get a SignatureDef protocol buffer by its key. + + Args: + meta_graph_def: MetaGraphDef protocol buffer with the SignatureDefMap to + look up. + signature_def_key: Key of the SignatureDef protocol buffer to find in the + SignatureDefMap. + + Returns: + A SignatureDef protocol buffer corresponding to the supplied key, if it + exists. + + Raises: + ValueError: If no entry corresponding to the supplied key is found in the + SignatureDefMap of the MetaGraphDef. + """ + if signature_def_key not in meta_graph_def.signature_def: + raise ValueError("No SignatureDef with key '%s' found in MetaGraphDef." % + signature_def_key) + return meta_graph_def.signature_def[signature_def_key] diff --git a/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py b/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py new file mode 100644 index 0000000000..282dd7dc3b --- /dev/null +++ b/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py @@ -0,0 +1,191 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SignatureDef utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils as signature_def_contrib_utils +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import signature_def_utils +from tensorflow.python.saved_model import utils + + +class SignatureDefUtilsTest(test.TestCase): + + def _add_to_signature_def_map(self, meta_graph_def, signature_def_map=None): + if signature_def_map is not None: + for key in signature_def_map: + meta_graph_def.signature_def[key].CopyFrom(signature_def_map[key]) + + def _check_tensor_info(self, tensor_info_map, map_key, expected_tensor_name): + actual_tensor_info = tensor_info_map[map_key] + self.assertEqual(expected_tensor_name, actual_tensor_info.name) + + def testGetSignatureDefByKey(self): + x = array_ops.placeholder(dtypes.float32, 1, name="x") + x_tensor_info = utils.build_tensor_info(x) + + y = array_ops.placeholder(dtypes.float32, name="y") + y_tensor_info = utils.build_tensor_info(y) + + foo_signature_def = signature_def_utils.build_signature_def({ + "foo-input": x_tensor_info + }, {"foo-output": y_tensor_info}, "foo-method-name") + bar_signature_def = signature_def_utils.build_signature_def({ + "bar-input": x_tensor_info + }, {"bar-output": y_tensor_info}, "bar-method-name") + meta_graph_def = meta_graph_pb2.MetaGraphDef() + self._add_to_signature_def_map( + meta_graph_def, {"foo": foo_signature_def, + "bar": bar_signature_def}) + + # Look up a key that does not exist in the SignatureDefMap. + missing_key = "missing-key" + with self.assertRaisesRegexp( + ValueError, + "No SignatureDef with key '%s' found in MetaGraphDef" % missing_key): + signature_def_contrib_utils.get_signature_def_by_key( + meta_graph_def, missing_key) + + # Look up the key, `foo` which exists in the SignatureDefMap. + foo_signature_def = signature_def_contrib_utils.get_signature_def_by_key( + meta_graph_def, "foo") + self.assertTrue("foo-method-name", foo_signature_def.method_name) + + # Check inputs in signature def. + self.assertEqual(1, len(foo_signature_def.inputs)) + self._check_tensor_info(foo_signature_def.inputs, "foo-input", "x:0") + + # Check outputs in signature def. + self.assertEqual(1, len(foo_signature_def.outputs)) + self._check_tensor_info(foo_signature_def.outputs, "foo-output", "y:0") + + # Look up the key, `bar` which exists in the SignatureDefMap. + bar_signature_def = signature_def_contrib_utils.get_signature_def_by_key( + meta_graph_def, "bar") + self.assertTrue("bar-method-name", bar_signature_def.method_name) + + # Check inputs in signature def. + self.assertEqual(1, len(bar_signature_def.inputs)) + self._check_tensor_info(bar_signature_def.inputs, "bar-input", "x:0") + + # Check outputs in signature def. + self.assertEqual(1, len(bar_signature_def.outputs)) + self._check_tensor_info(bar_signature_def.outputs, "bar-output", "y:0") + + def testGetSignatureDefByKeyRegression(self): + input1 = constant_op.constant("a", name="input-1") + output1 = constant_op.constant("b", name="output-1") + + meta_graph_def = meta_graph_pb2.MetaGraphDef() + self._add_to_signature_def_map(meta_graph_def, { + "my_regression": + signature_def_utils.regression_signature_def(input1, output1) + }) + + # Look up the regression signature with the key used while saving. + signature_def = signature_def_contrib_utils.get_signature_def_by_key( + meta_graph_def, "my_regression") + + # Check the method name to match the constants regression method name. + self.assertEqual(signature_constants.REGRESS_METHOD_NAME, + signature_def.method_name) + + # Check inputs in signature def. + self.assertEqual(1, len(signature_def.inputs)) + self._check_tensor_info(signature_def.inputs, + signature_constants.REGRESS_INPUTS, "input-1:0") + + # Check outputs in signature def. + self.assertEqual(1, len(signature_def.outputs)) + self._check_tensor_info(signature_def.outputs, + signature_constants.REGRESS_OUTPUTS, "output-1:0") + + def testGetSignatureDefByKeyClassification(self): + input1 = constant_op.constant("a", name="input-1") + output1 = constant_op.constant("b", name="output-1") + output2 = constant_op.constant("c", name="output-2") + + meta_graph_def = meta_graph_pb2.MetaGraphDef() + self._add_to_signature_def_map(meta_graph_def, { + "my_classification": + signature_def_utils.classification_signature_def( + input1, output1, output2) + }) + + # Look up the classification signature def with the key used while saving. + signature_def = signature_def_contrib_utils.get_signature_def_by_key( + meta_graph_def, "my_classification") + + # Check the method name to match the constants classification method name. + self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME, + signature_def.method_name) + + # Check inputs in signature def. + self.assertEqual(1, len(signature_def.inputs)) + self._check_tensor_info(signature_def.inputs, + signature_constants.CLASSIFY_INPUTS, "input-1:0") + + # Check outputs in signature def. + self.assertEqual(2, len(signature_def.outputs)) + self._check_tensor_info(signature_def.outputs, + signature_constants.CLASSIFY_OUTPUT_CLASSES, + "output-1:0") + self._check_tensor_info(signature_def.outputs, + signature_constants.CLASSIFY_OUTPUT_SCORES, + "output-2:0") + + def testPredictionSignatureDef(self): + input1 = constant_op.constant("a", name="input-1") + input2 = constant_op.constant("b", name="input-2") + output1 = constant_op.constant("c", name="output-1") + output2 = constant_op.constant("d", name="output-2") + + meta_graph_def = meta_graph_pb2.MetaGraphDef() + self._add_to_signature_def_map(meta_graph_def, { + "my_prediction": + signature_def_utils.predict_signature_def({ + "input-1": input1, + "input-2": input2 + }, {"output-1": output1, + "output-2": output2}) + }) + + # Look up the prediction signature def with the key used while saving. + signature_def = signature_def_contrib_utils.get_signature_def_by_key( + meta_graph_def, "my_prediction") + self.assertEqual(signature_constants.PREDICT_METHOD_NAME, + signature_def.method_name) + + # Check inputs in signature def. + self.assertEqual(2, len(signature_def.inputs)) + self._check_tensor_info(signature_def.inputs, "input-1", "input-1:0") + self._check_tensor_info(signature_def.inputs, "input-2", "input-2:0") + + # Check outputs in signature def. + self.assertEqual(2, len(signature_def.outputs)) + self._check_tensor_info(signature_def.outputs, "output-1", "output-1:0") + self._check_tensor_info(signature_def.outputs, "output-2", "output-2:0") + + +if __name__ == "__main__": + test.main() |