aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar David Soergel <soergel@google.com>2017-09-28 11:55:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-28 12:01:58 -0700
commit996b0342879af43de1bf4071190b90ff7309428a (patch)
tree9cfe19f90e59c22140cfa419c04db00b520871b3 /tensorflow/python/saved_model
parent0254d0d31337724db911c89609336afd60e8192d (diff)
Add more validation of serving signatures, both at creation and post hoc.
PiperOrigin-RevId: 170376578
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/signature_def_utils.py1
-rw-r--r--tensorflow/python/saved_model/signature_def_utils_impl.py108
-rw-r--r--tensorflow/python/saved_model/signature_def_utils_test.py160
3 files changed, 257 insertions, 12 deletions
diff --git a/tensorflow/python/saved_model/signature_def_utils.py b/tensorflow/python/saved_model/signature_def_utils.py
index a7c648ce2f..ea0f52f17e 100644
--- a/tensorflow/python/saved_model/signature_def_utils.py
+++ b/tensorflow/python/saved_model/signature_def_utils.py
@@ -23,6 +23,7 @@ from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.python.saved_model.signature_def_utils_impl import build_signature_def
from tensorflow.python.saved_model.signature_def_utils_impl import classification_signature_def
+from tensorflow.python.saved_model.signature_def_utils_impl import is_valid_signature
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def
from tensorflow.python.saved_model.signature_def_utils_impl import regression_signature_def
# pylint: enable=unused-import
diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py
index 7a3fb16825..564befeb0b 100644
--- a/tensorflow/python/saved_model/signature_def_utils_impl.py
+++ b/tensorflow/python/saved_model/signature_def_utils_impl.py
@@ -18,8 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+
+from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import utils
@@ -64,15 +67,22 @@ def regression_signature_def(examples, predictions):
ValueError: If examples is `None`.
"""
if examples is None:
- raise ValueError('examples cannot be None for regression.')
+ raise ValueError('Regression examples cannot be None.')
+ if not isinstance(examples, ops.Tensor):
+ raise ValueError('Regression examples must be a string Tensor.')
if predictions is None:
- raise ValueError('predictions cannot be None for regression.')
+ raise ValueError('Regression predictions cannot be None.')
input_tensor_info = utils.build_tensor_info(examples)
+ if input_tensor_info.dtype != types_pb2.DT_STRING:
+ raise ValueError('Regression examples must be a string Tensor.')
signature_inputs = {signature_constants.REGRESS_INPUTS: input_tensor_info}
output_tensor_info = utils.build_tensor_info(predictions)
+ if output_tensor_info.dtype != types_pb2.DT_FLOAT:
+ raise ValueError('Regression output must be a float Tensor.')
signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor_info}
+
signature_def = build_signature_def(
signature_inputs, signature_outputs,
signature_constants.REGRESS_METHOD_NAME)
@@ -95,21 +105,28 @@ def classification_signature_def(examples, classes, scores):
ValueError: If examples is `None`.
"""
if examples is None:
- raise ValueError('examples cannot be None for classification.')
+ raise ValueError('Classification examples cannot be None.')
+ if not isinstance(examples, ops.Tensor):
+ raise ValueError('Classification examples must be a string Tensor.')
if classes is None and scores is None:
- raise ValueError('classes and scores cannot both be None for '
- 'classification.')
+ raise ValueError('Classification classes and scores cannot both be None.')
input_tensor_info = utils.build_tensor_info(examples)
+ if input_tensor_info.dtype != types_pb2.DT_STRING:
+ raise ValueError('Classification examples must be a string Tensor.')
signature_inputs = {signature_constants.CLASSIFY_INPUTS: input_tensor_info}
signature_outputs = {}
if classes is not None:
classes_tensor_info = utils.build_tensor_info(classes)
+ if classes_tensor_info.dtype != types_pb2.DT_STRING:
+ raise ValueError('Classification classes must be a string Tensor.')
signature_outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES] = (
classes_tensor_info)
if scores is not None:
scores_tensor_info = utils.build_tensor_info(scores)
+ if scores_tensor_info.dtype != types_pb2.DT_FLOAT:
+ raise ValueError('Classification scores must be a float Tensor.')
signature_outputs[signature_constants.CLASSIFY_OUTPUT_SCORES] = (
scores_tensor_info)
@@ -134,9 +151,9 @@ def predict_signature_def(inputs, outputs):
ValueError: If inputs or outputs is `None`.
"""
if inputs is None or not inputs:
- raise ValueError('inputs cannot be None or empty for prediction.')
- if outputs is None:
- raise ValueError('outputs cannot be None or empty for prediction.')
+ raise ValueError('Prediction inputs cannot be None or empty.')
+ if outputs is None or not outputs:
+ raise ValueError('Prediction outputs cannot be None or empty.')
signature_inputs = {key: utils.build_tensor_info(tensor)
for key, tensor in inputs.items()}
@@ -150,6 +167,81 @@ def predict_signature_def(inputs, outputs):
return signature_def
+def is_valid_signature(signature_def):
+ """Determine whether a SignatureDef can be served by TensorFlow Serving."""
+ if signature_def is None:
+ return False
+ return (_is_valid_classification_signature(signature_def) or
+ _is_valid_regression_signature(signature_def) or
+ _is_valid_predict_signature(signature_def))
+
+
+def _is_valid_predict_signature(signature_def):
+ """Determine whether the argument is a servable 'predict' SignatureDef."""
+ if signature_def.method_name != signature_constants.PREDICT_METHOD_NAME:
+ return False
+ if not signature_def.inputs.keys():
+ return False
+ if not signature_def.outputs.keys():
+ return False
+ return True
+
+
+def _is_valid_regression_signature(signature_def):
+ """Determine whether the argument is a servable 'regress' SignatureDef."""
+ if signature_def.method_name != signature_constants.REGRESS_METHOD_NAME:
+ return False
+
+ if (set(signature_def.inputs.keys())
+ != set([signature_constants.REGRESS_INPUTS])):
+ return False
+ if (signature_def.inputs[signature_constants.REGRESS_INPUTS].dtype !=
+ types_pb2.DT_STRING):
+ return False
+
+ if (set(signature_def.outputs.keys())
+ != set([signature_constants.REGRESS_OUTPUTS])):
+ return False
+ if (signature_def.outputs[signature_constants.REGRESS_OUTPUTS].dtype !=
+ types_pb2.DT_FLOAT):
+ return False
+
+ return True
+
+
+def _is_valid_classification_signature(signature_def):
+ """Determine whether the argument is a servable 'classify' SignatureDef."""
+ if signature_def.method_name != signature_constants.CLASSIFY_METHOD_NAME:
+ return False
+
+ if (set(signature_def.inputs.keys())
+ != set([signature_constants.CLASSIFY_INPUTS])):
+ return False
+ if (signature_def.inputs[signature_constants.CLASSIFY_INPUTS].dtype !=
+ types_pb2.DT_STRING):
+ return False
+
+ allowed_outputs = set([signature_constants.CLASSIFY_OUTPUT_CLASSES,
+ signature_constants.CLASSIFY_OUTPUT_SCORES])
+
+ if not signature_def.outputs.keys():
+ return False
+ if set(signature_def.outputs.keys()) - allowed_outputs:
+ return False
+ if (signature_constants.CLASSIFY_OUTPUT_CLASSES in signature_def.outputs
+ and
+ signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES].dtype
+ != types_pb2.DT_STRING):
+ return False
+ if (signature_constants.CLASSIFY_OUTPUT_SCORES in signature_def.outputs
+ and
+ signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES].dtype !=
+ types_pb2.DT_FLOAT):
+ return False
+
+ return True
+
+
def _get_shapes_from_tensor_info_dict(tensor_info_dict):
"""Returns a map of keys to TensorShape objects.
diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py
index 6627602849..b2bd14db8c 100644
--- a/tensorflow/python/saved_model/signature_def_utils_test.py
+++ b/tensorflow/python/saved_model/signature_def_utils_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import types_pb2
+from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -28,6 +29,20 @@ from tensorflow.python.saved_model import signature_def_utils_impl
from tensorflow.python.saved_model import utils
+# We'll reuse the same tensor_infos in multiple contexts just for the tests.
+# The validator doesn't check shapes so we just omit them.
+_STRING = meta_graph_pb2.TensorInfo(
+ name="foobar",
+ dtype=dtypes.string.as_datatype_enum
+)
+
+
+_FLOAT = meta_graph_pb2.TensorInfo(
+ name="foobar",
+ dtype=dtypes.float32.as_datatype_enum
+)
+
+
def _make_signature(inputs, outputs, name=None):
input_info = {
input_name: utils.build_tensor_info(tensor)
@@ -75,7 +90,7 @@ class SignatureDefUtilsTest(test.TestCase):
def testRegressionSignatureDef(self):
input1 = constant_op.constant("a", name="input-1")
- output1 = constant_op.constant("b", name="output-1")
+ output1 = constant_op.constant(2.2, name="output-1")
signature_def = signature_def_utils_impl.regression_signature_def(
input1, output1)
@@ -95,13 +110,13 @@ class SignatureDefUtilsTest(test.TestCase):
y_tensor_info_actual = (
signature_def.outputs[signature_constants.REGRESS_OUTPUTS])
self.assertEqual("output-1:0", y_tensor_info_actual.name)
- self.assertEqual(types_pb2.DT_STRING, y_tensor_info_actual.dtype)
+ self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype)
self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim))
def testClassificationSignatureDef(self):
input1 = constant_op.constant("a", name="input-1")
output1 = constant_op.constant("b", name="output-1")
- output2 = constant_op.constant("c", name="output-2")
+ output2 = constant_op.constant(3.3, name="output-2")
signature_def = signature_def_utils_impl.classification_signature_def(
input1, output1, output2)
@@ -126,7 +141,7 @@ class SignatureDefUtilsTest(test.TestCase):
scores_tensor_info_actual = (
signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES])
self.assertEqual("output-2:0", scores_tensor_info_actual.name)
- self.assertEqual(types_pb2.DT_STRING, scores_tensor_info_actual.dtype)
+ self.assertEqual(types_pb2.DT_FLOAT, scores_tensor_info_actual.dtype)
self.assertEqual(0, len(scores_tensor_info_actual.tensor_shape.dim))
def testPredictionSignatureDef(self):
@@ -203,6 +218,143 @@ class SignatureDefUtilsTest(test.TestCase):
# Must compare `dims` since its an unknown shape.
self.assertEqual(shapes["output-2"].dims, None)
+ def _assertValidSignature(self, inputs, outputs, method_name):
+ signature_def = signature_def_utils_impl.build_signature_def(
+ inputs, outputs, method_name)
+ self.assertTrue(
+ signature_def_utils_impl.is_valid_signature(signature_def))
+
+ def _assertInvalidSignature(self, inputs, outputs, method_name):
+ signature_def = signature_def_utils_impl.build_signature_def(
+ inputs, outputs, method_name)
+ self.assertFalse(
+ signature_def_utils_impl.is_valid_signature(signature_def))
+
+ def testValidSignaturesAreAccepted(self):
+ self._assertValidSignature(
+ {"inputs": _STRING},
+ {"classes": _STRING, "scores": _FLOAT},
+ signature_constants.CLASSIFY_METHOD_NAME)
+
+ self._assertValidSignature(
+ {"inputs": _STRING},
+ {"classes": _STRING},
+ signature_constants.CLASSIFY_METHOD_NAME)
+
+ self._assertValidSignature(
+ {"inputs": _STRING},
+ {"scores": _FLOAT},
+ signature_constants.CLASSIFY_METHOD_NAME)
+
+ self._assertValidSignature(
+ {"inputs": _STRING},
+ {"outputs": _FLOAT},
+ signature_constants.REGRESS_METHOD_NAME)
+
+ self._assertValidSignature(
+ {"foo": _STRING, "bar": _FLOAT},
+ {"baz": _STRING, "qux": _FLOAT},
+ signature_constants.PREDICT_METHOD_NAME)
+
+ def testInvalidMethodNameSignatureIsRejected(self):
+ # WRONG METHOD
+ self._assertInvalidSignature(
+ {"inputs": _STRING},
+ {"classes": _STRING, "scores": _FLOAT},
+ "WRONG method name")
+
+ def testInvalidClassificationSignaturesAreRejected(self):
+ # CLASSIFY: wrong types
+ self._assertInvalidSignature(
+ {"inputs": _FLOAT},
+ {"classes": _STRING, "scores": _FLOAT},
+ signature_constants.CLASSIFY_METHOD_NAME)
+
+ self._assertInvalidSignature(
+ {"inputs": _STRING},
+ {"classes": _FLOAT, "scores": _FLOAT},
+ signature_constants.CLASSIFY_METHOD_NAME)
+
+ self._assertInvalidSignature(
+ {"inputs": _STRING},
+ {"classes": _STRING, "scores": _STRING},
+ signature_constants.CLASSIFY_METHOD_NAME)
+
+ # CLASSIFY: wrong keys
+ self._assertInvalidSignature(
+ {},
+ {"classes": _STRING, "scores": _FLOAT},
+ signature_constants.CLASSIFY_METHOD_NAME)
+
+ self._assertInvalidSignature(
+ {"inputs_WRONG": _STRING},
+ {"classes": _STRING, "scores": _FLOAT},
+ signature_constants.CLASSIFY_METHOD_NAME)
+
+ self._assertInvalidSignature(
+ {"inputs": _STRING},
+ {"classes_WRONG": _STRING, "scores": _FLOAT},
+ signature_constants.CLASSIFY_METHOD_NAME)
+
+ self._assertInvalidSignature(
+ {"inputs": _STRING},
+ {},
+ signature_constants.CLASSIFY_METHOD_NAME)
+
+ self._assertInvalidSignature(
+ {"inputs": _STRING},
+ {"classes": _STRING, "scores": _FLOAT, "extra_WRONG": _STRING},
+ signature_constants.CLASSIFY_METHOD_NAME)
+
+ def testInvalidRegressionSignaturesAreRejected(self):
+ # REGRESS: wrong types
+ self._assertInvalidSignature(
+ {"inputs": _FLOAT},
+ {"outputs": _FLOAT},
+ signature_constants.REGRESS_METHOD_NAME)
+
+ self._assertInvalidSignature(
+ {"inputs": _STRING},
+ {"outputs": _STRING},
+ signature_constants.REGRESS_METHOD_NAME)
+
+ # REGRESS: wrong keys
+ self._assertInvalidSignature(
+ {},
+ {"outputs": _FLOAT},
+ signature_constants.REGRESS_METHOD_NAME)
+
+ self._assertInvalidSignature(
+ {"inputs_WRONG": _STRING},
+ {"outputs": _FLOAT},
+ signature_constants.REGRESS_METHOD_NAME)
+
+ self._assertInvalidSignature(
+ {"inputs": _STRING},
+ {"outputs_WRONG": _FLOAT},
+ signature_constants.REGRESS_METHOD_NAME)
+
+ self._assertInvalidSignature(
+ {"inputs": _STRING},
+ {},
+ signature_constants.REGRESS_METHOD_NAME)
+
+ self._assertInvalidSignature(
+ {"inputs": _STRING},
+ {"outputs": _FLOAT, "extra_WRONG": _STRING},
+ signature_constants.REGRESS_METHOD_NAME)
+
+ def testInvalidPredictSignaturesAreRejected(self):
+ # PREDICT: wrong keys
+ self._assertInvalidSignature(
+ {},
+ {"baz": _STRING, "qux": _FLOAT},
+ signature_constants.PREDICT_METHOD_NAME)
+
+ self._assertInvalidSignature(
+ {"foo": _STRING, "bar": _FLOAT},
+ {},
+ signature_constants.PREDICT_METHOD_NAME)
if __name__ == "__main__":
test.main()