diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-04 12:02:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-04 12:06:36 -0700 |
commit | 2fcec016cec1ec70ba715c9b2f4c759c71eaafca (patch) | |
tree | b73548e8df4b99b68992e2a9f4563e8ec1041354 /tensorflow/contrib/saved_model | |
parent | 22e855159462b502dc3af138d254214bd02cf68b (diff) |
Add IsValidSignature method to signature_def_utils
PiperOrigin-RevId: 211498364
Diffstat (limited to 'tensorflow/contrib/saved_model')
4 files changed, 201 insertions, 8 deletions
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/BUILD b/tensorflow/contrib/saved_model/cc/saved_model/BUILD index 3c616c555b..ea4d41d43b 100644 --- a/tensorflow/contrib/saved_model/cc/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/cc/saved_model/BUILD @@ -30,6 +30,7 @@ cc_library( hdrs = ["signature_def_utils.h"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/cc/saved_model:signature_constants", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", @@ -42,6 +43,7 @@ tf_cc_test( srcs = ["signature_def_utils_test.cc"], deps = [ ":signature_def_utils", + "//tensorflow/cc/saved_model:signature_constants", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc index a45908d272..e87e497e5f 100644 --- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc +++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h" +#include "tensorflow/cc/saved_model/signature_constants.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/protobuf.h" @@ -33,6 +35,79 @@ Status FindInProtobufMap(StringPiece description, *value = &it->second; return Status::OK(); } + +// Looks up the TensorInfo for the given key in the given map and verifies that +// its datatype matches the given correct datatype. +bool VerifyTensorInfoForKeyInMap(const protobuf::Map<string, TensorInfo>& map, + const string& key, DataType correct_dtype) { + const TensorInfo* tensor_info; + const Status& status = FindInProtobufMap("", map, key, &tensor_info); + if (!status.ok()) { + return false; + } + if (tensor_info->dtype() != correct_dtype) { + return false; + } + return true; +} + +bool IsValidPredictSignature(const SignatureDef& signature_def) { + if (signature_def.method_name() != kPredictMethodName) { + return false; + } + if (signature_def.inputs().empty()) { + return false; + } + if (signature_def.outputs().empty()) { + return false; + } + return true; +} + +bool IsValidRegressionSignature(const SignatureDef& signature_def) { + if (signature_def.method_name() != kRegressMethodName) { + return false; + } + if (!VerifyTensorInfoForKeyInMap(signature_def.inputs(), kRegressInputs, + DT_STRING)) { + return false; + } + if (!VerifyTensorInfoForKeyInMap(signature_def.outputs(), kRegressOutputs, + DT_FLOAT)) { + return false; + } + return true; +} + +bool IsValidClassificationSignature(const SignatureDef& signature_def) { + if (signature_def.method_name() != kClassifyMethodName) { + return false; + } + if (!VerifyTensorInfoForKeyInMap(signature_def.inputs(), kClassifyInputs, + DT_STRING)) { + return false; + } + if (signature_def.outputs().empty()) { + return false; + } + for (auto const& output : signature_def.outputs()) { + const string& key = output.first; + const TensorInfo& tensor_info = output.second; + if (key == kClassifyOutputClasses) { + if (tensor_info.dtype() != DT_STRING) { + return false; + } + } else if (key == kClassifyOutputScores) { + if (tensor_info.dtype() != DT_FLOAT) { + return false; + } + } else { + return false; + } + } + return true; +} + } // namespace Status FindSignatureDefByKey(const MetaGraphDef& meta_graph_def, @@ -74,4 +149,10 @@ Status FindOutputTensorNameByKey(const SignatureDef& signature_def, return Status::OK(); } +bool IsValidSignature(const SignatureDef& signature_def) { + return IsValidClassificationSignature(signature_def) || + IsValidRegressionSignature(signature_def) || + IsValidPredictSignature(signature_def); +} + } // namespace tensorflow diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h index b732cdd41e..bb24faa989 100644 --- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h +++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h @@ -64,6 +64,9 @@ Status FindInputTensorNameByKey(const SignatureDef& signature_def, Status FindOutputTensorNameByKey(const SignatureDef& signature_def, const string& tensor_info_key, string* name); +// Determine whether a SignatureDef can be served by TensorFlow Serving. +bool IsValidSignature(const SignatureDef& signature_def); + } // namespace tensorflow #endif // TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_ diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc index a063e95696..c743112ce0 100644 --- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc +++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h" +#include "tensorflow/cc/saved_model/signature_constants.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -22,7 +23,7 @@ limitations under the License. namespace tensorflow { -class SignatureDefUtilsTest : public ::testing::Test { +class FindByKeyTest : public ::testing::Test { protected: MetaGraphDef MakeSampleMetaGraphDef() { MetaGraphDef result; @@ -32,13 +33,23 @@ class SignatureDefUtilsTest : public ::testing::Test { return result; } + void SetInputNameForKey(const string& key, const string& name, + SignatureDef* signature_def) { + (*signature_def->mutable_inputs())[key].set_name(name); + } + + void SetOutputNameForKey(const string& key, const string& name, + SignatureDef* signature_def) { + (*signature_def->mutable_outputs())[key].set_name(name); + } + SignatureDef MakeSampleSignatureDef() { SignatureDef result; result.set_method_name(kMethodName); - (*result.mutable_inputs())[kInput1Key].set_name(kInput1Name); - (*result.mutable_inputs())[kInput2Key].set_name(kInput2Name); - (*result.mutable_outputs())[kOutput1Key].set_name(kOutput1Name); - (*result.mutable_outputs())[kOutput2Key].set_name(kOutput2Name); + SetInputNameForKey(kInput1Key, kInput1Name, &result); + SetInputNameForKey(kInput2Key, kInput2Name, &result); + SetOutputNameForKey(kOutput1Key, kOutput1Name, &result); + SetOutputNameForKey(kOutput2Key, kOutput2Name, &result); return result; } @@ -54,7 +65,7 @@ class SignatureDefUtilsTest : public ::testing::Test { const string kOutput2Name = "output_two"; }; -TEST_F(SignatureDefUtilsTest, FindSignatureDefByKey) { +TEST_F(FindByKeyTest, FindSignatureDefByKey) { const MetaGraphDef meta_graph_def = MakeSampleMetaGraphDef(); const SignatureDef* signature_def; // Succeeds for an existing signature. @@ -67,7 +78,7 @@ TEST_F(SignatureDefUtilsTest, FindSignatureDefByKey) { .ok()); } -TEST_F(SignatureDefUtilsTest, FindInputTensorNameByKey) { +TEST_F(FindByKeyTest, FindInputTensorNameByKey) { const SignatureDef signature_def = MakeSampleSignatureDef(); string name; // Succeeds for an existing input. @@ -78,7 +89,7 @@ TEST_F(SignatureDefUtilsTest, FindInputTensorNameByKey) { FindInputTensorNameByKey(signature_def, "nonexistent", &name).ok()); } -TEST_F(SignatureDefUtilsTest, FindOutputTensorNameByKey) { +TEST_F(FindByKeyTest, FindOutputTensorNameByKey) { const SignatureDef signature_def = MakeSampleSignatureDef(); string name; // Succeeds for an existing output. @@ -89,4 +100,100 @@ TEST_F(SignatureDefUtilsTest, FindOutputTensorNameByKey) { FindOutputTensorNameByKey(signature_def, "nonexistent", &name).ok()); } +class IsValidSignatureTest : public ::testing::Test { + protected: + void SetInputDataTypeForKey(const string& key, DataType dtype) { + (*signature_def_.mutable_inputs())[key].set_dtype(dtype); + } + + void SetOutputDataTypeForKey(const string& key, DataType dtype) { + (*signature_def_.mutable_outputs())[key].set_dtype(dtype); + } + + void EraseOutputKey(const string& key) { + (*signature_def_.mutable_outputs()).erase(key); + } + + void ExpectInvalidSignature() { + EXPECT_FALSE(IsValidSignature(signature_def_)); + } + + void ExpectValidSignature() { EXPECT_TRUE(IsValidSignature(signature_def_)); } + + SignatureDef signature_def_; +}; + +TEST_F(IsValidSignatureTest, IsValidPredictSignature) { + signature_def_.set_method_name("not_kPredictMethodName"); + // Incorrect method name + ExpectInvalidSignature(); + + signature_def_.set_method_name(kPredictMethodName); + // No inputs + ExpectInvalidSignature(); + + SetInputDataTypeForKey(kPredictInputs, DT_STRING); + // No outputs + ExpectInvalidSignature(); + + SetOutputDataTypeForKey(kPredictOutputs, DT_STRING); + ExpectValidSignature(); +} + +TEST_F(IsValidSignatureTest, IsValidRegressionSignature) { + signature_def_.set_method_name("not_kRegressMethodName"); + // Incorrect method name + ExpectInvalidSignature(); + + signature_def_.set_method_name(kRegressMethodName); + // No inputs + ExpectInvalidSignature(); + + SetInputDataTypeForKey(kRegressInputs, DT_STRING); + // No outputs + ExpectInvalidSignature(); + + SetOutputDataTypeForKey(kRegressOutputs, DT_STRING); + // Incorrect data type + ExpectInvalidSignature(); + + SetOutputDataTypeForKey(kRegressOutputs, DT_FLOAT); + ExpectValidSignature(); +} + +TEST_F(IsValidSignatureTest, IsValidClassificationSignature) { + signature_def_.set_method_name("not_kClassifyMethodName"); + // Incorrect method name + ExpectInvalidSignature(); + + signature_def_.set_method_name(kClassifyMethodName); + // No inputs + ExpectInvalidSignature(); + + SetInputDataTypeForKey(kClassifyInputs, DT_STRING); + // No outputs + ExpectInvalidSignature(); + + SetOutputDataTypeForKey("invalidKey", DT_FLOAT); + // Invalid key + ExpectInvalidSignature(); + + EraseOutputKey("invalidKey"); + SetOutputDataTypeForKey(kClassifyOutputClasses, DT_FLOAT); + // Invalid dtype for classes + ExpectInvalidSignature(); + + SetOutputDataTypeForKey(kClassifyOutputClasses, DT_STRING); + // Valid without scores + ExpectValidSignature(); + + SetOutputDataTypeForKey(kClassifyOutputScores, DT_STRING); + // Invalid dtype for scores + ExpectInvalidSignature(); + + SetOutputDataTypeForKey(kClassifyOutputScores, DT_FLOAT); + // Valid with both classes and scores + ExpectValidSignature(); +} + } // namespace tensorflow |