diff options
author | 2018-08-30 18:25:30 -0700 | |
---|---|---|
committer | 2018-08-30 18:29:29 -0700 | |
commit | 58745044c7224f894a51dcf25e26412fa23fc444 (patch) | |
tree | 04ad4c60a6ccc670212f3d5e1870d3f0cb82196f /tensorflow/core/grappler | |
parent | ab96371eaea4cc5b2f9c431eec455a1cf4be7c1c (diff) |
Fix bug in hoisting monotonic functions out of reductions: Do not change the value of nodes in the preserve set, e.g. fetch nodes.
I simplified the rewiring logic a tad.
PiperOrigin-RevId: 211017989
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 21 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 24 |
2 files changed, 39 insertions, 6 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 4fb2fe6883..4fed88d536 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2703,22 +2703,31 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { NodeDef* inner_function; TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function)); // Optimize only if: + // 0. inner_function is not in the preserve set, // 1. inner_function's Op is element-wise monotonic // 2. inner_function's output is not being consumed elsewhere. - if (IsElementWiseMonotonic(*inner_function) && - (NumNonControlOutputs(*inner_function, *ctx().node_map) == 1)) { + if (!IsInPreserveSet(*inner_function) && + IsElementWiseMonotonic(*inner_function) && + ctx().node_map->GetOutputs(inner_function->name()).size() == 1) { // Swap the first inputs of the inner function Op & the reduction Op. NodeDef* inner_input; TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input)); - inner_function->set_input(0, reduction_node->name()); - UpdateConsumersAvoidingLoop(inner_function, reduction_node->name()); reduction_node->set_input(0, inner_input->name()); - UpdateConsumersAvoidingLoop(reduction_node, inner_function->name()); + ctx().node_map->UpdateInput(reduction_node->name(), + inner_function->name(), inner_input->name()); + inner_function->set_input(0, reduction_node->name()); + UpdateConsumers(reduction_node, inner_function->name()); + ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(), + reduction_node->name()); + + AddToOptimizationQueue(reduction_node); + AddToOptimizationQueue(inner_function); + AddToOptimizationQueue(inner_input); } return Status::OK(); } - void UpdateConsumersAvoidingLoop(NodeDef* node, const string& new_input) { + void UpdateConsumers(NodeDef* node, const string& new_input) { const string& node_name = node->name(); const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name); for (NodeDef* consumer : consumers) { diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 685b5379af..bfccc0affd 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -3224,6 +3224,30 @@ TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) { EXPECT_EQ(2, required_node_count); } +TEST_F(ArithmeticOptimizerTest, + OptimizeMaxOrMinOfMonotonicElementWise_DoNotChangeFetchNode) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x); + Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0}); + Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max); + + GrapplerItem item; + item.fetch = {"sqrt", "final_out"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(2, tensors_expected.size()); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer); + OptimizeTwice(&optimizer, &item, &output); + + // Should be a NoOp since we are not allowed to change the output of fetch + // nodes. + VerifyGraphsMatch(item.graph, output, __LINE__); +} + TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); |