aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/graph_def_util.cc
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2016-02-29 16:20:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-29 16:29:12 -0800
commitc1d385d01924cf02e81adabbdec575c100b14f59 (patch)
tree468afd64c246dd5532f7912f799cfa587240ccf4 /tensorflow/core/framework/graph_def_util.cc
parentd4b17f62fa09a10869d0f5d70e5afde833b23f0f (diff)
Expose tf.contrib.util.stripped_op_list_for_graph
C++ and Python use two different op registries, and in rare cases they can actually be different. Thus, we need both functions available. Also fix both Python and C++ to handle arbitrarily nested functions. Change: 115918836
Diffstat (limited to 'tensorflow/core/framework/graph_def_util.cc')
-rw-r--r--tensorflow/core/framework/graph_def_util.cc49
1 files changed, 9 insertions, 40 deletions
diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc
index 71281cf7b2..2eb23bdf9c 100644
--- a/tensorflow/core/framework/graph_def_util.cc
+++ b/tensorflow/core/framework/graph_def_util.cc
@@ -16,8 +16,6 @@ limitations under the License.
#include "tensorflow/core/framework/graph_def_util.h"
#include <set>
-#include <unordered_map>
-#include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/node_def_util.h"
@@ -123,49 +121,20 @@ Status StrippedOpListForGraph(const GraphDef& graph_def,
OpList* stripped_op_list) {
stripped_op_list->clear_op();
- // Map function names to definitions.
- std::unordered_map<string, const FunctionDef*> name_to_function;
- for (const auto& function : graph_def.library().function()) {
- name_to_function.insert(
- std::make_pair(function.signature().name(), &function));
- }
-
- // Collect the sorted list of op names. Since functions can reference
- // functions, we need a recursive traversal.
- std::set<string> used_ops; // Includes both primitive ops and functions
- std::vector<const FunctionDef*> functions_to_process; // A subset of used_ops
- // Collect the logic to mark an op in a lambda; it'll be used twice below.
- const auto mark_op_as_used = [&used_ops, &functions_to_process,
- &name_to_function](const string& op) {
- if (used_ops.insert(op).second) {
- // If it's a function, we'll need to process further
- const auto it = name_to_function.find(op);
- if (it != name_to_function.end()) {
- functions_to_process.push_back(it->second);
- }
- }
- };
+ // Collect the sorted list of op names
+ std::set<string> used_ops;
for (const auto& node : graph_def.node()) {
- mark_op_as_used(node.op());
- }
- while (!functions_to_process.empty()) {
- const FunctionDef* fun = functions_to_process.back();
- functions_to_process.pop_back();
- for (const auto& node : fun->node()) {
- mark_op_as_used(node.op());
- }
+ used_ops.insert(node.op());
}
- // Build the stripped op list in sorted order, ignoring functions.
+ // Build the stripped op list in sorted order.
Status status;
for (const string& op_name : used_ops) {
- if (name_to_function.find(op_name) == name_to_function.end()) {
- const OpDef* op = op_registry.LookUp(op_name, &status);
- if (!op) return status;
- OpDef* stripped_op = stripped_op_list->add_op();
- stripped_op->CopyFrom(*op);
- RemoveDescriptionsFromOpDef(stripped_op);
- }
+ const OpDef* op = op_registry.LookUp(op_name, &status);
+ if (!op) return status;
+ OpDef* stripped_op = stripped_op_list->add_op();
+ stripped_op->CopyFrom(*op);
+ RemoveDescriptionsFromOpDef(stripped_op);
}
return Status::OK();
}