aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-15 04:53:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-15 04:56:56 -0700
commitc1e30ab82641c0c0645a65f8046761d1710617a5 (patch)
tree1383100cdf61f53e6bb4faeed97f361a1fdea897 /tensorflow/python/saved_model
parent405b657c07b1754d831a0442eec0d9c5df793042 (diff)
Add signature def utility functions for inspection of input and output types and shapes.
PiperOrigin-RevId: 168820997
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/signature_def_utils.py14
-rw-r--r--tensorflow/python/saved_model/signature_def_utils_impl.py80
-rw-r--r--tensorflow/python/saved_model/signature_def_utils_test.py68
3 files changed, 143 insertions, 19 deletions
diff --git a/tensorflow/python/saved_model/signature_def_utils.py b/tensorflow/python/saved_model/signature_def_utils.py
index be29a0f6b1..a7c648ce2f 100644
--- a/tensorflow/python/saved_model/signature_def_utils.py
+++ b/tensorflow/python/saved_model/signature_def_utils.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""SignatureDef utility functions.
-Utility functions for constructing SignatureDef protos.
+Utility functions for building and inspecting SignatureDef protos.
"""
from __future__ import absolute_import
from __future__ import division
@@ -26,13 +26,7 @@ from tensorflow.python.saved_model.signature_def_utils_impl import classificatio
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def
from tensorflow.python.saved_model.signature_def_utils_impl import regression_signature_def
# pylint: enable=unused-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = [
- "build_signature_def",
- "classification_signature_def",
- "predict_signature_def",
- "regression_signature_def",
-]
-remove_undocumented(__name__, _allowed_symbols)
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py
index 0559fb415e..7a3fb16825 100644
--- a/tensorflow/python/saved_model/signature_def_utils_impl.py
+++ b/tensorflow/python/saved_model/signature_def_utils_impl.py
@@ -19,6 +19,8 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import utils
@@ -146,3 +148,81 @@ def predict_signature_def(inputs, outputs):
signature_constants.PREDICT_METHOD_NAME)
return signature_def
+
+
+def _get_shapes_from_tensor_info_dict(tensor_info_dict):
+ """Returns a map of keys to TensorShape objects.
+
+ Args:
+ tensor_info_dict: map with TensorInfo proto as values.
+
+ Returns:
+ Map with corresponding TensorShape objects as values.
+ """
+ return {
+ key: tensor_shape.TensorShape(tensor_info.tensor_shape)
+ for key, tensor_info in tensor_info_dict.items()
+ }
+
+
+def _get_types_from_tensor_info_dict(tensor_info_dict):
+ """Returns a map of keys to DType objects.
+
+ Args:
+ tensor_info_dict: map with TensorInfo proto as values.
+
+ Returns:
+ Map with corresponding DType objects as values.
+ """
+ return {
+ key: dtypes.DType(tensor_info.dtype)
+ for key, tensor_info in tensor_info_dict.items()
+ }
+
+
+def get_signature_def_input_shapes(signature):
+ """Returns map of parameter names to their shapes.
+
+ Args:
+ signature: SignatureDef proto.
+
+ Returns:
+ Map from string to TensorShape objects.
+ """
+ return _get_shapes_from_tensor_info_dict(signature.inputs)
+
+
+def get_signature_def_input_types(signature):
+ """Returns map of output names to their types.
+
+ Args:
+ signature: SignatureDef proto.
+
+ Returns:
+ Map from string to DType objects.
+ """
+ return _get_types_from_tensor_info_dict(signature.inputs)
+
+
+def get_signature_def_output_shapes(signature):
+ """Returns map of output names to their shapes.
+
+ Args:
+ signature: SignatureDef proto.
+
+ Returns:
+ Map from string to TensorShape objects.
+ """
+ return _get_shapes_from_tensor_info_dict(signature.outputs)
+
+
+def get_signature_def_output_types(signature):
+ """Returns map of output names to their types.
+
+ Args:
+ signature: SignatureDef proto.
+
+ Returns:
+ Map from string to DType objects.
+ """
+ return _get_types_from_tensor_info_dict(signature.outputs)
diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py
index 5859496cf3..6627602849 100644
--- a/tensorflow/python/saved_model/signature_def_utils_test.py
+++ b/tensorflow/python/saved_model/signature_def_utils_test.py
@@ -24,10 +24,23 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
-from tensorflow.python.saved_model import signature_def_utils
+from tensorflow.python.saved_model import signature_def_utils_impl
from tensorflow.python.saved_model import utils
+def _make_signature(inputs, outputs, name=None):
+ input_info = {
+ input_name: utils.build_tensor_info(tensor)
+ for input_name, tensor in inputs.items()
+ }
+ output_info = {
+ output_name: utils.build_tensor_info(tensor)
+ for output_name, tensor in outputs.items()
+ }
+ return signature_def_utils_impl.build_signature_def(input_info, output_info,
+ name)
+
+
class SignatureDefUtilsTest(test.TestCase):
def testBuildSignatureDef(self):
@@ -41,8 +54,8 @@ class SignatureDefUtilsTest(test.TestCase):
outputs = dict()
outputs["foo-output"] = y_tensor_info
- signature_def = signature_def_utils.build_signature_def(inputs, outputs,
- "foo-method-name")
+ signature_def = signature_def_utils_impl.build_signature_def(
+ inputs, outputs, "foo-method-name")
self.assertEqual("foo-method-name", signature_def.method_name)
# Check inputs in signature def.
@@ -63,8 +76,8 @@ class SignatureDefUtilsTest(test.TestCase):
def testRegressionSignatureDef(self):
input1 = constant_op.constant("a", name="input-1")
output1 = constant_op.constant("b", name="output-1")
- signature_def = signature_def_utils.regression_signature_def(input1,
- output1)
+ signature_def = signature_def_utils_impl.regression_signature_def(
+ input1, output1)
self.assertEqual(signature_constants.REGRESS_METHOD_NAME,
signature_def.method_name)
@@ -89,9 +102,8 @@ class SignatureDefUtilsTest(test.TestCase):
input1 = constant_op.constant("a", name="input-1")
output1 = constant_op.constant("b", name="output-1")
output2 = constant_op.constant("c", name="output-2")
- signature_def = signature_def_utils.classification_signature_def(input1,
- output1,
- output2)
+ signature_def = signature_def_utils_impl.classification_signature_def(
+ input1, output1, output2)
self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME,
signature_def.method_name)
@@ -122,7 +134,7 @@ class SignatureDefUtilsTest(test.TestCase):
input2 = constant_op.constant("b", name="input-2")
output1 = constant_op.constant("c", name="output-1")
output2 = constant_op.constant("d", name="output-2")
- signature_def = signature_def_utils.predict_signature_def({
+ signature_def = signature_def_utils_impl.predict_signature_def({
"input-1": input1,
"input-2": input2
}, {"output-1": output1,
@@ -153,6 +165,44 @@ class SignatureDefUtilsTest(test.TestCase):
self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype)
self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim))
+ def testGetShapeAndTypes(self):
+ inputs = {
+ "input-1": constant_op.constant(["a", "b"]),
+ "input-2": array_ops.placeholder(dtypes.float32, [10, 11]),
+ }
+ outputs = {
+ "output-1": array_ops.placeholder(dtypes.float32, [10, 32]),
+ "output-2": constant_op.constant([["b"]]),
+ }
+ signature_def = _make_signature(inputs, outputs)
+ self.assertEqual(
+ signature_def_utils_impl.get_signature_def_input_shapes(signature_def),
+ {"input-1": [2], "input-2": [10, 11]})
+ self.assertEqual(
+ signature_def_utils_impl.get_signature_def_output_shapes(signature_def),
+ {"output-1": [10, 32], "output-2": [1, 1]})
+ self.assertEqual(
+ signature_def_utils_impl.get_signature_def_input_types(signature_def),
+ {"input-1": dtypes.string, "input-2": dtypes.float32})
+ self.assertEqual(
+ signature_def_utils_impl.get_signature_def_output_types(signature_def),
+ {"output-1": dtypes.float32, "output-2": dtypes.string})
+
+ def testGetNonFullySpecifiedShapes(self):
+ outputs = {
+ "output-1": array_ops.placeholder(dtypes.float32, [None, 10, None]),
+ "output-2": array_ops.sparse_placeholder(dtypes.float32),
+ }
+ signature_def = _make_signature({}, outputs)
+ shapes = signature_def_utils_impl.get_signature_def_output_shapes(
+ signature_def)
+ self.assertEqual(len(shapes), 2)
+ # Must compare shapes with as_list() since 2 equivalent non-fully defined
+ # shapes are not equal to each other.
+ self.assertEqual(shapes["output-1"].as_list(), [None, 10, None])
+ # Must compare `dims` since its an unknown shape.
+ self.assertEqual(shapes["output-2"].dims, None)
+
if __name__ == "__main__":
test.main()