aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/saved_model
diff options
context:
space:
mode:
authorGravatar Sukriti Ramesh <sukritiramesh@google.com>2017-03-20 10:27:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-20 11:45:50 -0700
commit4001bd52f4655e749714c1f9ccddefcdf5e6d855 (patch)
tree786852d919ac5b072ae087a6f225e890af4a6f68 /tensorflow/cc/saved_model
parent92eb06d2e4d98bcd58c8f4d7c68de0d3c637e181 (diff)
Add initial support for main-op in SavedModel C++.
Change: 150651883
Diffstat (limited to 'tensorflow/cc/saved_model')
-rw-r--r--tensorflow/cc/saved_model/BUILD1
-rw-r--r--tensorflow/cc/saved_model/constants.h3
-rw-r--r--tensorflow/cc/saved_model/loader.cc45
-rw-r--r--tensorflow/cc/saved_model/loader_test.cc14
-rw-r--r--tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/assets/foo.txt1
-rw-r--r--tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/saved_model.pbbin0 -> 10095 bytes
-rw-r--r--tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.data-00000-of-00001bin0 -> 12 bytes
-rw-r--r--tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.indexbin0 -> 151 bytes
8 files changed, 59 insertions, 5 deletions
diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD
index 36fec7a2f2..b402570757 100644
--- a/tensorflow/cc/saved_model/BUILD
+++ b/tensorflow/cc/saved_model/BUILD
@@ -66,6 +66,7 @@ filegroup(
name = "saved_model_half_plus_two",
srcs = glob([
"testdata/half_plus_two_pbtxt/**",
+ "testdata/half_plus_two_main_op/**",
"testdata/half_plus_two/**",
]),
)
diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h
index 7f2d560978..94a3b3cf46 100644
--- a/tensorflow/cc/saved_model/constants.h
+++ b/tensorflow/cc/saved_model/constants.h
@@ -33,6 +33,9 @@ constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
/// SavedModel legacy init op key.
constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op";
+/// SavedModel main op key.
+constexpr char kSavedModelMainOpKey[] = "saved_model_main_op";
+
/// Directory in which to save the SavedModel variables.
constexpr char kSavedModelVariablesDirectory[] = "variables";
diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc
index a7843ddb1d..b144bfc33e 100644
--- a/tensorflow/cc/saved_model/loader.cc
+++ b/tensorflow/cc/saved_model/loader.cc
@@ -106,6 +106,37 @@ void AddAssetsTensorsToInputs(const StringPiece export_dir,
}
}
+bool HasMainOp(const MetaGraphDef& meta_graph_def) {
+ const auto& collection_def_map = meta_graph_def.collection_def();
+ if (collection_def_map.find(kSavedModelMainOpKey) !=
+ collection_def_map.end()) {
+ return true;
+ }
+ return false;
+}
+
+Status RunMainOp(const RunOptions& run_options, const string& export_dir,
+ const MetaGraphDef& meta_graph_def,
+ const std::vector<AssetFileDef>& asset_file_defs,
+ Session* session) {
+ LOG(INFO) << "Running MainOp on SavedModel bundle.";
+ const auto& collection_def_map = meta_graph_def.collection_def();
+ const auto main_op_it = collection_def_map.find(kSavedModelMainOpKey);
+ if (main_op_it != collection_def_map.end()) {
+ if (main_op_it->second.node_list().value_size() != 1) {
+ return errors::FailedPrecondition(
+ strings::StrCat("Expected exactly one main op in : ", export_dir));
+ }
+ std::vector<std::pair<string, Tensor>> inputs;
+ AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
+ RunMetadata run_metadata;
+ const StringPiece main_op_name = main_op_it->second.node_list().value(0);
+ return session->Run(run_options, inputs, {}, {main_op_name.ToString()},
+ nullptr /* outputs */, &run_metadata);
+ }
+ return Status::OK();
+}
+
Status RunRestore(const RunOptions& run_options, const string& export_dir,
const StringPiece restore_op_name,
const StringPiece variable_filename_const_op_name,
@@ -211,11 +242,15 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
bundle->meta_graph_def.saver_def().restore_op_name(),
bundle->meta_graph_def.saver_def().filename_tensor_name(),
asset_file_defs, bundle->session.get()));
- // TODO(sukritiramesh): Add support for a single main op to run upon load,
- // which will supersede the legacy_init_op and separate RunRestore.
- TF_RETURN_IF_ERROR(RunLegacyInitOp(run_options, export_dir,
- bundle->meta_graph_def, asset_file_defs,
- bundle->session.get()));
+ if (HasMainOp(bundle->meta_graph_def)) {
+ TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir,
+ bundle->meta_graph_def, asset_file_defs,
+ bundle->session.get()));
+ } else {
+ TF_RETURN_IF_ERROR(RunLegacyInitOp(run_options, export_dir,
+ bundle->meta_graph_def, asset_file_defs,
+ bundle->session.get()));
+ }
return Status::OK();
}
diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc
index 2a8a7c5bff..cef29e7b07 100644
--- a/tensorflow/cc/saved_model/loader_test.cc
+++ b/tensorflow/cc/saved_model/loader_test.cc
@@ -31,6 +31,8 @@ namespace {
constexpr char kTestDataPbTxt[] =
"cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
+constexpr char kTestDataMainOp[] =
+ "cc/saved_model/testdata/half_plus_two_main_op/00000123";
constexpr char kTestDataSharded[] =
"cc/saved_model/testdata/half_plus_two/00000123";
@@ -165,6 +167,18 @@ TEST_F(LoaderTest, PbtxtFormat) {
CheckSavedModelBundle(export_dir, bundle);
}
+TEST_F(LoaderTest, MainOpFormat) {
+ SavedModelBundle bundle;
+ SessionOptions session_options;
+ RunOptions run_options;
+
+ const string export_dir =
+ io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataMainOp);
+ TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
+ {kSavedModelTagServe}, &bundle));
+ CheckSavedModelBundle(export_dir, bundle);
+}
+
TEST_F(LoaderTest, InvalidExportPath) {
SavedModelBundle bundle;
RunOptions run_options;
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/assets/foo.txt b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/assets/foo.txt
new file mode 100644
index 0000000000..f9ff036688
--- /dev/null
+++ b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/assets/foo.txt
@@ -0,0 +1 @@
+asset-file-contents \ No newline at end of file
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/saved_model.pb b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/saved_model.pb
new file mode 100644
index 0000000000..cf6234821a
--- /dev/null
+++ b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/saved_model.pb
Binary files differ
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.data-00000-of-00001
new file mode 100644
index 0000000000..15b75d6ef6
--- /dev/null
+++ b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.data-00000-of-00001
Binary files differ
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.index b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.index
new file mode 100644
index 0000000000..7ec9fb4fe2
--- /dev/null
+++ b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.index
Binary files differ