diff options
author | Sukriti Ramesh <sukritiramesh@google.com> | 2017-03-20 10:27:42 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-20 11:45:50 -0700 |
commit | 4001bd52f4655e749714c1f9ccddefcdf5e6d855 (patch) | |
tree | 786852d919ac5b072ae087a6f225e890af4a6f68 /tensorflow/cc/saved_model | |
parent | 92eb06d2e4d98bcd58c8f4d7c68de0d3c637e181 (diff) |
Add initial support for main-op in SavedModel C++.
Change: 150651883
Diffstat (limited to 'tensorflow/cc/saved_model')
-rw-r--r-- | tensorflow/cc/saved_model/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/cc/saved_model/constants.h | 3 | ||||
-rw-r--r-- | tensorflow/cc/saved_model/loader.cc | 45 | ||||
-rw-r--r-- | tensorflow/cc/saved_model/loader_test.cc | 14 | ||||
-rw-r--r-- | tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/assets/foo.txt | 1 | ||||
-rw-r--r-- | tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/saved_model.pb | bin | 0 -> 10095 bytes | |||
-rw-r--r-- | tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.data-00000-of-00001 | bin | 0 -> 12 bytes | |||
-rw-r--r-- | tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.index | bin | 0 -> 151 bytes |
8 files changed, 59 insertions, 5 deletions
diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 36fec7a2f2..b402570757 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -66,6 +66,7 @@ filegroup( name = "saved_model_half_plus_two", srcs = glob([ "testdata/half_plus_two_pbtxt/**", + "testdata/half_plus_two_main_op/**", "testdata/half_plus_two/**", ]), ) diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h index 7f2d560978..94a3b3cf46 100644 --- a/tensorflow/cc/saved_model/constants.h +++ b/tensorflow/cc/saved_model/constants.h @@ -33,6 +33,9 @@ constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt"; /// SavedModel legacy init op key. constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op"; +/// SavedModel main op key. +constexpr char kSavedModelMainOpKey[] = "saved_model_main_op"; + /// Directory in which to save the SavedModel variables. constexpr char kSavedModelVariablesDirectory[] = "variables"; diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index a7843ddb1d..b144bfc33e 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -106,6 +106,37 @@ void AddAssetsTensorsToInputs(const StringPiece export_dir, } } +bool HasMainOp(const MetaGraphDef& meta_graph_def) { + const auto& collection_def_map = meta_graph_def.collection_def(); + if (collection_def_map.find(kSavedModelMainOpKey) != + collection_def_map.end()) { + return true; + } + return false; +} + +Status RunMainOp(const RunOptions& run_options, const string& export_dir, + const MetaGraphDef& meta_graph_def, + const std::vector<AssetFileDef>& asset_file_defs, + Session* session) { + LOG(INFO) << "Running MainOp on SavedModel bundle."; + const auto& collection_def_map = meta_graph_def.collection_def(); + const auto main_op_it = collection_def_map.find(kSavedModelMainOpKey); + if (main_op_it != collection_def_map.end()) { + if (main_op_it->second.node_list().value_size() != 1) { + return errors::FailedPrecondition( + strings::StrCat("Expected exactly one main op in : ", export_dir)); + } + std::vector<std::pair<string, Tensor>> inputs; + AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); + RunMetadata run_metadata; + const StringPiece main_op_name = main_op_it->second.node_list().value(0); + return session->Run(run_options, inputs, {}, {main_op_name.ToString()}, + nullptr /* outputs */, &run_metadata); + } + return Status::OK(); +} + Status RunRestore(const RunOptions& run_options, const string& export_dir, const StringPiece restore_op_name, const StringPiece variable_filename_const_op_name, @@ -211,11 +242,15 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, bundle->meta_graph_def.saver_def().restore_op_name(), bundle->meta_graph_def.saver_def().filename_tensor_name(), asset_file_defs, bundle->session.get())); - // TODO(sukritiramesh): Add support for a single main op to run upon load, - // which will supersede the legacy_init_op and separate RunRestore. - TF_RETURN_IF_ERROR(RunLegacyInitOp(run_options, export_dir, - bundle->meta_graph_def, asset_file_defs, - bundle->session.get())); + if (HasMainOp(bundle->meta_graph_def)) { + TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir, + bundle->meta_graph_def, asset_file_defs, + bundle->session.get())); + } else { + TF_RETURN_IF_ERROR(RunLegacyInitOp(run_options, export_dir, + bundle->meta_graph_def, asset_file_defs, + bundle->session.get())); + } return Status::OK(); } diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 2a8a7c5bff..cef29e7b07 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -31,6 +31,8 @@ namespace { constexpr char kTestDataPbTxt[] = "cc/saved_model/testdata/half_plus_two_pbtxt/00000123"; +constexpr char kTestDataMainOp[] = + "cc/saved_model/testdata/half_plus_two_main_op/00000123"; constexpr char kTestDataSharded[] = "cc/saved_model/testdata/half_plus_two/00000123"; @@ -165,6 +167,18 @@ TEST_F(LoaderTest, PbtxtFormat) { CheckSavedModelBundle(export_dir, bundle); } +TEST_F(LoaderTest, MainOpFormat) { + SavedModelBundle bundle; + SessionOptions session_options; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataMainOp); + TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle)); + CheckSavedModelBundle(export_dir, bundle); +} + TEST_F(LoaderTest, InvalidExportPath) { SavedModelBundle bundle; RunOptions run_options; diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/assets/foo.txt b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/assets/foo.txt new file mode 100644 index 0000000000..f9ff036688 --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/assets/foo.txt @@ -0,0 +1 @@ +asset-file-contents
\ No newline at end of file diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/saved_model.pb b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/saved_model.pb Binary files differnew file mode 100644 index 0000000000..cf6234821a --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/saved_model.pb diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.data-00000-of-00001 Binary files differnew file mode 100644 index 0000000000..15b75d6ef6 --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.data-00000-of-00001 diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.index b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.index Binary files differnew file mode 100644 index 0000000000..7ec9fb4fe2 --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.index |