diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-09-15 04:53:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-15 04:56:56 -0700 |
commit | c1e30ab82641c0c0645a65f8046761d1710617a5 (patch) | |
tree | 1383100cdf61f53e6bb4faeed97f361a1fdea897 /tensorflow/python/saved_model | |
parent | 405b657c07b1754d831a0442eec0d9c5df793042 (diff) |
Add signature def utility functions for inspection of input and output types and shapes.
PiperOrigin-RevId: 168820997
Diffstat (limited to 'tensorflow/python/saved_model')
3 files changed, 143 insertions, 19 deletions
diff --git a/tensorflow/python/saved_model/signature_def_utils.py b/tensorflow/python/saved_model/signature_def_utils.py index be29a0f6b1..a7c648ce2f 100644 --- a/tensorflow/python/saved_model/signature_def_utils.py +++ b/tensorflow/python/saved_model/signature_def_utils.py @@ -14,7 +14,7 @@ # ============================================================================== """SignatureDef utility functions. -Utility functions for constructing SignatureDef protos. +Utility functions for building and inspecting SignatureDef protos. """ from __future__ import absolute_import from __future__ import division @@ -26,13 +26,7 @@ from tensorflow.python.saved_model.signature_def_utils_impl import classificatio from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def from tensorflow.python.saved_model.signature_def_utils_impl import regression_signature_def # pylint: enable=unused-import -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = [ - "build_signature_def", - "classification_signature_def", - "predict_signature_def", - "regression_signature_def", -] -remove_undocumented(__name__, _allowed_symbols) +del absolute_import +del division +del print_function diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py index 0559fb415e..7a3fb16825 100644 --- a/tensorflow/python/saved_model/signature_def_utils_impl.py +++ b/tensorflow/python/saved_model/signature_def_utils_impl.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import utils @@ -146,3 +148,81 @@ def predict_signature_def(inputs, outputs): signature_constants.PREDICT_METHOD_NAME) return signature_def + + +def _get_shapes_from_tensor_info_dict(tensor_info_dict): + """Returns a map of keys to TensorShape objects. + + Args: + tensor_info_dict: map with TensorInfo proto as values. + + Returns: + Map with corresponding TensorShape objects as values. + """ + return { + key: tensor_shape.TensorShape(tensor_info.tensor_shape) + for key, tensor_info in tensor_info_dict.items() + } + + +def _get_types_from_tensor_info_dict(tensor_info_dict): + """Returns a map of keys to DType objects. + + Args: + tensor_info_dict: map with TensorInfo proto as values. + + Returns: + Map with corresponding DType objects as values. + """ + return { + key: dtypes.DType(tensor_info.dtype) + for key, tensor_info in tensor_info_dict.items() + } + + +def get_signature_def_input_shapes(signature): + """Returns map of parameter names to their shapes. + + Args: + signature: SignatureDef proto. + + Returns: + Map from string to TensorShape objects. + """ + return _get_shapes_from_tensor_info_dict(signature.inputs) + + +def get_signature_def_input_types(signature): + """Returns map of output names to their types. + + Args: + signature: SignatureDef proto. + + Returns: + Map from string to DType objects. + """ + return _get_types_from_tensor_info_dict(signature.inputs) + + +def get_signature_def_output_shapes(signature): + """Returns map of output names to their shapes. + + Args: + signature: SignatureDef proto. + + Returns: + Map from string to TensorShape objects. + """ + return _get_shapes_from_tensor_info_dict(signature.outputs) + + +def get_signature_def_output_types(signature): + """Returns map of output names to their types. + + Args: + signature: SignatureDef proto. + + Returns: + Map from string to DType objects. + """ + return _get_types_from_tensor_info_dict(signature.outputs) diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py index 5859496cf3..6627602849 100644 --- a/tensorflow/python/saved_model/signature_def_utils_test.py +++ b/tensorflow/python/saved_model/signature_def_utils_test.py @@ -24,10 +24,23 @@ from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants -from tensorflow.python.saved_model import signature_def_utils +from tensorflow.python.saved_model import signature_def_utils_impl from tensorflow.python.saved_model import utils +def _make_signature(inputs, outputs, name=None): + input_info = { + input_name: utils.build_tensor_info(tensor) + for input_name, tensor in inputs.items() + } + output_info = { + output_name: utils.build_tensor_info(tensor) + for output_name, tensor in outputs.items() + } + return signature_def_utils_impl.build_signature_def(input_info, output_info, + name) + + class SignatureDefUtilsTest(test.TestCase): def testBuildSignatureDef(self): @@ -41,8 +54,8 @@ class SignatureDefUtilsTest(test.TestCase): outputs = dict() outputs["foo-output"] = y_tensor_info - signature_def = signature_def_utils.build_signature_def(inputs, outputs, - "foo-method-name") + signature_def = signature_def_utils_impl.build_signature_def( + inputs, outputs, "foo-method-name") self.assertEqual("foo-method-name", signature_def.method_name) # Check inputs in signature def. @@ -63,8 +76,8 @@ class SignatureDefUtilsTest(test.TestCase): def testRegressionSignatureDef(self): input1 = constant_op.constant("a", name="input-1") output1 = constant_op.constant("b", name="output-1") - signature_def = signature_def_utils.regression_signature_def(input1, - output1) + signature_def = signature_def_utils_impl.regression_signature_def( + input1, output1) self.assertEqual(signature_constants.REGRESS_METHOD_NAME, signature_def.method_name) @@ -89,9 +102,8 @@ class SignatureDefUtilsTest(test.TestCase): input1 = constant_op.constant("a", name="input-1") output1 = constant_op.constant("b", name="output-1") output2 = constant_op.constant("c", name="output-2") - signature_def = signature_def_utils.classification_signature_def(input1, - output1, - output2) + signature_def = signature_def_utils_impl.classification_signature_def( + input1, output1, output2) self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME, signature_def.method_name) @@ -122,7 +134,7 @@ class SignatureDefUtilsTest(test.TestCase): input2 = constant_op.constant("b", name="input-2") output1 = constant_op.constant("c", name="output-1") output2 = constant_op.constant("d", name="output-2") - signature_def = signature_def_utils.predict_signature_def({ + signature_def = signature_def_utils_impl.predict_signature_def({ "input-1": input1, "input-2": input2 }, {"output-1": output1, @@ -153,6 +165,44 @@ class SignatureDefUtilsTest(test.TestCase): self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype) self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim)) + def testGetShapeAndTypes(self): + inputs = { + "input-1": constant_op.constant(["a", "b"]), + "input-2": array_ops.placeholder(dtypes.float32, [10, 11]), + } + outputs = { + "output-1": array_ops.placeholder(dtypes.float32, [10, 32]), + "output-2": constant_op.constant([["b"]]), + } + signature_def = _make_signature(inputs, outputs) + self.assertEqual( + signature_def_utils_impl.get_signature_def_input_shapes(signature_def), + {"input-1": [2], "input-2": [10, 11]}) + self.assertEqual( + signature_def_utils_impl.get_signature_def_output_shapes(signature_def), + {"output-1": [10, 32], "output-2": [1, 1]}) + self.assertEqual( + signature_def_utils_impl.get_signature_def_input_types(signature_def), + {"input-1": dtypes.string, "input-2": dtypes.float32}) + self.assertEqual( + signature_def_utils_impl.get_signature_def_output_types(signature_def), + {"output-1": dtypes.float32, "output-2": dtypes.string}) + + def testGetNonFullySpecifiedShapes(self): + outputs = { + "output-1": array_ops.placeholder(dtypes.float32, [None, 10, None]), + "output-2": array_ops.sparse_placeholder(dtypes.float32), + } + signature_def = _make_signature({}, outputs) + shapes = signature_def_utils_impl.get_signature_def_output_shapes( + signature_def) + self.assertEqual(len(shapes), 2) + # Must compare shapes with as_list() since 2 equivalent non-fully defined + # shapes are not equal to each other. + self.assertEqual(shapes["output-1"].as_list(), [None, 10, None]) + # Must compare `dims` since its an unknown shape. + self.assertEqual(shapes["output-2"].dims, None) + if __name__ == "__main__": test.main() |