diff options
author | David Soergel <soergel@google.com> | 2017-09-28 11:55:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-28 12:01:58 -0700 |
commit | 996b0342879af43de1bf4071190b90ff7309428a (patch) | |
tree | 9cfe19f90e59c22140cfa419c04db00b520871b3 /tensorflow/python/saved_model | |
parent | 0254d0d31337724db911c89609336afd60e8192d (diff) |
Add more validation of serving signatures, both at creation and post hoc.
PiperOrigin-RevId: 170376578
Diffstat (limited to 'tensorflow/python/saved_model')
3 files changed, 257 insertions, 12 deletions
diff --git a/tensorflow/python/saved_model/signature_def_utils.py b/tensorflow/python/saved_model/signature_def_utils.py index a7c648ce2f..ea0f52f17e 100644 --- a/tensorflow/python/saved_model/signature_def_utils.py +++ b/tensorflow/python/saved_model/signature_def_utils.py @@ -23,6 +23,7 @@ from __future__ import print_function # pylint: disable=unused-import from tensorflow.python.saved_model.signature_def_utils_impl import build_signature_def from tensorflow.python.saved_model.signature_def_utils_impl import classification_signature_def +from tensorflow.python.saved_model.signature_def_utils_impl import is_valid_signature from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def from tensorflow.python.saved_model.signature_def_utils_impl import regression_signature_def # pylint: enable=unused-import diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py index 7a3fb16825..564befeb0b 100644 --- a/tensorflow/python/saved_model/signature_def_utils_impl.py +++ b/tensorflow/python/saved_model/signature_def_utils_impl.py @@ -18,8 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + +from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import utils @@ -64,15 +67,22 @@ def regression_signature_def(examples, predictions): ValueError: If examples is `None`. """ if examples is None: - raise ValueError('examples cannot be None for regression.') + raise ValueError('Regression examples cannot be None.') + if not isinstance(examples, ops.Tensor): + raise ValueError('Regression examples must be a string Tensor.') if predictions is None: - raise ValueError('predictions cannot be None for regression.') + raise ValueError('Regression predictions cannot be None.') input_tensor_info = utils.build_tensor_info(examples) + if input_tensor_info.dtype != types_pb2.DT_STRING: + raise ValueError('Regression examples must be a string Tensor.') signature_inputs = {signature_constants.REGRESS_INPUTS: input_tensor_info} output_tensor_info = utils.build_tensor_info(predictions) + if output_tensor_info.dtype != types_pb2.DT_FLOAT: + raise ValueError('Regression output must be a float Tensor.') signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor_info} + signature_def = build_signature_def( signature_inputs, signature_outputs, signature_constants.REGRESS_METHOD_NAME) @@ -95,21 +105,28 @@ def classification_signature_def(examples, classes, scores): ValueError: If examples is `None`. """ if examples is None: - raise ValueError('examples cannot be None for classification.') + raise ValueError('Classification examples cannot be None.') + if not isinstance(examples, ops.Tensor): + raise ValueError('Classification examples must be a string Tensor.') if classes is None and scores is None: - raise ValueError('classes and scores cannot both be None for ' - 'classification.') + raise ValueError('Classification classes and scores cannot both be None.') input_tensor_info = utils.build_tensor_info(examples) + if input_tensor_info.dtype != types_pb2.DT_STRING: + raise ValueError('Classification examples must be a string Tensor.') signature_inputs = {signature_constants.CLASSIFY_INPUTS: input_tensor_info} signature_outputs = {} if classes is not None: classes_tensor_info = utils.build_tensor_info(classes) + if classes_tensor_info.dtype != types_pb2.DT_STRING: + raise ValueError('Classification classes must be a string Tensor.') signature_outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES] = ( classes_tensor_info) if scores is not None: scores_tensor_info = utils.build_tensor_info(scores) + if scores_tensor_info.dtype != types_pb2.DT_FLOAT: + raise ValueError('Classification scores must be a float Tensor.') signature_outputs[signature_constants.CLASSIFY_OUTPUT_SCORES] = ( scores_tensor_info) @@ -134,9 +151,9 @@ def predict_signature_def(inputs, outputs): ValueError: If inputs or outputs is `None`. """ if inputs is None or not inputs: - raise ValueError('inputs cannot be None or empty for prediction.') - if outputs is None: - raise ValueError('outputs cannot be None or empty for prediction.') + raise ValueError('Prediction inputs cannot be None or empty.') + if outputs is None or not outputs: + raise ValueError('Prediction outputs cannot be None or empty.') signature_inputs = {key: utils.build_tensor_info(tensor) for key, tensor in inputs.items()} @@ -150,6 +167,81 @@ def predict_signature_def(inputs, outputs): return signature_def +def is_valid_signature(signature_def): + """Determine whether a SignatureDef can be served by TensorFlow Serving.""" + if signature_def is None: + return False + return (_is_valid_classification_signature(signature_def) or + _is_valid_regression_signature(signature_def) or + _is_valid_predict_signature(signature_def)) + + +def _is_valid_predict_signature(signature_def): + """Determine whether the argument is a servable 'predict' SignatureDef.""" + if signature_def.method_name != signature_constants.PREDICT_METHOD_NAME: + return False + if not signature_def.inputs.keys(): + return False + if not signature_def.outputs.keys(): + return False + return True + + +def _is_valid_regression_signature(signature_def): + """Determine whether the argument is a servable 'regress' SignatureDef.""" + if signature_def.method_name != signature_constants.REGRESS_METHOD_NAME: + return False + + if (set(signature_def.inputs.keys()) + != set([signature_constants.REGRESS_INPUTS])): + return False + if (signature_def.inputs[signature_constants.REGRESS_INPUTS].dtype != + types_pb2.DT_STRING): + return False + + if (set(signature_def.outputs.keys()) + != set([signature_constants.REGRESS_OUTPUTS])): + return False + if (signature_def.outputs[signature_constants.REGRESS_OUTPUTS].dtype != + types_pb2.DT_FLOAT): + return False + + return True + + +def _is_valid_classification_signature(signature_def): + """Determine whether the argument is a servable 'classify' SignatureDef.""" + if signature_def.method_name != signature_constants.CLASSIFY_METHOD_NAME: + return False + + if (set(signature_def.inputs.keys()) + != set([signature_constants.CLASSIFY_INPUTS])): + return False + if (signature_def.inputs[signature_constants.CLASSIFY_INPUTS].dtype != + types_pb2.DT_STRING): + return False + + allowed_outputs = set([signature_constants.CLASSIFY_OUTPUT_CLASSES, + signature_constants.CLASSIFY_OUTPUT_SCORES]) + + if not signature_def.outputs.keys(): + return False + if set(signature_def.outputs.keys()) - allowed_outputs: + return False + if (signature_constants.CLASSIFY_OUTPUT_CLASSES in signature_def.outputs + and + signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES].dtype + != types_pb2.DT_STRING): + return False + if (signature_constants.CLASSIFY_OUTPUT_SCORES in signature_def.outputs + and + signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES].dtype != + types_pb2.DT_FLOAT): + return False + + return True + + def _get_shapes_from_tensor_info_dict(tensor_info_dict): """Returns a map of keys to TensorShape objects. diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py index 6627602849..b2bd14db8c 100644 --- a/tensorflow/python/saved_model/signature_def_utils_test.py +++ b/tensorflow/python/saved_model/signature_def_utils_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.core.framework import types_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -28,6 +29,20 @@ from tensorflow.python.saved_model import signature_def_utils_impl from tensorflow.python.saved_model import utils +# We'll reuse the same tensor_infos in multiple contexts just for the tests. +# The validator doesn't check shapes so we just omit them. +_STRING = meta_graph_pb2.TensorInfo( + name="foobar", + dtype=dtypes.string.as_datatype_enum +) + + +_FLOAT = meta_graph_pb2.TensorInfo( + name="foobar", + dtype=dtypes.float32.as_datatype_enum +) + + def _make_signature(inputs, outputs, name=None): input_info = { input_name: utils.build_tensor_info(tensor) @@ -75,7 +90,7 @@ class SignatureDefUtilsTest(test.TestCase): def testRegressionSignatureDef(self): input1 = constant_op.constant("a", name="input-1") - output1 = constant_op.constant("b", name="output-1") + output1 = constant_op.constant(2.2, name="output-1") signature_def = signature_def_utils_impl.regression_signature_def( input1, output1) @@ -95,13 +110,13 @@ class SignatureDefUtilsTest(test.TestCase): y_tensor_info_actual = ( signature_def.outputs[signature_constants.REGRESS_OUTPUTS]) self.assertEqual("output-1:0", y_tensor_info_actual.name) - self.assertEqual(types_pb2.DT_STRING, y_tensor_info_actual.dtype) + self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype) self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim)) def testClassificationSignatureDef(self): input1 = constant_op.constant("a", name="input-1") output1 = constant_op.constant("b", name="output-1") - output2 = constant_op.constant("c", name="output-2") + output2 = constant_op.constant(3.3, name="output-2") signature_def = signature_def_utils_impl.classification_signature_def( input1, output1, output2) @@ -126,7 +141,7 @@ class SignatureDefUtilsTest(test.TestCase): scores_tensor_info_actual = ( signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES]) self.assertEqual("output-2:0", scores_tensor_info_actual.name) - self.assertEqual(types_pb2.DT_STRING, scores_tensor_info_actual.dtype) + self.assertEqual(types_pb2.DT_FLOAT, scores_tensor_info_actual.dtype) self.assertEqual(0, len(scores_tensor_info_actual.tensor_shape.dim)) def testPredictionSignatureDef(self): @@ -203,6 +218,143 @@ class SignatureDefUtilsTest(test.TestCase): # Must compare `dims` since its an unknown shape. self.assertEqual(shapes["output-2"].dims, None) + def _assertValidSignature(self, inputs, outputs, method_name): + signature_def = signature_def_utils_impl.build_signature_def( + inputs, outputs, method_name) + self.assertTrue( + signature_def_utils_impl.is_valid_signature(signature_def)) + + def _assertInvalidSignature(self, inputs, outputs, method_name): + signature_def = signature_def_utils_impl.build_signature_def( + inputs, outputs, method_name) + self.assertFalse( + signature_def_utils_impl.is_valid_signature(signature_def)) + + def testValidSignaturesAreAccepted(self): + self._assertValidSignature( + {"inputs": _STRING}, + {"classes": _STRING, "scores": _FLOAT}, + signature_constants.CLASSIFY_METHOD_NAME) + + self._assertValidSignature( + {"inputs": _STRING}, + {"classes": _STRING}, + signature_constants.CLASSIFY_METHOD_NAME) + + self._assertValidSignature( + {"inputs": _STRING}, + {"scores": _FLOAT}, + signature_constants.CLASSIFY_METHOD_NAME) + + self._assertValidSignature( + {"inputs": _STRING}, + {"outputs": _FLOAT}, + signature_constants.REGRESS_METHOD_NAME) + + self._assertValidSignature( + {"foo": _STRING, "bar": _FLOAT}, + {"baz": _STRING, "qux": _FLOAT}, + signature_constants.PREDICT_METHOD_NAME) + + def testInvalidMethodNameSignatureIsRejected(self): + # WRONG METHOD + self._assertInvalidSignature( + {"inputs": _STRING}, + {"classes": _STRING, "scores": _FLOAT}, + "WRONG method name") + + def testInvalidClassificationSignaturesAreRejected(self): + # CLASSIFY: wrong types + self._assertInvalidSignature( + {"inputs": _FLOAT}, + {"classes": _STRING, "scores": _FLOAT}, + signature_constants.CLASSIFY_METHOD_NAME) + + self._assertInvalidSignature( + {"inputs": _STRING}, + {"classes": _FLOAT, "scores": _FLOAT}, + signature_constants.CLASSIFY_METHOD_NAME) + + self._assertInvalidSignature( + {"inputs": _STRING}, + {"classes": _STRING, "scores": _STRING}, + signature_constants.CLASSIFY_METHOD_NAME) + + # CLASSIFY: wrong keys + self._assertInvalidSignature( + {}, + {"classes": _STRING, "scores": _FLOAT}, + signature_constants.CLASSIFY_METHOD_NAME) + + self._assertInvalidSignature( + {"inputs_WRONG": _STRING}, + {"classes": _STRING, "scores": _FLOAT}, + signature_constants.CLASSIFY_METHOD_NAME) + + self._assertInvalidSignature( + {"inputs": _STRING}, + {"classes_WRONG": _STRING, "scores": _FLOAT}, + signature_constants.CLASSIFY_METHOD_NAME) + + self._assertInvalidSignature( + {"inputs": _STRING}, + {}, + signature_constants.CLASSIFY_METHOD_NAME) + + self._assertInvalidSignature( + {"inputs": _STRING}, + {"classes": _STRING, "scores": _FLOAT, "extra_WRONG": _STRING}, + signature_constants.CLASSIFY_METHOD_NAME) + + def testInvalidRegressionSignaturesAreRejected(self): + # REGRESS: wrong types + self._assertInvalidSignature( + {"inputs": _FLOAT}, + {"outputs": _FLOAT}, + signature_constants.REGRESS_METHOD_NAME) + + self._assertInvalidSignature( + {"inputs": _STRING}, + {"outputs": _STRING}, + signature_constants.REGRESS_METHOD_NAME) + + # REGRESS: wrong keys + self._assertInvalidSignature( + {}, + {"outputs": _FLOAT}, + signature_constants.REGRESS_METHOD_NAME) + + self._assertInvalidSignature( + {"inputs_WRONG": _STRING}, + {"outputs": _FLOAT}, + signature_constants.REGRESS_METHOD_NAME) + + self._assertInvalidSignature( + {"inputs": _STRING}, + {"outputs_WRONG": _FLOAT}, + signature_constants.REGRESS_METHOD_NAME) + + self._assertInvalidSignature( + {"inputs": _STRING}, + {}, + signature_constants.REGRESS_METHOD_NAME) + + self._assertInvalidSignature( + {"inputs": _STRING}, + {"outputs": _FLOAT, "extra_WRONG": _STRING}, + signature_constants.REGRESS_METHOD_NAME) + + def testInvalidPredictSignaturesAreRejected(self): + # PREDICT: wrong keys + self._assertInvalidSignature( + {}, + {"baz": _STRING, "qux": _FLOAT}, + signature_constants.PREDICT_METHOD_NAME) + + self._assertInvalidSignature( + {"foo": _STRING, "bar": _FLOAT}, + {}, + signature_constants.PREDICT_METHOD_NAME) if __name__ == "__main__": test.main() |