diff options
author | Abhijit Karmarkar <awk@google.com> | 2018-09-20 22:18:35 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-20 22:22:43 -0700 |
commit | 23552a8b2f2a92a31710b9339e6ade514ac25996 (patch) | |
tree | e3669169491cab7ed014dc244ea2a0c214914184 /tensorflow/contrib/session_bundle | |
parent | f10b00558de87020554c9c0512537dab96dba918 (diff) |
Return model format from LoadSessionBundleOrSavedModelBundle(),
allowing callers to know if we up-converted a SessionBundle to
SavedModel format.
PiperOrigin-RevId: 213937542
Diffstat (limited to 'tensorflow/contrib/session_bundle')
-rw-r--r-- | tensorflow/contrib/session_bundle/bundle_shim.cc | 9 | ||||
-rw-r--r-- | tensorflow/contrib/session_bundle/bundle_shim.h | 6 | ||||
-rw-r--r-- | tensorflow/contrib/session_bundle/bundle_shim_test.cc | 14 |
3 files changed, 22 insertions, 7 deletions
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.cc b/tensorflow/contrib/session_bundle/bundle_shim.cc index 4fc36d85ed..c669ced997 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim.cc +++ b/tensorflow/contrib/session_bundle/bundle_shim.cc @@ -355,11 +355,15 @@ Status LoadSessionBundleOrSavedModelBundle( const SessionOptions& session_options, const RunOptions& run_options, const string& export_dir, const std::unordered_set<string>& saved_model_tags, - SavedModelBundle* saved_model_bundle) { + SavedModelBundle* saved_model_bundle, bool* is_session_bundle) { + if (is_session_bundle != nullptr) { + *is_session_bundle = false; + } if (MaybeSavedModelDirectory(export_dir)) { LOG(INFO) << "Attempting to load native SavedModelBundle in bundle-shim from: " << export_dir; + return LoadSavedModel(session_options, run_options, export_dir, saved_model_tags, saved_model_bundle); } else if (IsPossibleExportDirectory(export_dir)) { @@ -368,6 +372,9 @@ Status LoadSessionBundleOrSavedModelBundle( LOG(INFO) << "Attempting to up-convert SessionBundle to SavedModelBundle " "in bundle-shim from: " << export_dir; + if (is_session_bundle != nullptr) { + *is_session_bundle = true; + } return LoadSavedModelFromLegacySessionBundlePath( session_options, run_options, export_dir, saved_model_bundle); } diff --git a/tensorflow/contrib/session_bundle/bundle_shim.h b/tensorflow/contrib/session_bundle/bundle_shim.h index 4628b6ab1b..7f0f9958d7 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim.h +++ b/tensorflow/contrib/session_bundle/bundle_shim.h @@ -59,11 +59,13 @@ Status ConvertSessionBundleToSavedModelBundle( } // namespace internal // Loads a SavedModel from either a session-bundle path or a SavedModel bundle -// path. +// path. If `is_session_bundle` is not a nullptr, sets it to `true` iff +// SavedModel was up-converted and loaded from a SessionBundle. +// `is_session_bundle` value should not be used if error is returned. Status LoadSessionBundleOrSavedModelBundle( const SessionOptions& session_options, const RunOptions& run_options, const string& export_dir, const std::unordered_set<string>& tags, - SavedModelBundle* bundle); + SavedModelBundle* bundle, bool* is_session_bundle = nullptr); } // namespace serving } // namespace tensorflow diff --git a/tensorflow/contrib/session_bundle/bundle_shim_test.cc b/tensorflow/contrib/session_bundle/bundle_shim_test.cc index 9a1dd9303f..815beb73a0 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim_test.cc +++ b/tensorflow/contrib/session_bundle/bundle_shim_test.cc @@ -63,12 +63,16 @@ void ValidateHalfPlusTwo(const SavedModelBundle& saved_model_bundle, void LoadAndValidateSavedModelBundle(const string& export_dir, const std::unordered_set<string>& tags, - const string& signature_def_key) { + const string& signature_def_key, + bool expect_session_bundle) { SessionOptions session_options; RunOptions run_options; SavedModelBundle saved_model_bundle; + bool is_session_bundle = false; TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle( - session_options, run_options, export_dir, tags, &saved_model_bundle)); + session_options, run_options, export_dir, tags, &saved_model_bundle, + &is_session_bundle)); + EXPECT_EQ(expect_session_bundle, is_session_bundle); const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def; const auto& signature_def_map = meta_graph_def.signature_def(); @@ -512,7 +516,8 @@ TEST(BundleShimTest, BasicExportSessionBundle) { const string session_bundle_export_dir = test_util::TestSrcDirPath(kSessionBundlePath); LoadAndValidateSavedModelBundle(session_bundle_export_dir, tags, - kDefaultServingSignatureDefKey); + kDefaultServingSignatureDefKey, + /*expect_session_bundle=*/true); // Verify that the named signature is also present. SessionOptions session_options; @@ -558,7 +563,8 @@ TEST(BundleShimTest, BasicExportSavedModel) { const string saved_model_bundle_export_dir = io::JoinPath(testing::TensorFlowSrcRoot(), kSavedModelBundlePath); LoadAndValidateSavedModelBundle(saved_model_bundle_export_dir, - {kSavedModelTagServe}, "regress_x_to_y"); + {kSavedModelTagServe}, "regress_x_to_y", + /*expect_session_bundle=*/false); } // Checks a basic load fails with an invalid export path. |