aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-30 18:25:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 18:29:29 -0700
commit58745044c7224f894a51dcf25e26412fa23fc444 (patch)
tree04ad4c60a6ccc670212f3d5e1870d3f0cb82196f /tensorflow/core/grappler
parentab96371eaea4cc5b2f9c431eec455a1cf4be7c1c (diff)
Fix bug in hoisting monotonic functions out of reductions: Do not change the value of nodes in the preserve set, e.g. fetch nodes.
I simplified the rewiring logic a tad. PiperOrigin-RevId: 211017989
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc21
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc24
2 files changed, 39 insertions, 6 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 4fb2fe6883..4fed88d536 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -2703,22 +2703,31 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
NodeDef* inner_function;
TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
// Optimize only if:
+ // 0. inner_function is not in the preserve set,
// 1. inner_function's Op is element-wise monotonic
// 2. inner_function's output is not being consumed elsewhere.
- if (IsElementWiseMonotonic(*inner_function) &&
- (NumNonControlOutputs(*inner_function, *ctx().node_map) == 1)) {
+ if (!IsInPreserveSet(*inner_function) &&
+ IsElementWiseMonotonic(*inner_function) &&
+ ctx().node_map->GetOutputs(inner_function->name()).size() == 1) {
// Swap the first inputs of the inner function Op & the reduction Op.
NodeDef* inner_input;
TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
- inner_function->set_input(0, reduction_node->name());
- UpdateConsumersAvoidingLoop(inner_function, reduction_node->name());
reduction_node->set_input(0, inner_input->name());
- UpdateConsumersAvoidingLoop(reduction_node, inner_function->name());
+ ctx().node_map->UpdateInput(reduction_node->name(),
+ inner_function->name(), inner_input->name());
+ inner_function->set_input(0, reduction_node->name());
+ UpdateConsumers(reduction_node, inner_function->name());
+ ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
+ reduction_node->name());
+
+ AddToOptimizationQueue(reduction_node);
+ AddToOptimizationQueue(inner_function);
+ AddToOptimizationQueue(inner_input);
}
return Status::OK();
}
- void UpdateConsumersAvoidingLoop(NodeDef* node, const string& new_input) {
+ void UpdateConsumers(NodeDef* node, const string& new_input) {
const string& node_name = node->name();
const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
for (NodeDef* consumer : consumers) {
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 685b5379af..bfccc0affd 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -3224,6 +3224,30 @@ TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
EXPECT_EQ(2, required_node_count);
}
+TEST_F(ArithmeticOptimizerTest,
+ OptimizeMaxOrMinOfMonotonicElementWise_DoNotChangeFetchNode) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
+ Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
+ Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
+
+ GrapplerItem item;
+ item.fetch = {"sqrt", "final_out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(2, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
+ OptimizeTwice(&optimizer, &item, &output);
+
+ // Should be a NoOp since we are not allowed to change the output of fetch
+ // nodes.
+ VerifyGraphsMatch(item.graph, output, __LINE__);
+}
+
TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();