aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/saved_model
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-04 12:02:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 12:06:36 -0700
commit2fcec016cec1ec70ba715c9b2f4c759c71eaafca (patch)
treeb73548e8df4b99b68992e2a9f4563e8ec1041354 /tensorflow/contrib/saved_model
parent22e855159462b502dc3af138d254214bd02cf68b (diff)
Add IsValidSignature method to signature_def_utils
PiperOrigin-RevId: 211498364
Diffstat (limited to 'tensorflow/contrib/saved_model')
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/BUILD2
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc81
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h3
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc123
4 files changed, 201 insertions, 8 deletions
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/BUILD b/tensorflow/contrib/saved_model/cc/saved_model/BUILD
index 3c616c555b..ea4d41d43b 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/cc/saved_model/BUILD
@@ -30,6 +30,7 @@ cc_library(
hdrs = ["signature_def_utils.h"],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/cc/saved_model:signature_constants",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
@@ -42,6 +43,7 @@ tf_cc_test(
srcs = ["signature_def_utils_test.cc"],
deps = [
":signature_def_utils",
+ "//tensorflow/cc/saved_model:signature_constants",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc
index a45908d272..e87e497e5f 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h"
+#include "tensorflow/cc/saved_model/signature_constants.h"
+#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -33,6 +35,79 @@ Status FindInProtobufMap(StringPiece description,
*value = &it->second;
return Status::OK();
}
+
+// Looks up the TensorInfo for the given key in the given map and verifies that
+// its datatype matches the given correct datatype.
+bool VerifyTensorInfoForKeyInMap(const protobuf::Map<string, TensorInfo>& map,
+ const string& key, DataType correct_dtype) {
+ const TensorInfo* tensor_info;
+ const Status& status = FindInProtobufMap("", map, key, &tensor_info);
+ if (!status.ok()) {
+ return false;
+ }
+ if (tensor_info->dtype() != correct_dtype) {
+ return false;
+ }
+ return true;
+}
+
+bool IsValidPredictSignature(const SignatureDef& signature_def) {
+ if (signature_def.method_name() != kPredictMethodName) {
+ return false;
+ }
+ if (signature_def.inputs().empty()) {
+ return false;
+ }
+ if (signature_def.outputs().empty()) {
+ return false;
+ }
+ return true;
+}
+
+bool IsValidRegressionSignature(const SignatureDef& signature_def) {
+ if (signature_def.method_name() != kRegressMethodName) {
+ return false;
+ }
+ if (!VerifyTensorInfoForKeyInMap(signature_def.inputs(), kRegressInputs,
+ DT_STRING)) {
+ return false;
+ }
+ if (!VerifyTensorInfoForKeyInMap(signature_def.outputs(), kRegressOutputs,
+ DT_FLOAT)) {
+ return false;
+ }
+ return true;
+}
+
+bool IsValidClassificationSignature(const SignatureDef& signature_def) {
+ if (signature_def.method_name() != kClassifyMethodName) {
+ return false;
+ }
+ if (!VerifyTensorInfoForKeyInMap(signature_def.inputs(), kClassifyInputs,
+ DT_STRING)) {
+ return false;
+ }
+ if (signature_def.outputs().empty()) {
+ return false;
+ }
+ for (auto const& output : signature_def.outputs()) {
+ const string& key = output.first;
+ const TensorInfo& tensor_info = output.second;
+ if (key == kClassifyOutputClasses) {
+ if (tensor_info.dtype() != DT_STRING) {
+ return false;
+ }
+ } else if (key == kClassifyOutputScores) {
+ if (tensor_info.dtype() != DT_FLOAT) {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ }
+ return true;
+}
+
} // namespace
Status FindSignatureDefByKey(const MetaGraphDef& meta_graph_def,
@@ -74,4 +149,10 @@ Status FindOutputTensorNameByKey(const SignatureDef& signature_def,
return Status::OK();
}
+bool IsValidSignature(const SignatureDef& signature_def) {
+ return IsValidClassificationSignature(signature_def) ||
+ IsValidRegressionSignature(signature_def) ||
+ IsValidPredictSignature(signature_def);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
index b732cdd41e..bb24faa989 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
@@ -64,6 +64,9 @@ Status FindInputTensorNameByKey(const SignatureDef& signature_def,
Status FindOutputTensorNameByKey(const SignatureDef& signature_def,
const string& tensor_info_key, string* name);
+// Determine whether a SignatureDef can be served by TensorFlow Serving.
+bool IsValidSignature(const SignatureDef& signature_def);
+
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc
index a063e95696..c743112ce0 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h"
+#include "tensorflow/cc/saved_model/signature_constants.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -22,7 +23,7 @@ limitations under the License.
namespace tensorflow {
-class SignatureDefUtilsTest : public ::testing::Test {
+class FindByKeyTest : public ::testing::Test {
protected:
MetaGraphDef MakeSampleMetaGraphDef() {
MetaGraphDef result;
@@ -32,13 +33,23 @@ class SignatureDefUtilsTest : public ::testing::Test {
return result;
}
+ void SetInputNameForKey(const string& key, const string& name,
+ SignatureDef* signature_def) {
+ (*signature_def->mutable_inputs())[key].set_name(name);
+ }
+
+ void SetOutputNameForKey(const string& key, const string& name,
+ SignatureDef* signature_def) {
+ (*signature_def->mutable_outputs())[key].set_name(name);
+ }
+
SignatureDef MakeSampleSignatureDef() {
SignatureDef result;
result.set_method_name(kMethodName);
- (*result.mutable_inputs())[kInput1Key].set_name(kInput1Name);
- (*result.mutable_inputs())[kInput2Key].set_name(kInput2Name);
- (*result.mutable_outputs())[kOutput1Key].set_name(kOutput1Name);
- (*result.mutable_outputs())[kOutput2Key].set_name(kOutput2Name);
+ SetInputNameForKey(kInput1Key, kInput1Name, &result);
+ SetInputNameForKey(kInput2Key, kInput2Name, &result);
+ SetOutputNameForKey(kOutput1Key, kOutput1Name, &result);
+ SetOutputNameForKey(kOutput2Key, kOutput2Name, &result);
return result;
}
@@ -54,7 +65,7 @@ class SignatureDefUtilsTest : public ::testing::Test {
const string kOutput2Name = "output_two";
};
-TEST_F(SignatureDefUtilsTest, FindSignatureDefByKey) {
+TEST_F(FindByKeyTest, FindSignatureDefByKey) {
const MetaGraphDef meta_graph_def = MakeSampleMetaGraphDef();
const SignatureDef* signature_def;
// Succeeds for an existing signature.
@@ -67,7 +78,7 @@ TEST_F(SignatureDefUtilsTest, FindSignatureDefByKey) {
.ok());
}
-TEST_F(SignatureDefUtilsTest, FindInputTensorNameByKey) {
+TEST_F(FindByKeyTest, FindInputTensorNameByKey) {
const SignatureDef signature_def = MakeSampleSignatureDef();
string name;
// Succeeds for an existing input.
@@ -78,7 +89,7 @@ TEST_F(SignatureDefUtilsTest, FindInputTensorNameByKey) {
FindInputTensorNameByKey(signature_def, "nonexistent", &name).ok());
}
-TEST_F(SignatureDefUtilsTest, FindOutputTensorNameByKey) {
+TEST_F(FindByKeyTest, FindOutputTensorNameByKey) {
const SignatureDef signature_def = MakeSampleSignatureDef();
string name;
// Succeeds for an existing output.
@@ -89,4 +100,100 @@ TEST_F(SignatureDefUtilsTest, FindOutputTensorNameByKey) {
FindOutputTensorNameByKey(signature_def, "nonexistent", &name).ok());
}
+class IsValidSignatureTest : public ::testing::Test {
+ protected:
+ void SetInputDataTypeForKey(const string& key, DataType dtype) {
+ (*signature_def_.mutable_inputs())[key].set_dtype(dtype);
+ }
+
+ void SetOutputDataTypeForKey(const string& key, DataType dtype) {
+ (*signature_def_.mutable_outputs())[key].set_dtype(dtype);
+ }
+
+ void EraseOutputKey(const string& key) {
+ (*signature_def_.mutable_outputs()).erase(key);
+ }
+
+ void ExpectInvalidSignature() {
+ EXPECT_FALSE(IsValidSignature(signature_def_));
+ }
+
+ void ExpectValidSignature() { EXPECT_TRUE(IsValidSignature(signature_def_)); }
+
+ SignatureDef signature_def_;
+};
+
+TEST_F(IsValidSignatureTest, IsValidPredictSignature) {
+ signature_def_.set_method_name("not_kPredictMethodName");
+ // Incorrect method name
+ ExpectInvalidSignature();
+
+ signature_def_.set_method_name(kPredictMethodName);
+ // No inputs
+ ExpectInvalidSignature();
+
+ SetInputDataTypeForKey(kPredictInputs, DT_STRING);
+ // No outputs
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kPredictOutputs, DT_STRING);
+ ExpectValidSignature();
+}
+
+TEST_F(IsValidSignatureTest, IsValidRegressionSignature) {
+ signature_def_.set_method_name("not_kRegressMethodName");
+ // Incorrect method name
+ ExpectInvalidSignature();
+
+ signature_def_.set_method_name(kRegressMethodName);
+ // No inputs
+ ExpectInvalidSignature();
+
+ SetInputDataTypeForKey(kRegressInputs, DT_STRING);
+ // No outputs
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kRegressOutputs, DT_STRING);
+ // Incorrect data type
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kRegressOutputs, DT_FLOAT);
+ ExpectValidSignature();
+}
+
+TEST_F(IsValidSignatureTest, IsValidClassificationSignature) {
+ signature_def_.set_method_name("not_kClassifyMethodName");
+ // Incorrect method name
+ ExpectInvalidSignature();
+
+ signature_def_.set_method_name(kClassifyMethodName);
+ // No inputs
+ ExpectInvalidSignature();
+
+ SetInputDataTypeForKey(kClassifyInputs, DT_STRING);
+ // No outputs
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey("invalidKey", DT_FLOAT);
+ // Invalid key
+ ExpectInvalidSignature();
+
+ EraseOutputKey("invalidKey");
+ SetOutputDataTypeForKey(kClassifyOutputClasses, DT_FLOAT);
+ // Invalid dtype for classes
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kClassifyOutputClasses, DT_STRING);
+ // Valid without scores
+ ExpectValidSignature();
+
+ SetOutputDataTypeForKey(kClassifyOutputScores, DT_STRING);
+ // Invalid dtype for scores
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kClassifyOutputScores, DT_FLOAT);
+ // Valid with both classes and scores
+ ExpectValidSignature();
+}
+
} // namespace tensorflow