aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-30 18:25:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 18:29:29 -0700
commit58745044c7224f894a51dcf25e26412fa23fc444 (patch)
tree04ad4c60a6ccc670212f3d5e1870d3f0cb82196f /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
parentab96371eaea4cc5b2f9c431eec455a1cf4be7c1c (diff)
Fix bug in hoisting monotonic functions out of reductions: Do not change the value of nodes in the preserve set, e.g. fetch nodes.
I simplified the rewiring logic a tad. PiperOrigin-RevId: 211017989
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc21
1 files changed, 15 insertions, 6 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 4fb2fe6883..4fed88d536 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -2703,22 +2703,31 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
NodeDef* inner_function;
TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
// Optimize only if:
+ // 0. inner_function is not in the preserve set,
// 1. inner_function's Op is element-wise monotonic
// 2. inner_function's output is not being consumed elsewhere.
- if (IsElementWiseMonotonic(*inner_function) &&
- (NumNonControlOutputs(*inner_function, *ctx().node_map) == 1)) {
+ if (!IsInPreserveSet(*inner_function) &&
+ IsElementWiseMonotonic(*inner_function) &&
+ ctx().node_map->GetOutputs(inner_function->name()).size() == 1) {
// Swap the first inputs of the inner function Op & the reduction Op.
NodeDef* inner_input;
TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
- inner_function->set_input(0, reduction_node->name());
- UpdateConsumersAvoidingLoop(inner_function, reduction_node->name());
reduction_node->set_input(0, inner_input->name());
- UpdateConsumersAvoidingLoop(reduction_node, inner_function->name());
+ ctx().node_map->UpdateInput(reduction_node->name(),
+ inner_function->name(), inner_input->name());
+ inner_function->set_input(0, reduction_node->name());
+ UpdateConsumers(reduction_node, inner_function->name());
+ ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
+ reduction_node->name());
+
+ AddToOptimizationQueue(reduction_node);
+ AddToOptimizationQueue(inner_function);
+ AddToOptimizationQueue(inner_input);
}
return Status::OK();
}
- void UpdateConsumersAvoidingLoop(NodeDef* node, const string& new_input) {
+ void UpdateConsumers(NodeDef* node, const string& new_input) {
const string& node_name = node->name();
const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
for (NodeDef* consumer : consumers) {