diff options
6 files changed, 18 insertions, 45 deletions
diff --git a/tensorflow/cc/saved_model/signature_constants.h b/tensorflow/cc/saved_model/signature_constants.h index 75a2831ab4..5a784874cd 100644 --- a/tensorflow/cc/saved_model/signature_constants.h +++ b/tensorflow/cc/saved_model/signature_constants.h @@ -18,6 +18,11 @@ limitations under the License. namespace tensorflow { +// Key in the signature def map for `default` serving signatures. The default +// signature is used in inference requests where a specific signature was not +// specified. +static constexpr char kDefaultServingSignatureDefKey[] = "serving_default"; + //////////////////////////////////////////////////////////////////////////////// // Classification API constants. diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD index ae2c36523a..7750e54569 100644 --- a/tensorflow/contrib/session_bundle/BUILD +++ b/tensorflow/contrib/session_bundle/BUILD @@ -266,9 +266,8 @@ cc_library( srcs = ["bundle_shim.cc"], hdrs = ["bundle_shim.h"], copts = if_ios(["-DGOOGLE_LOGGING"]), - visibility = ["//visibility:private"], + visibility = ["//visibility:public"], deps = [ - ":bundle_shim_constants", ":session_bundle", ":signature", "//tensorflow/cc/saved_model:loader", @@ -295,7 +294,6 @@ cc_test( linkstatic = 1, deps = [ ":bundle_shim", - ":bundle_shim_constants", ":test_util", "//tensorflow/cc/saved_model:loader", "//tensorflow/cc/saved_model:signature_constants", @@ -309,11 +307,6 @@ cc_test( ], ) -cc_library( - name = "bundle_shim_constants", - hdrs = ["bundle_shim_constants.h"], -) - tf_proto_library( name = "manifest_proto", srcs = ["manifest.proto"], diff --git a/tensorflow/contrib/session_bundle/bundle_shim.cc b/tensorflow/contrib/session_bundle/bundle_shim.cc index 5770719cdc..a63fc77c76 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim.cc +++ b/tensorflow/contrib/session_bundle/bundle_shim.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/cc/saved_model/signature_constants.h" -#include "tensorflow/contrib/session_bundle/bundle_shim_constants.h" #include "tensorflow/contrib/session_bundle/manifest.pb.h" #include "tensorflow/contrib/session_bundle/session_bundle.h" #include "tensorflow/contrib/session_bundle/signature.h" @@ -62,7 +61,7 @@ Status ConvertDefaultSignatureToSignatureDef(const Signatures& signatures, kRegressInputs, &signature_def); AddOutputToSignatureDef(regression_signature.output().tensor_name(), kRegressOutputs, &signature_def); - (*meta_graph_def->mutable_signature_def())[kDefaultSignatureDefKey] = + (*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] = signature_def; return Status::OK(); } else if (default_signature.type_case() == @@ -77,7 +76,7 @@ Status ConvertDefaultSignatureToSignatureDef(const Signatures& signatures, kClassifyOutputClasses, &signature_def); AddOutputToSignatureDef(classification_signature.scores().tensor_name(), kClassifyOutputScores, &signature_def); - (*meta_graph_def->mutable_signature_def())[kDefaultSignatureDefKey] = + (*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] = signature_def; return Status::OK(); } @@ -124,7 +123,7 @@ Status ConvertNamedSignaturesToSignatureDef(const Signatures& signatures, } // Add the `default` key to the signature def map of the meta graph def and // map it to the constructed signature def. - (*meta_graph_def->mutable_signature_def())[kDefaultSignatureDefKey] = + (*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] = signature_def; return Status::OK(); } diff --git a/tensorflow/contrib/session_bundle/bundle_shim_constants.h b/tensorflow/contrib/session_bundle/bundle_shim_constants.h deleted file mode 100644 index f49c0bc3e4..0000000000 --- a/tensorflow/contrib/session_bundle/bundle_shim_constants.h +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_BUNDLE_SHIM_CONSTANTS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_BUNDLE_SHIM_CONSTANTS_H_ - -namespace tensorflow { -namespace serving { - -// Key in the signature def map for `default` signatures. -static constexpr char kDefaultSignatureDefKey[] = "default"; - -} // namespace serving -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_BUNDLE_SHIM_CONSTANTS_H_ diff --git a/tensorflow/contrib/session_bundle/bundle_shim_test.cc b/tensorflow/contrib/session_bundle/bundle_shim_test.cc index 5d386f182f..cfdd05e608 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim_test.cc +++ b/tensorflow/contrib/session_bundle/bundle_shim_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/cc/saved_model/signature_constants.h" #include "tensorflow/cc/saved_model/tag_constants.h" -#include "tensorflow/contrib/session_bundle/bundle_shim_constants.h" #include "tensorflow/contrib/session_bundle/test_util.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" @@ -146,7 +145,7 @@ TEST(BundleShimTest, DefaultSignatureRegression) { ConvertDefaultSignatureToSignatureDef(signatures, &meta_graph_def); EXPECT_EQ(1, meta_graph_def.signature_def_size()); const auto actual_signature_def = - meta_graph_def.signature_def().find(kDefaultSignatureDefKey); + meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey); EXPECT_EQ("foo-input", actual_signature_def->second.inputs() .find(kRegressInputs) ->second.name()); @@ -174,7 +173,7 @@ TEST(BundleShimTest, DefaultSignatureClassification) { ConvertDefaultSignatureToSignatureDef(signatures, &meta_graph_def); EXPECT_EQ(1, meta_graph_def.signature_def_size()); const auto actual_signature_def = - meta_graph_def.signature_def().find(kDefaultSignatureDefKey); + meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey); EXPECT_EQ("foo-input", actual_signature_def->second.inputs() .find(kClassifyInputs) ->second.name()); @@ -261,7 +260,7 @@ TEST(BundleShimTest, NamedSignatureGenericInputsAndOutputs) { ConvertNamedSignaturesToSignatureDef(signatures, &meta_graph_def); EXPECT_EQ(1, meta_graph_def.signature_def_size()); const auto actual_signature_def = - meta_graph_def.signature_def().find(kDefaultSignatureDefKey); + meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey); EXPECT_EQ( "foo-input", actual_signature_def->second.inputs().find("foo-input")->second.name()); @@ -322,7 +321,7 @@ TEST(BundleShimTest, BasicExportSessionBundle) { const string session_bundle_export_dir = test_util::TestSrcDirPath(kSessionBundlePath); LoadAndValidateSavedModelBundle(session_bundle_export_dir, {"tag"}, - kDefaultSignatureDefKey); + kDefaultServingSignatureDefKey); } // Checks a basic load for half plus two for SavedModelBundle. diff --git a/tensorflow/python/saved_model/signature_constants.py b/tensorflow/python/saved_model/signature_constants.py index 6e0be8ec46..51a57cab05 100644 --- a/tensorflow/python/saved_model/signature_constants.py +++ b/tensorflow/python/saved_model/signature_constants.py @@ -19,6 +19,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +# Key in the signature def map for `default` serving signatures. The default +# signature is used in inference requests where a specific signature was not +# specified. +DEFAULT_SERVING_SIGNATURE_DEF_KEY = "serving_default" + ################################################################################ # Classification API constants. |