diff options
author | 2017-04-19 11:37:39 -0800 | |
---|---|---|
committer | 2017-04-19 12:45:09 -0700 | |
commit | 2152312bacdc3c1146fabf3563310b5fd0bf02c8 (patch) | |
tree | 3cd941f73ff2d02b74127239c88714ade4fe6e58 /tensorflow/contrib/saved_model | |
parent | 721c48c611058aa2d43f7d3b99f5b445741765be (diff) |
Add tensorflow/contrib/saved_model/cc/saved_model/signature_def_util.h with
utilities to find SignatureDefs and the tensor endpoints referenced by them,
using tensorflow::Status to make it easy to forward errors.
Change: 153620106
Diffstat (limited to 'tensorflow/contrib/saved_model')
4 files changed, 293 insertions, 0 deletions
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/BUILD b/tensorflow/contrib/saved_model/cc/saved_model/BUILD new file mode 100644 index 0000000000..f3d98cfbbe --- /dev/null +++ b/tensorflow/contrib/saved_model/cc/saved_model/BUILD @@ -0,0 +1,55 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Description: +# SavedModel contrib libraries for C++. + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +cc_library( + name = "signature_def_utils", + srcs = ["signature_def_utils.cc"], + hdrs = ["signature_def_utils.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "signature_def_utils_test", + srcs = ["signature_def_utils_test.cc"], + deps = [ + ":signature_def_utils", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["*"]), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc new file mode 100644 index 0000000000..a45908d272 --- /dev/null +++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +namespace { +template <class T> +Status FindInProtobufMap(StringPiece description, + const protobuf::Map<string, T>& map, const string& key, + const T** value) { + const auto it = map.find(key); + if (it == map.end()) { + return errors::NotFound("Could not find ", description, " for key: ", key); + } + *value = &it->second; + return Status::OK(); +} +} // namespace + +Status FindSignatureDefByKey(const MetaGraphDef& meta_graph_def, + const string& signature_def_key, + const SignatureDef** signature_def) { + return FindInProtobufMap("SignatureDef", meta_graph_def.signature_def(), + signature_def_key, signature_def); +} + +Status FindInputTensorInfoByKey(const SignatureDef& signature_def, + const string& tensor_info_key, + const TensorInfo** tensor_info) { + return FindInProtobufMap("input TensorInfo", signature_def.inputs(), + tensor_info_key, tensor_info); +} + +Status FindOutputTensorInfoByKey(const SignatureDef& signature_def, + const string& tensor_info_key, + const TensorInfo** tensor_info) { + return FindInProtobufMap("output TensorInfo", signature_def.outputs(), + tensor_info_key, tensor_info); +} + +Status FindInputTensorNameByKey(const SignatureDef& signature_def, + const string& tensor_info_key, string* name) { + const TensorInfo* tensor_info; + TF_RETURN_IF_ERROR( + FindInputTensorInfoByKey(signature_def, tensor_info_key, &tensor_info)); + *name = tensor_info->name(); + return Status::OK(); +} + +Status FindOutputTensorNameByKey(const SignatureDef& signature_def, + const string& tensor_info_key, string* name) { + const TensorInfo* tensor_info; + TF_RETURN_IF_ERROR( + FindOutputTensorInfoByKey(signature_def, tensor_info_key, &tensor_info)); + *name = tensor_info->name(); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h new file mode 100644 index 0000000000..c0df224bc8 --- /dev/null +++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helpers for working with the SignatureDefs of TensorFlow SavedModels. + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_ + +#include <string> +#include <utility> +#include <vector> + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { + +// Finds the entry in meta_graph_def.signature_def with the given key, or +// returns NotFound and leaves *signature_def unchanged. NOTE: The output +// SignatureDef* points into meta_graph_def and may be invalidated by changes +// to that protocol buffer, as usual. +Status FindSignatureDefByKey(const MetaGraphDef& meta_graph_def, + const string& signature_def_key, + const SignatureDef** signature_def); + +// Finds the entry in signature_def.inputs with the given key, or returns +// NotFound and leaves *tensor_info unchanged. NOTE: The output TensorInfo* +// points into signature_def and may be invalidated by changes to that protocol +// buffer, as usual. +Status FindInputTensorInfoByKey(const SignatureDef& signature_def, + const string& tensor_info_key, + const TensorInfo** tensor_info); + +// Finds the entry in signature_def.outputs with the given key, or returns +// NotFound and leaves *tensor_info unchanged. NOTE: The output TensorInfo* +// points into signature_def and may be invalidated by changes to that protocol +// buffer, as usual. +Status FindOutputTensorInfoByKey(const SignatureDef& signature_def, + const string& tensor_info_key, + const TensorInfo** tensor_info); + +// Finds the entry in signature_def.inputs with the given key and copies out +// the name of this Tensor in the graph, or returns NotFound and leaves *name +// unchanged. +Status FindInputTensorNameByKey(const SignatureDef& signature_def, + const string& tensor_info_key, string* name); + +// Finds the entry in signature_def.outputs with the given key and copies out +// the name of this Tensor in the graph, or returns NotFound and leaves *name +// unchanged. +Status FindOutputTensorNameByKey(const SignatureDef& signature_def, + const string& tensor_info_key, string* name); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_ diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc new file mode 100644 index 0000000000..a063e95696 --- /dev/null +++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h" + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class SignatureDefUtilsTest : public ::testing::Test { + protected: + MetaGraphDef MakeSampleMetaGraphDef() { + MetaGraphDef result; + (*result.mutable_signature_def())["blah"].set_method_name("foo"); + (*result.mutable_signature_def())[kSignatureKey] = MakeSampleSignatureDef(); + (*result.mutable_signature_def())["gnarl"].set_method_name("blah"); + return result; + } + + SignatureDef MakeSampleSignatureDef() { + SignatureDef result; + result.set_method_name(kMethodName); + (*result.mutable_inputs())[kInput1Key].set_name(kInput1Name); + (*result.mutable_inputs())[kInput2Key].set_name(kInput2Name); + (*result.mutable_outputs())[kOutput1Key].set_name(kOutput1Name); + (*result.mutable_outputs())[kOutput2Key].set_name(kOutput2Name); + return result; + } + + const string kSignatureKey = "my_signature"; + const string kMethodName = "my_method"; + const string kInput1Key = "input_one_key"; + const string kInput1Name = "input_one"; + const string kInput2Key = "input_two_key"; + const string kInput2Name = "input_two"; + const string kOutput1Key = "output_one_key"; + const string kOutput1Name = "output_one"; + const string kOutput2Key = "output_two_key"; + const string kOutput2Name = "output_two"; +}; + +TEST_F(SignatureDefUtilsTest, FindSignatureDefByKey) { + const MetaGraphDef meta_graph_def = MakeSampleMetaGraphDef(); + const SignatureDef* signature_def; + // Succeeds for an existing signature. + TF_ASSERT_OK( + FindSignatureDefByKey(meta_graph_def, kSignatureKey, &signature_def)); + EXPECT_EQ(kMethodName, signature_def->method_name()); + // Fails for a missing signature. + EXPECT_FALSE( + FindSignatureDefByKey(meta_graph_def, "nonexistent", &signature_def) + .ok()); +} + +TEST_F(SignatureDefUtilsTest, FindInputTensorNameByKey) { + const SignatureDef signature_def = MakeSampleSignatureDef(); + string name; + // Succeeds for an existing input. + TF_ASSERT_OK(FindInputTensorNameByKey(signature_def, kInput2Key, &name)); + EXPECT_EQ(kInput2Name, name); + // Fails for a missing input. + EXPECT_FALSE( + FindInputTensorNameByKey(signature_def, "nonexistent", &name).ok()); +} + +TEST_F(SignatureDefUtilsTest, FindOutputTensorNameByKey) { + const SignatureDef signature_def = MakeSampleSignatureDef(); + string name; + // Succeeds for an existing output. + TF_ASSERT_OK(FindOutputTensorNameByKey(signature_def, kOutput2Key, &name)); + EXPECT_EQ(kOutput2Name, name); + // Fails for a missing output. + EXPECT_FALSE( + FindOutputTensorNameByKey(signature_def, "nonexistent", &name).ok()); +} + +} // namespace tensorflow |