From 27df639673ae2bfe63b82862008da9bec488f0db Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 9 Oct 2017 13:00:39 -0700 Subject: [Grappler] Correctly replace control-dependency uses. When redirecting the use of node A to node B, old code incorrectly replace control dependencies with data dependencies. PiperOrigin-RevId: 171575072 --- .../grappler/optimizers/arithmetic_optimizer.cc | 14 ++++++++---- .../optimizers/arithmetic_optimizer_test.cc | 25 ++++++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 343820de71..5c9073f049 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -555,12 +555,18 @@ void ArithmeticOptimizer::SimplifyArithmeticOps( for (NodeDef* consumer : consumers) { // Update `consumer`'s use of `node` to `input`'s operand. for (int i = 0; i < consumer->input_size(); ++i) { - if (NodeName(consumer->input(i)) == node->name()) { - *consumer->mutable_input(i) = simplified_tensor; + int operand_pos; + string operand_node_name = + ParseNodeName(consumer->input(i), &operand_pos); + if (operand_node_name == node->name()) { + *consumer->mutable_input(i) = + (operand_pos < 0 + ? AsControlDependency(NodeName(simplified_tensor)) + : simplified_tensor); } + VLOG(2) << "Update input " << consumer->input(i) << " of " + << consumer->name() << " to " << simplified_tensor; } - VLOG(2) << "Update input " << node->name() << " of " << consumer->name() - << " to " << simplified_tensor; node_map.UpdateInput(consumer->name(), node->name(), simplified_tensor); if (!nodes_to_simplify.Exists(consumer)) { nodes_to_simplify.PushBack(consumer); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index b3405646eb..7965419ea2 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -240,6 +240,31 @@ TEST_F(ArithmeticOptimizerTest, RemoveInverseTransposesMultipleOutputs) { } } +TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output inputs = + ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({2, 3})); + Output transpose1 = ops::Transpose(s, inputs, ops::Const(s, {1, 0})); + Output transpose2 = ops::Transpose(s, transpose1, ops::Const(s, {1, 0})); + Output outputs = + ops::Identity(s.WithOpName("outputs").WithControlDependencies(transpose2), + ops::Const(s.WithOpName("outputs_const"), 1.0f)); + + GrapplerItem item; + item.fetch = {"outputs"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + GraphDef output; + TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); + item.graph = output; + TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + + NodeMap node_map(&output); + const NodeDef* outputs_node = node_map.GetNode("outputs"); + EXPECT_EQ(2, outputs_node->input_size()); + EXPECT_EQ(outputs_node->input(0), "outputs_const"); + EXPECT_EQ(outputs_node->input(1), "^Placeholder"); +} + TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs_shape = -- cgit v1.2.3