aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/saved_model
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-19 11:37:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-19 12:45:09 -0700
commit2152312bacdc3c1146fabf3563310b5fd0bf02c8 (patch)
tree3cd941f73ff2d02b74127239c88714ade4fe6e58 /tensorflow/contrib/saved_model
parent721c48c611058aa2d43f7d3b99f5b445741765be (diff)
Add tensorflow/contrib/saved_model/cc/saved_model/signature_def_util.h with
utilities to find SignatureDefs and the tensor endpoints referenced by them, using tensorflow::Status to make it easy to forward errors. Change: 153620106
Diffstat (limited to 'tensorflow/contrib/saved_model')
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/BUILD55
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc77
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h69
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc92
4 files changed, 293 insertions, 0 deletions
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/BUILD b/tensorflow/contrib/saved_model/cc/saved_model/BUILD
new file mode 100644
index 0000000000..f3d98cfbbe
--- /dev/null
+++ b/tensorflow/contrib/saved_model/cc/saved_model/BUILD
@@ -0,0 +1,55 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# Description:
+# SavedModel contrib libraries for C++.
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+cc_library(
+ name = "signature_def_utils",
+ srcs = ["signature_def_utils.cc"],
+ hdrs = ["signature_def_utils.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_test(
+ name = "signature_def_utils_test",
+ srcs = ["signature_def_utils_test.cc"],
+ deps = [
+ ":signature_def_utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["*"]),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc
new file mode 100644
index 0000000000..a45908d272
--- /dev/null
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+namespace {
+template <class T>
+Status FindInProtobufMap(StringPiece description,
+ const protobuf::Map<string, T>& map, const string& key,
+ const T** value) {
+ const auto it = map.find(key);
+ if (it == map.end()) {
+ return errors::NotFound("Could not find ", description, " for key: ", key);
+ }
+ *value = &it->second;
+ return Status::OK();
+}
+} // namespace
+
+Status FindSignatureDefByKey(const MetaGraphDef& meta_graph_def,
+ const string& signature_def_key,
+ const SignatureDef** signature_def) {
+ return FindInProtobufMap("SignatureDef", meta_graph_def.signature_def(),
+ signature_def_key, signature_def);
+}
+
+Status FindInputTensorInfoByKey(const SignatureDef& signature_def,
+ const string& tensor_info_key,
+ const TensorInfo** tensor_info) {
+ return FindInProtobufMap("input TensorInfo", signature_def.inputs(),
+ tensor_info_key, tensor_info);
+}
+
+Status FindOutputTensorInfoByKey(const SignatureDef& signature_def,
+ const string& tensor_info_key,
+ const TensorInfo** tensor_info) {
+ return FindInProtobufMap("output TensorInfo", signature_def.outputs(),
+ tensor_info_key, tensor_info);
+}
+
+Status FindInputTensorNameByKey(const SignatureDef& signature_def,
+ const string& tensor_info_key, string* name) {
+ const TensorInfo* tensor_info;
+ TF_RETURN_IF_ERROR(
+ FindInputTensorInfoByKey(signature_def, tensor_info_key, &tensor_info));
+ *name = tensor_info->name();
+ return Status::OK();
+}
+
+Status FindOutputTensorNameByKey(const SignatureDef& signature_def,
+ const string& tensor_info_key, string* name) {
+ const TensorInfo* tensor_info;
+ TF_RETURN_IF_ERROR(
+ FindOutputTensorInfoByKey(signature_def, tensor_info_key, &tensor_info));
+ *name = tensor_info->name();
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
new file mode 100644
index 0000000000..c0df224bc8
--- /dev/null
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
@@ -0,0 +1,69 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Helpers for working with the SignatureDefs of TensorFlow SavedModels.
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+
+namespace tensorflow {
+
+// Finds the entry in meta_graph_def.signature_def with the given key, or
+// returns NotFound and leaves *signature_def unchanged. NOTE: The output
+// SignatureDef* points into meta_graph_def and may be invalidated by changes
+// to that protocol buffer, as usual.
+Status FindSignatureDefByKey(const MetaGraphDef& meta_graph_def,
+ const string& signature_def_key,
+ const SignatureDef** signature_def);
+
+// Finds the entry in signature_def.inputs with the given key, or returns
+// NotFound and leaves *tensor_info unchanged. NOTE: The output TensorInfo*
+// points into signature_def and may be invalidated by changes to that protocol
+// buffer, as usual.
+Status FindInputTensorInfoByKey(const SignatureDef& signature_def,
+ const string& tensor_info_key,
+ const TensorInfo** tensor_info);
+
+// Finds the entry in signature_def.outputs with the given key, or returns
+// NotFound and leaves *tensor_info unchanged. NOTE: The output TensorInfo*
+// points into signature_def and may be invalidated by changes to that protocol
+// buffer, as usual.
+Status FindOutputTensorInfoByKey(const SignatureDef& signature_def,
+ const string& tensor_info_key,
+ const TensorInfo** tensor_info);
+
+// Finds the entry in signature_def.inputs with the given key and copies out
+// the name of this Tensor in the graph, or returns NotFound and leaves *name
+// unchanged.
+Status FindInputTensorNameByKey(const SignatureDef& signature_def,
+ const string& tensor_info_key, string* name);
+
+// Finds the entry in signature_def.outputs with the given key and copies out
+// the name of this Tensor in the graph, or returns NotFound and leaves *name
+// unchanged.
+Status FindOutputTensorNameByKey(const SignatureDef& signature_def,
+ const string& tensor_info_key, string* name);
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc
new file mode 100644
index 0000000000..a063e95696
--- /dev/null
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc
@@ -0,0 +1,92 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h"
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+class SignatureDefUtilsTest : public ::testing::Test {
+ protected:
+ MetaGraphDef MakeSampleMetaGraphDef() {
+ MetaGraphDef result;
+ (*result.mutable_signature_def())["blah"].set_method_name("foo");
+ (*result.mutable_signature_def())[kSignatureKey] = MakeSampleSignatureDef();
+ (*result.mutable_signature_def())["gnarl"].set_method_name("blah");
+ return result;
+ }
+
+ SignatureDef MakeSampleSignatureDef() {
+ SignatureDef result;
+ result.set_method_name(kMethodName);
+ (*result.mutable_inputs())[kInput1Key].set_name(kInput1Name);
+ (*result.mutable_inputs())[kInput2Key].set_name(kInput2Name);
+ (*result.mutable_outputs())[kOutput1Key].set_name(kOutput1Name);
+ (*result.mutable_outputs())[kOutput2Key].set_name(kOutput2Name);
+ return result;
+ }
+
+ const string kSignatureKey = "my_signature";
+ const string kMethodName = "my_method";
+ const string kInput1Key = "input_one_key";
+ const string kInput1Name = "input_one";
+ const string kInput2Key = "input_two_key";
+ const string kInput2Name = "input_two";
+ const string kOutput1Key = "output_one_key";
+ const string kOutput1Name = "output_one";
+ const string kOutput2Key = "output_two_key";
+ const string kOutput2Name = "output_two";
+};
+
+TEST_F(SignatureDefUtilsTest, FindSignatureDefByKey) {
+ const MetaGraphDef meta_graph_def = MakeSampleMetaGraphDef();
+ const SignatureDef* signature_def;
+ // Succeeds for an existing signature.
+ TF_ASSERT_OK(
+ FindSignatureDefByKey(meta_graph_def, kSignatureKey, &signature_def));
+ EXPECT_EQ(kMethodName, signature_def->method_name());
+ // Fails for a missing signature.
+ EXPECT_FALSE(
+ FindSignatureDefByKey(meta_graph_def, "nonexistent", &signature_def)
+ .ok());
+}
+
+TEST_F(SignatureDefUtilsTest, FindInputTensorNameByKey) {
+ const SignatureDef signature_def = MakeSampleSignatureDef();
+ string name;
+ // Succeeds for an existing input.
+ TF_ASSERT_OK(FindInputTensorNameByKey(signature_def, kInput2Key, &name));
+ EXPECT_EQ(kInput2Name, name);
+ // Fails for a missing input.
+ EXPECT_FALSE(
+ FindInputTensorNameByKey(signature_def, "nonexistent", &name).ok());
+}
+
+TEST_F(SignatureDefUtilsTest, FindOutputTensorNameByKey) {
+ const SignatureDef signature_def = MakeSampleSignatureDef();
+ string name;
+ // Succeeds for an existing output.
+ TF_ASSERT_OK(FindOutputTensorNameByKey(signature_def, kOutput2Key, &name));
+ EXPECT_EQ(kOutput2Name, name);
+ // Fails for a missing output.
+ EXPECT_FALSE(
+ FindOutputTensorNameByKey(signature_def, "nonexistent", &name).ok());
+}
+
+} // namespace tensorflow