diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 21 |
1 files changed, 15 insertions, 6 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 4fb2fe6883..4fed88d536 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2703,22 +2703,31 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { NodeDef* inner_function; TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function)); // Optimize only if: + // 0. inner_function is not in the preserve set, // 1. inner_function's Op is element-wise monotonic // 2. inner_function's output is not being consumed elsewhere. - if (IsElementWiseMonotonic(*inner_function) && - (NumNonControlOutputs(*inner_function, *ctx().node_map) == 1)) { + if (!IsInPreserveSet(*inner_function) && + IsElementWiseMonotonic(*inner_function) && + ctx().node_map->GetOutputs(inner_function->name()).size() == 1) { // Swap the first inputs of the inner function Op & the reduction Op. NodeDef* inner_input; TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input)); - inner_function->set_input(0, reduction_node->name()); - UpdateConsumersAvoidingLoop(inner_function, reduction_node->name()); reduction_node->set_input(0, inner_input->name()); - UpdateConsumersAvoidingLoop(reduction_node, inner_function->name()); + ctx().node_map->UpdateInput(reduction_node->name(), + inner_function->name(), inner_input->name()); + inner_function->set_input(0, reduction_node->name()); + UpdateConsumers(reduction_node, inner_function->name()); + ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(), + reduction_node->name()); + + AddToOptimizationQueue(reduction_node); + AddToOptimizationQueue(inner_function); + AddToOptimizationQueue(inner_input); } return Status::OK(); } - void UpdateConsumersAvoidingLoop(NodeDef* node, const string& new_input) { + void UpdateConsumers(NodeDef* node, const string& new_input) { const string& node_name = node->name(); const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name); for (NodeDef* consumer : consumers) { |