aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Christopher Olston <olston@google.com>2016-11-30 09:21:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-30 09:43:07 -0800
commitab8fb0bb40e3ec729db00f552f0bad3c8a118f87 (patch)
tree16f73da3b6b20927c033de7bd1e2cc1bfc3b4b38
parent5648e4e7db6db59c4a3ec59f2321f204deb24b0f (diff)
Have ConvertSignaturesToSignatureDef() handle the case in which there are both named and default signatures.
Also, add some ASSERTs to avoid segfaulting when a test fails. Change: 140612556
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim.cc23
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim_test.cc42
2 files changed, 57 insertions, 8 deletions
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.cc b/tensorflow/contrib/session_bundle/bundle_shim.cc
index 47a0935472..1ce2753c57 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim.cc
+++ b/tensorflow/contrib/session_bundle/bundle_shim.cc
@@ -127,10 +127,16 @@ Status ConvertNamedSignaturesToSignatureDef(const Signatures& signatures,
AddOutputToSignatureDef(map_entry.second.tensor_name(), map_entry.first,
&signature_def);
}
- // Add the `default` key to the signature def map of the meta graph def and
- // map it to the constructed signature def.
- (*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] =
- signature_def;
+ // Add the constructed signature def to the signature def map of the meta
+ // graph def. Use the default key if it isn't already in use.
+ const bool already_has_default_signature =
+ meta_graph_def->signature_def().find(kDefaultServingSignatureDefKey) !=
+ meta_graph_def->signature_def().end();
+ const string signature_def_key =
+ already_has_default_signature
+ ? strings::StrCat(kDefaultServingSignatureDefKey, "_from_named")
+ : kDefaultServingSignatureDefKey;
+ (*meta_graph_def->mutable_signature_def())[signature_def_key] = signature_def;
return Status::OK();
}
@@ -138,9 +144,12 @@ Status ConvertSignaturesToSignatureDef(MetaGraphDef* meta_graph_def) {
Signatures signatures;
GetSignatures(*meta_graph_def, &signatures);
if (signatures.has_default_signature()) {
- return ConvertDefaultSignatureToSignatureDef(signatures, meta_graph_def);
- } else if (!signatures.named_signatures().empty()) {
- return ConvertNamedSignaturesToSignatureDef(signatures, meta_graph_def);
+ TF_RETURN_IF_ERROR(
+ ConvertDefaultSignatureToSignatureDef(signatures, meta_graph_def));
+ }
+ if (!signatures.named_signatures().empty()) {
+ TF_RETURN_IF_ERROR(
+ ConvertNamedSignaturesToSignatureDef(signatures, meta_graph_def));
}
return Status::OK();
}
diff --git a/tensorflow/contrib/session_bundle/bundle_shim_test.cc b/tensorflow/contrib/session_bundle/bundle_shim_test.cc
index 81e636fb2e..a8dca12195 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim_test.cc
+++ b/tensorflow/contrib/session_bundle/bundle_shim_test.cc
@@ -74,13 +74,18 @@ void LoadAndValidateSavedModelBundle(const string& export_dir,
const auto& signature_def_map = meta_graph_def.signature_def();
const auto& regression_entry = signature_def_map.find(signature_def_key);
+ ASSERT_FALSE(regression_entry == signature_def_map.end());
SignatureDef regression_signature_def = regression_entry->second;
EXPECT_EQ(1, regression_signature_def.inputs_size());
+ ASSERT_FALSE(regression_signature_def.inputs().find(kRegressInputs) ==
+ regression_signature_def.inputs().end());
TensorInfo input_tensor_info =
regression_signature_def.inputs().find(kRegressInputs)->second;
EXPECT_EQ(1, regression_signature_def.outputs_size());
+ ASSERT_FALSE(regression_signature_def.outputs().find(kRegressOutputs) ==
+ regression_signature_def.outputs().end());
TensorInfo output_tensor_info =
regression_signature_def.outputs().find(kRegressOutputs)->second;
ValidateHalfPlusTwo(saved_model_bundle, input_tensor_info.name(),
@@ -260,9 +265,14 @@ TEST(BundleShimTest, NamedSignatureGenericInputsAndOutputs) {
EXPECT_EQ(1, meta_graph_def.signature_def_size());
const auto actual_signature_def =
meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey);
+ ASSERT_FALSE(actual_signature_def == meta_graph_def.signature_def().end());
+ ASSERT_FALSE(actual_signature_def->second.inputs().find("foo-input") ==
+ actual_signature_def->second.inputs().end());
EXPECT_EQ(
"foo-input",
actual_signature_def->second.inputs().find("foo-input")->second.name());
+ ASSERT_FALSE(actual_signature_def->second.outputs().find("foo-output") ==
+ actual_signature_def->second.outputs().end());
EXPECT_EQ(
"foo-output",
actual_signature_def->second.outputs().find("foo-output")->second.name());
@@ -317,10 +327,40 @@ TEST(BundleShimTest, NamedSignatureGenericOnlyInput) {
// Checks a basic up conversion for half plus two for SessionBundle.
TEST(BundleShimTest, BasicExportSessionBundle) {
+ const std::unordered_set<string> tags = {"tag"};
const string session_bundle_export_dir =
test_util::TestSrcDirPath(kSessionBundlePath);
- LoadAndValidateSavedModelBundle(session_bundle_export_dir, {"tag"},
+ LoadAndValidateSavedModelBundle(session_bundle_export_dir, tags,
kDefaultServingSignatureDefKey);
+
+ // Verify that the named signature is also present.
+ SessionOptions session_options;
+ RunOptions run_options;
+ SavedModelBundle saved_model_bundle;
+ TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(session_options, run_options,
+ session_bundle_export_dir,
+ tags, &saved_model_bundle));
+ const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def;
+ const auto& signature_def_map = meta_graph_def.signature_def();
+ bool found_named_signature = false;
+ for (const auto& entry : signature_def_map) {
+ const string& key = entry.first;
+ const SignatureDef& signature_def = entry.second;
+
+ // We're looking for the key that is *not* kDefaultServingSignatureDefKey.
+ if (key == kDefaultServingSignatureDefKey) {
+ continue;
+ }
+ found_named_signature = true;
+
+ EXPECT_EQ(1, signature_def.inputs_size());
+ EXPECT_FALSE(signature_def.inputs().find("x") ==
+ signature_def.inputs().end());
+ EXPECT_EQ(1, signature_def.outputs_size());
+ EXPECT_FALSE(signature_def.outputs().find("y") ==
+ signature_def.outputs().end());
+ }
+ EXPECT_TRUE(found_named_signature);
}
// Checks a basic load for half plus two for SavedModelBundle.