diff options
author | 2018-03-02 16:25:21 -0800 | |
---|---|---|
committer | 2018-03-02 16:28:52 -0800 | |
commit | 4df167ac55346357afd612d15674c7556e21ab00 (patch) | |
tree | f9d048371157604c02538ec090dbc73ab90ef4c2 /tensorflow/core/grappler/utils.cc | |
parent | 284dac189dcae46c77f1ec70055b13e69c31e4c0 (diff) |
Loop optimizer: Convert StackPush nodes to Identity instead of eliminating them completely.
Move loop optimizer to run before dependency optimizer so identity nodes will be pruned.
PiperOrigin-RevId: 187685669
Diffstat (limited to 'tensorflow/core/grappler/utils.cc')
-rw-r--r-- | tensorflow/core/grappler/utils.cc | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index a611a93086..eb1f882ff1 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -398,12 +398,12 @@ Status SimpleGraphView::Initialize(const GraphDef& graph, bool dedup_inputs, void SimpleGraphView::DepthFirstSearch( const std::unordered_set<string>& op_types_to_traverse, int node_idx, std::set<int>* nodes_found) const { - const NodeDef& node = graph_->node(node_idx); - if (op_types_to_traverse.find(node.op()) == op_types_to_traverse.end()) { - nodes_found->insert(node_idx); + if (nodes_found->find(node_idx) != nodes_found->end()) { return; } - if (nodes_found->find(node_idx) != nodes_found->end()) { + nodes_found->insert(node_idx); + const string& op_type = graph_->node(node_idx).op(); + if (op_types_to_traverse.find(op_type) == op_types_to_traverse.end()) { return; } for (auto output_idx : this->outputs(node_idx)) { |