diff options
author | 2017-08-15 13:31:50 -0700 | |
---|---|---|
committer | 2017-08-15 13:37:32 -0700 | |
commit | d4bc0ef0647823a207bb19cc6ba03bdafc56c3e4 (patch) | |
tree | bd3b1e434ec95f03a77117386d760a947744a50a /tensorflow/core/grappler/grappler_item_builder.cc | |
parent | fed992fb2ad6c39a7241b57093e7918f433c3e89 (diff) |
Add save and restore op in grappler item;
PiperOrigin-RevId: 165350681
Diffstat (limited to 'tensorflow/core/grappler/grappler_item_builder.cc')
-rw-r--r-- | tensorflow/core/grappler/grappler_item_builder.cc | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 1002e89417..6136651410 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saver.pb.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { @@ -331,6 +332,32 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( } } + if (meta_graph.collection_def().count("savers") > 0) { + const CollectionDef& savers = meta_graph.collection_def().at("savers"); + for (const auto& raw : savers.bytes_list().value()) { + SaverDef saver; + // Skip bad savers since we don't need saves/restores to be able to run a + // graph. + if (!saver.ParseFromString(raw)) { + continue; + } + if (saver.filename_tensor_name().empty()) { + continue; + } + new_item->save_op = saver.save_tensor_name(); + new_item->restore_op = saver.restore_op_name(); + new_item->save_restore_loc_tensor = saver.filename_tensor_name(); + // Only use the first saver since it's not clear what to do if there's + // more than one. + break; + } + } else { + const SaverDef& saver = meta_graph.saver_def(); + new_item->save_op = saver.save_tensor_name(); + new_item->restore_op = saver.restore_op_name(); + new_item->save_restore_loc_tensor = saver.filename_tensor_name(); + } + // Optimize the graph (function inlining, l1 optimizations, etc). Status optimize_status = OptimizeGraph(new_item->graph, &new_item->graph, cfg); |