diff options
author | 2016-05-03 20:16:59 -0800 | |
---|---|---|
committer | 2016-05-03 21:22:00 -0700 | |
commit | df9e8046eb44b2ad7ceb14e2dd6a49741c0e12e0 (patch) | |
tree | 17ca815f5d2790eb13b7c5a03646b03842a49ebc /tensorflow/core/framework/graph_def_util.cc | |
parent | 47916d41d12dff93b57544c1a5d09b7dcae93d84 (diff) |
Extract OpsUsedByGraph() from StrippedOpListForGraph().
Change: 121447237
Diffstat (limited to 'tensorflow/core/framework/graph_def_util.cc')
-rw-r--r-- | tensorflow/core/framework/graph_def_util.cc | 36 |
1 files changed, 24 insertions, 12 deletions
diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc index 047847263c..9087b38fd9 100644 --- a/tensorflow/core/framework/graph_def_util.cc +++ b/tensorflow/core/framework/graph_def_util.cc @@ -120,11 +120,8 @@ Status RemoveNewDefaultAttrsFromGraphDef( return s; } -Status StrippedOpListForGraph(const GraphDef& graph_def, - const OpRegistryInterface& op_registry, - OpList* stripped_op_list) { - stripped_op_list->clear_op(); - +void OpsUsedByGraph(const GraphDef& graph_def, + std::set<string>* ops_used_in_graph) { // Map function names to definitions. std::unordered_map<string, const FunctionDef*> name_to_function; for (const auto& function : graph_def.library().function()) { @@ -158,17 +155,32 @@ Status StrippedOpListForGraph(const GraphDef& graph_def, } } - // Build the stripped op list in sorted order, ignoring functions. - Status status; + // Filter out function names to produce output. + // TODO(josh11b): Change the above code to produce this directly. + ops_used_in_graph->clear(); for (const string& op_name : used_ops) { if (name_to_function.find(op_name) == name_to_function.end()) { - const OpDef* op = op_registry.LookUp(op_name, &status); - if (!op) return status; - OpDef* stripped_op = stripped_op_list->add_op(); - stripped_op->CopyFrom(*op); - RemoveDescriptionsFromOpDef(stripped_op); + ops_used_in_graph->insert(op_name); } } +} + +Status StrippedOpListForGraph(const GraphDef& graph_def, + const OpRegistryInterface& op_registry, + OpList* stripped_op_list) { + std::set<string> used_ops; + OpsUsedByGraph(graph_def, &used_ops); + + // Build the stripped op list in sorted order, ignoring functions. + Status status; + stripped_op_list->clear_op(); + for (const string& op_name : used_ops) { + const OpDef* op = op_registry.LookUp(op_name, &status); + if (!op) return status; + OpDef* stripped_op = stripped_op_list->add_op(); + stripped_op->CopyFrom(*op); + RemoveDescriptionsFromOpDef(stripped_op); + } return Status::OK(); } |