diff options
author | 2016-02-29 16:20:56 -0800 | |
---|---|---|
committer | 2016-02-29 16:29:12 -0800 | |
commit | c1d385d01924cf02e81adabbdec575c100b14f59 (patch) | |
tree | 468afd64c246dd5532f7912f799cfa587240ccf4 /tensorflow/core/framework/graph_def_util.cc | |
parent | d4b17f62fa09a10869d0f5d70e5afde833b23f0f (diff) |
Expose tf.contrib.util.stripped_op_list_for_graph
C++ and Python use two different op registries, and in rare cases they can
actually be different. Thus, we need both functions available.
Also fix both Python and C++ to handle arbitrarily nested functions.
Change: 115918836
Diffstat (limited to 'tensorflow/core/framework/graph_def_util.cc')
-rw-r--r-- | tensorflow/core/framework/graph_def_util.cc | 49 |
1 files changed, 9 insertions, 40 deletions
diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc index 71281cf7b2..2eb23bdf9c 100644 --- a/tensorflow/core/framework/graph_def_util.cc +++ b/tensorflow/core/framework/graph_def_util.cc @@ -16,8 +16,6 @@ limitations under the License. #include "tensorflow/core/framework/graph_def_util.h" #include <set> -#include <unordered_map> -#include <unordered_set> #include <vector> #include "tensorflow/core/framework/node_def_util.h" @@ -123,49 +121,20 @@ Status StrippedOpListForGraph(const GraphDef& graph_def, OpList* stripped_op_list) { stripped_op_list->clear_op(); - // Map function names to definitions. - std::unordered_map<string, const FunctionDef*> name_to_function; - for (const auto& function : graph_def.library().function()) { - name_to_function.insert( - std::make_pair(function.signature().name(), &function)); - } - - // Collect the sorted list of op names. Since functions can reference - // functions, we need a recursive traversal. - std::set<string> used_ops; // Includes both primitive ops and functions - std::vector<const FunctionDef*> functions_to_process; // A subset of used_ops - // Collect the logic to mark an op in a lambda; it'll be used twice below. - const auto mark_op_as_used = [&used_ops, &functions_to_process, - &name_to_function](const string& op) { - if (used_ops.insert(op).second) { - // If it's a function, we'll need to process further - const auto it = name_to_function.find(op); - if (it != name_to_function.end()) { - functions_to_process.push_back(it->second); - } - } - }; + // Collect the sorted list of op names + std::set<string> used_ops; for (const auto& node : graph_def.node()) { - mark_op_as_used(node.op()); - } - while (!functions_to_process.empty()) { - const FunctionDef* fun = functions_to_process.back(); - functions_to_process.pop_back(); - for (const auto& node : fun->node()) { - mark_op_as_used(node.op()); - } + used_ops.insert(node.op()); } - // Build the stripped op list in sorted order, ignoring functions. + // Build the stripped op list in sorted order. Status status; for (const string& op_name : used_ops) { - if (name_to_function.find(op_name) == name_to_function.end()) { - const OpDef* op = op_registry.LookUp(op_name, &status); - if (!op) return status; - OpDef* stripped_op = stripped_op_list->add_op(); - stripped_op->CopyFrom(*op); - RemoveDescriptionsFromOpDef(stripped_op); - } + const OpDef* op = op_registry.LookUp(op_name, &status); + if (!op) return status; + OpDef* stripped_op = stripped_op_list->add_op(); + stripped_op->CopyFrom(*op); + RemoveDescriptionsFromOpDef(stripped_op); } return Status::OK(); } |