aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/session_bundle
diff options
context:
space:
mode:
authorGravatar Abhijit Karmarkar <awk@google.com>2018-09-20 22:18:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 22:22:43 -0700
commit23552a8b2f2a92a31710b9339e6ade514ac25996 (patch)
treee3669169491cab7ed014dc244ea2a0c214914184 /tensorflow/contrib/session_bundle
parentf10b00558de87020554c9c0512537dab96dba918 (diff)
Return model format from LoadSessionBundleOrSavedModelBundle(),
allowing callers to know if we up-converted a SessionBundle to SavedModel format. PiperOrigin-RevId: 213937542
Diffstat (limited to 'tensorflow/contrib/session_bundle')
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim.cc9
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim.h6
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim_test.cc14
3 files changed, 22 insertions, 7 deletions
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.cc b/tensorflow/contrib/session_bundle/bundle_shim.cc
index 4fc36d85ed..c669ced997 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim.cc
+++ b/tensorflow/contrib/session_bundle/bundle_shim.cc
@@ -355,11 +355,15 @@ Status LoadSessionBundleOrSavedModelBundle(
const SessionOptions& session_options, const RunOptions& run_options,
const string& export_dir,
const std::unordered_set<string>& saved_model_tags,
- SavedModelBundle* saved_model_bundle) {
+ SavedModelBundle* saved_model_bundle, bool* is_session_bundle) {
+ if (is_session_bundle != nullptr) {
+ *is_session_bundle = false;
+ }
if (MaybeSavedModelDirectory(export_dir)) {
LOG(INFO)
<< "Attempting to load native SavedModelBundle in bundle-shim from: "
<< export_dir;
+
return LoadSavedModel(session_options, run_options, export_dir,
saved_model_tags, saved_model_bundle);
} else if (IsPossibleExportDirectory(export_dir)) {
@@ -368,6 +372,9 @@ Status LoadSessionBundleOrSavedModelBundle(
LOG(INFO) << "Attempting to up-convert SessionBundle to SavedModelBundle "
"in bundle-shim from: "
<< export_dir;
+ if (is_session_bundle != nullptr) {
+ *is_session_bundle = true;
+ }
return LoadSavedModelFromLegacySessionBundlePath(
session_options, run_options, export_dir, saved_model_bundle);
}
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.h b/tensorflow/contrib/session_bundle/bundle_shim.h
index 4628b6ab1b..7f0f9958d7 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim.h
+++ b/tensorflow/contrib/session_bundle/bundle_shim.h
@@ -59,11 +59,13 @@ Status ConvertSessionBundleToSavedModelBundle(
} // namespace internal
// Loads a SavedModel from either a session-bundle path or a SavedModel bundle
-// path.
+// path. If `is_session_bundle` is not a nullptr, sets it to `true` iff
+// SavedModel was up-converted and loaded from a SessionBundle.
+// `is_session_bundle` value should not be used if error is returned.
Status LoadSessionBundleOrSavedModelBundle(
const SessionOptions& session_options, const RunOptions& run_options,
const string& export_dir, const std::unordered_set<string>& tags,
- SavedModelBundle* bundle);
+ SavedModelBundle* bundle, bool* is_session_bundle = nullptr);
} // namespace serving
} // namespace tensorflow
diff --git a/tensorflow/contrib/session_bundle/bundle_shim_test.cc b/tensorflow/contrib/session_bundle/bundle_shim_test.cc
index 9a1dd9303f..815beb73a0 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim_test.cc
+++ b/tensorflow/contrib/session_bundle/bundle_shim_test.cc
@@ -63,12 +63,16 @@ void ValidateHalfPlusTwo(const SavedModelBundle& saved_model_bundle,
void LoadAndValidateSavedModelBundle(const string& export_dir,
const std::unordered_set<string>& tags,
- const string& signature_def_key) {
+ const string& signature_def_key,
+ bool expect_session_bundle) {
SessionOptions session_options;
RunOptions run_options;
SavedModelBundle saved_model_bundle;
+ bool is_session_bundle = false;
TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(
- session_options, run_options, export_dir, tags, &saved_model_bundle));
+ session_options, run_options, export_dir, tags, &saved_model_bundle,
+ &is_session_bundle));
+ EXPECT_EQ(expect_session_bundle, is_session_bundle);
const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def;
const auto& signature_def_map = meta_graph_def.signature_def();
@@ -512,7 +516,8 @@ TEST(BundleShimTest, BasicExportSessionBundle) {
const string session_bundle_export_dir =
test_util::TestSrcDirPath(kSessionBundlePath);
LoadAndValidateSavedModelBundle(session_bundle_export_dir, tags,
- kDefaultServingSignatureDefKey);
+ kDefaultServingSignatureDefKey,
+ /*expect_session_bundle=*/true);
// Verify that the named signature is also present.
SessionOptions session_options;
@@ -558,7 +563,8 @@ TEST(BundleShimTest, BasicExportSavedModel) {
const string saved_model_bundle_export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kSavedModelBundlePath);
LoadAndValidateSavedModelBundle(saved_model_bundle_export_dir,
- {kSavedModelTagServe}, "regress_x_to_y");
+ {kSavedModelTagServe}, "regress_x_to_y",
+ /*expect_session_bundle=*/false);
}
// Checks a basic load fails with an invalid export path.