aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-04 11:18:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 11:47:06 -0700
commit965e3b0ca01ed7cc951131454b38ab638ff44fbf (patch)
tree7a7dad14dc6a68b1cce6b33c3d37f569a79f63b4 /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
parent5d183ab7fc7b82f1dea0b9fa9c6412c39ade15a1 (diff)
Extend hoisting monotonic functions out of min/max reductions to all monotonic unary functions.
Add the ability to flip Max <-> Min if the function is non-increasing, e.g. Max(Neg(x)) => Neg(Min(x)). PiperOrigin-RevId: 211490436
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc10
1 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 4fed88d536..65947ddce5 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -2706,8 +2706,9 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
// 0. inner_function is not in the preserve set,
// 1. inner_function's Op is element-wise monotonic
// 2. inner_function's output is not being consumed elsewhere.
+ bool is_non_decreasing = false;
if (!IsInPreserveSet(*inner_function) &&
- IsElementWiseMonotonic(*inner_function) &&
+ IsElementWiseMonotonic(*inner_function, &is_non_decreasing) &&
ctx().node_map->GetOutputs(inner_function->name()).size() == 1) {
// Swap the first inputs of the inner function Op & the reduction Op.
NodeDef* inner_input;
@@ -2719,7 +2720,12 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
UpdateConsumers(reduction_node, inner_function->name());
ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
reduction_node->name());
-
+ if (!is_non_decreasing) {
+ // Flip Min<->Max if the function is non-increasing, e.g.
+ // Max(Neg(x)) = Neg(Min(x)).
+ const string opposite = IsMax(*reduction_node) ? "Min" : "Max";
+ reduction_node->set_op(opposite);
+ }
AddToOptimizationQueue(reduction_node);
AddToOptimizationQueue(inner_function);
AddToOptimizationQueue(inner_input);