diff options
author | 2017-10-09 13:00:39 -0700 | |
---|---|---|
committer | 2017-10-09 13:04:43 -0700 | |
commit | 27df639673ae2bfe63b82862008da9bec488f0db (patch) | |
tree | 17798e92c6dbeee4be88c67d408950430627c601 | |
parent | 1ba562a6878905c9967e999a73e749b59de56e21 (diff) |
[Grappler] Correctly replace control-dependency uses.
When redirecting the use of node A to node B, old code incorrectly replace
control dependencies with data dependencies.
PiperOrigin-RevId: 171575072
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 14 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 25 |
2 files changed, 35 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 343820de71..5c9073f049 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -555,12 +555,18 @@ void ArithmeticOptimizer::SimplifyArithmeticOps( for (NodeDef* consumer : consumers) { // Update `consumer`'s use of `node` to `input`'s operand. for (int i = 0; i < consumer->input_size(); ++i) { - if (NodeName(consumer->input(i)) == node->name()) { - *consumer->mutable_input(i) = simplified_tensor; + int operand_pos; + string operand_node_name = + ParseNodeName(consumer->input(i), &operand_pos); + if (operand_node_name == node->name()) { + *consumer->mutable_input(i) = + (operand_pos < 0 + ? AsControlDependency(NodeName(simplified_tensor)) + : simplified_tensor); } + VLOG(2) << "Update input " << consumer->input(i) << " of " + << consumer->name() << " to " << simplified_tensor; } - VLOG(2) << "Update input " << node->name() << " of " << consumer->name() - << " to " << simplified_tensor; node_map.UpdateInput(consumer->name(), node->name(), simplified_tensor); if (!nodes_to_simplify.Exists(consumer)) { nodes_to_simplify.PushBack(consumer); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index b3405646eb..7965419ea2 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -240,6 +240,31 @@ TEST_F(ArithmeticOptimizerTest, RemoveInverseTransposesMultipleOutputs) { } } +TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output inputs = + ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({2, 3})); + Output transpose1 = ops::Transpose(s, inputs, ops::Const(s, {1, 0})); + Output transpose2 = ops::Transpose(s, transpose1, ops::Const(s, {1, 0})); + Output outputs = + ops::Identity(s.WithOpName("outputs").WithControlDependencies(transpose2), + ops::Const(s.WithOpName("outputs_const"), 1.0f)); + + GrapplerItem item; + item.fetch = {"outputs"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + GraphDef output; + TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); + item.graph = output; + TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + + NodeMap node_map(&output); + const NodeDef* outputs_node = node_map.GetNode("outputs"); + EXPECT_EQ(2, outputs_node->input_size()); + EXPECT_EQ(outputs_node->input(0), "outputs_const"); + EXPECT_EQ(outputs_node->input(1), "^Placeholder"); +} + TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs_shape = |