diff options
author | 2016-12-22 15:21:44 -0800 | |
---|---|---|
committer | 2016-12-22 15:45:02 -0800 | |
commit | 1e5bd8cdd62033d1f7ea928fcbec521bb48bb1f5 (patch) | |
tree | 964182a523bfb596bcab306f8f2bd7de8087d246 | |
parent | aa112f20647acd2298e3752cea7e370d5841b032 (diff) |
Introduce a shim function wrapping checks for possible SessionBundle or
SavedModel export formats.
Change: 142803996
-rw-r--r-- | tensorflow/contrib/session_bundle/bundle_shim.cc | 5 | ||||
-rw-r--r-- | tensorflow/contrib/session_bundle/bundle_shim.h | 8 | ||||
-rw-r--r-- | tensorflow/contrib/session_bundle/bundle_shim_test.cc | 12 |
3 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.cc b/tensorflow/contrib/session_bundle/bundle_shim.cc index d9ee56fcc9..4a3499bcc6 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim.cc +++ b/tensorflow/contrib/session_bundle/bundle_shim.cc @@ -292,5 +292,10 @@ Status LoadSessionBundleOrSavedModelBundle( "export location"); } +bool MaybeSessionBundleOrSavedModelDirectory(const string& export_dir) { + return IsPossibleExportDirectory(export_dir) || + MaybeSavedModelDirectory(export_dir); +} + } // namespace serving } // namespace tensorflow diff --git a/tensorflow/contrib/session_bundle/bundle_shim.h b/tensorflow/contrib/session_bundle/bundle_shim.h index 37c242c6ea..4a614542e0 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim.h +++ b/tensorflow/contrib/session_bundle/bundle_shim.h @@ -59,6 +59,14 @@ Status LoadSessionBundleOrSavedModelBundle( const string& export_dir, const std::unordered_set<string>& tags, SavedModelBundle* bundle); +// Checks whether the provided directory could contain a SessionBundle or a +// SavedModel. Note that the method does not load any data by itself. If the +// method returns `false`, the export directory definitely does not contain a +// SessionBundle or SavedModel. If the method returns `true`, the export +// directory may contain a SessionBundle or a SavedModel but provides no +// guarantee that it can be loaded. +bool MaybeSessionBundleOrSavedModelDirectory(const string& export_dir); + } // namespace serving } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_BUNDLE_SHIM_H_ diff --git a/tensorflow/contrib/session_bundle/bundle_shim_test.cc b/tensorflow/contrib/session_bundle/bundle_shim_test.cc index fb367beb0f..22b3151527 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim_test.cc +++ b/tensorflow/contrib/session_bundle/bundle_shim_test.cc @@ -495,6 +495,18 @@ TEST(BundleShimTest, LoadSessionBundleError) { .ok()); } +TEST(BundleShimTest, MaybeSessionBundleOrSavedModelDirectory) { + const string saved_model_export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kSavedModelBundlePath); + const string session_bundle_export_dir = + test_util::TestSrcDirPath(kSessionBundlePath); + const string invalid_export_dir = testing::TensorFlowSrcRoot(); + EXPECT_TRUE(MaybeSessionBundleOrSavedModelDirectory(saved_model_export_dir)); + EXPECT_TRUE( + MaybeSessionBundleOrSavedModelDirectory(session_bundle_export_dir)); + EXPECT_FALSE(MaybeSessionBundleOrSavedModelDirectory(invalid_export_dir)); +} + } // namespace } // namespace internal } // namespace serving |