aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/utils.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-02 16:25:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-02 16:28:52 -0800
commit4df167ac55346357afd612d15674c7556e21ab00 (patch)
treef9d048371157604c02538ec090dbc73ab90ef4c2 /tensorflow/core/grappler/utils.cc
parent284dac189dcae46c77f1ec70055b13e69c31e4c0 (diff)
Loop optimizer: Convert StackPush nodes to Identity instead of eliminating them completely.
Move loop optimizer to run before dependency optimizer so identity nodes will be pruned. PiperOrigin-RevId: 187685669
Diffstat (limited to 'tensorflow/core/grappler/utils.cc')
-rw-r--r--tensorflow/core/grappler/utils.cc8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index a611a93086..eb1f882ff1 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -398,12 +398,12 @@ Status SimpleGraphView::Initialize(const GraphDef& graph, bool dedup_inputs,
void SimpleGraphView::DepthFirstSearch(
const std::unordered_set<string>& op_types_to_traverse, int node_idx,
std::set<int>* nodes_found) const {
- const NodeDef& node = graph_->node(node_idx);
- if (op_types_to_traverse.find(node.op()) == op_types_to_traverse.end()) {
- nodes_found->insert(node_idx);
+ if (nodes_found->find(node_idx) != nodes_found->end()) {
return;
}
- if (nodes_found->find(node_idx) != nodes_found->end()) {
+ nodes_found->insert(node_idx);
+ const string& op_type = graph_->node(node_idx).op();
+ if (op_types_to_traverse.find(op_type) == op_types_to_traverse.end()) {
return;
}
for (auto output_idx : this->outputs(node_idx)) {