diff options
Diffstat (limited to 'tensorflow/cc/saved_model/loader.cc')
-rw-r--r-- | tensorflow/cc/saved_model/loader.cc | 70 |
1 files changed, 6 insertions, 64 deletions
diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index faa1e378d0..07807ed2f3 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -18,8 +18,10 @@ limitations under the License. #include <unordered_set> #include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf_internal.h" @@ -43,56 +45,6 @@ auto* load_latency = monitoring::Counter<1>::New( constexpr char kLoadAttemptFail[] = "fail"; constexpr char kLoadAttemptSuccess[] = "success"; -Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) { - const string saved_model_pb_path = - io::JoinPath(export_dir, kSavedModelFilenamePb); - if (Env::Default()->FileExists(saved_model_pb_path).ok()) { - return ReadBinaryProto(Env::Default(), saved_model_pb_path, - saved_model_proto); - } - const string saved_model_pbtxt_path = - io::JoinPath(export_dir, kSavedModelFilenamePbTxt); - if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) { - return ReadTextProto(Env::Default(), saved_model_pbtxt_path, - saved_model_proto); - } - return Status(error::Code::NOT_FOUND, - "Could not find SavedModel .pb or .pbtxt at supplied export " - "directory path: " + - export_dir); -} - -string GetTagsAsString(const std::unordered_set<string>& tags) { - string tags_as_string = "{ "; - for (const string& tag : tags) { - tags_as_string = strings::StrCat(tags_as_string, tag, " "); - } - tags_as_string = strings::StrCat(tags_as_string, "}"); - return tags_as_string; -} - -Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, - const std::unordered_set<string>& tags, - MetaGraphDef* meta_graph_def_to_load) { - for (const MetaGraphDef& meta_graph_def : saved_model_proto.meta_graphs()) { - // Get tags from the meta_graph_def. - std::unordered_set<string> graph_tags; - for (const string& tag : meta_graph_def.meta_info_def().tags()) { - graph_tags.insert(tag); - } - // Match with the set of tags provided. - if (graph_tags == tags) { - *meta_graph_def_to_load = meta_graph_def; - return Status::OK(); - } - } - return Status(error::Code::NOT_FOUND, - "Could not find meta graph def matching supplied tags: " + - GetTagsAsString(tags) + - ". To inspect available tag-sets in the SavedModel, please " - "use the SavedModel CLI: `saved_model_cli`"); -} - Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, const SessionOptions& session_options, std::unique_ptr<Session>* session) { @@ -235,18 +187,8 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, const string& export_dir, const std::unordered_set<string>& tags, SavedModelBundle* const bundle) { - if (!MaybeSavedModelDirectory(export_dir)) { - return Status(error::Code::NOT_FOUND, - "SavedModel not found in export directory: " + export_dir); - } - LOG(INFO) << "Loading SavedModel with tags: " << GetTagsAsString(tags) - << "; from: " << export_dir; - - SavedModel saved_model_proto; - TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto)); - - TF_RETURN_IF_ERROR( - FindMetaGraphDefToLoad(saved_model_proto, tags, &bundle->meta_graph_def)); + TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags, + &bundle->meta_graph_def)); TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession( bundle->meta_graph_def, session_options, &bundle->session)); @@ -288,8 +230,8 @@ Status LoadSavedModel(const SessionOptions& session_options, return end_microseconds - start_microseconds; }(); auto log_and_count = [&](const string& status_str) { - LOG(INFO) << "SavedModel load for tags " << GetTagsAsString(tags) - << "; Status: " << status_str << ". Took " + LOG(INFO) << "SavedModel load for tags { " << str_util::Join(tags, " ") + << " }; Status: " << status_str << ". Took " << load_latency_microsecs << " microseconds."; load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1); }; |