aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc21
1 files changed, 15 insertions, 6 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 4fb2fe6883..4fed88d536 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -2703,22 +2703,31 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
NodeDef* inner_function;
TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
// Optimize only if:
+ // 0. inner_function is not in the preserve set,
// 1. inner_function's Op is element-wise monotonic
// 2. inner_function's output is not being consumed elsewhere.
- if (IsElementWiseMonotonic(*inner_function) &&
- (NumNonControlOutputs(*inner_function, *ctx().node_map) == 1)) {
+ if (!IsInPreserveSet(*inner_function) &&
+ IsElementWiseMonotonic(*inner_function) &&
+ ctx().node_map->GetOutputs(inner_function->name()).size() == 1) {
// Swap the first inputs of the inner function Op & the reduction Op.
NodeDef* inner_input;
TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
- inner_function->set_input(0, reduction_node->name());
- UpdateConsumersAvoidingLoop(inner_function, reduction_node->name());
reduction_node->set_input(0, inner_input->name());
- UpdateConsumersAvoidingLoop(reduction_node, inner_function->name());
+ ctx().node_map->UpdateInput(reduction_node->name(),
+ inner_function->name(), inner_input->name());
+ inner_function->set_input(0, reduction_node->name());
+ UpdateConsumers(reduction_node, inner_function->name());
+ ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
+ reduction_node->name());
+
+ AddToOptimizationQueue(reduction_node);
+ AddToOptimizationQueue(inner_function);
+ AddToOptimizationQueue(inner_input);
}
return Status::OK();
}
- void UpdateConsumersAvoidingLoop(NodeDef* node, const string& new_input) {
+ void UpdateConsumers(NodeDef* node, const string& new_input) {
const string& node_name = node->name();
const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
for (NodeDef* consumer : consumers) {