aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jingyue Wu <jingyue@google.com>2017-10-09 13:00:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-09 13:04:43 -0700
commit27df639673ae2bfe63b82862008da9bec488f0db (patch)
tree17798e92c6dbeee4be88c67d408950430627c601
parent1ba562a6878905c9967e999a73e749b59de56e21 (diff)
[Grappler] Correctly replace control-dependency uses.
When redirecting the use of node A to node B, old code incorrectly replace control dependencies with data dependencies. PiperOrigin-RevId: 171575072
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc14
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc25
2 files changed, 35 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 343820de71..5c9073f049 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -555,12 +555,18 @@ void ArithmeticOptimizer::SimplifyArithmeticOps(
for (NodeDef* consumer : consumers) {
// Update `consumer`'s use of `node` to `input`'s operand.
for (int i = 0; i < consumer->input_size(); ++i) {
- if (NodeName(consumer->input(i)) == node->name()) {
- *consumer->mutable_input(i) = simplified_tensor;
+ int operand_pos;
+ string operand_node_name =
+ ParseNodeName(consumer->input(i), &operand_pos);
+ if (operand_node_name == node->name()) {
+ *consumer->mutable_input(i) =
+ (operand_pos < 0
+ ? AsControlDependency(NodeName(simplified_tensor))
+ : simplified_tensor);
}
+ VLOG(2) << "Update input " << consumer->input(i) << " of "
+ << consumer->name() << " to " << simplified_tensor;
}
- VLOG(2) << "Update input " << node->name() << " of " << consumer->name()
- << " to " << simplified_tensor;
node_map.UpdateInput(consumer->name(), node->name(), simplified_tensor);
if (!nodes_to_simplify.Exists(consumer)) {
nodes_to_simplify.PushBack(consumer);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index b3405646eb..7965419ea2 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -240,6 +240,31 @@ TEST_F(ArithmeticOptimizerTest, RemoveInverseTransposesMultipleOutputs) {
}
}
+TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output inputs =
+ ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({2, 3}));
+ Output transpose1 = ops::Transpose(s, inputs, ops::Const(s, {1, 0}));
+ Output transpose2 = ops::Transpose(s, transpose1, ops::Const(s, {1, 0}));
+ Output outputs =
+ ops::Identity(s.WithOpName("outputs").WithControlDependencies(transpose2),
+ ops::Const(s.WithOpName("outputs_const"), 1.0f));
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphDef output;
+ TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
+ item.graph = output;
+ TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+
+ NodeMap node_map(&output);
+ const NodeDef* outputs_node = node_map.GetNode("outputs");
+ EXPECT_EQ(2, outputs_node->input_size());
+ EXPECT_EQ(outputs_node->input(0), "outputs_const");
+ EXPECT_EQ(outputs_node->input(1), "^Placeholder");
+}
+
TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs_shape =