diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 343820de71..5c9073f049 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -555,12 +555,18 @@ void ArithmeticOptimizer::SimplifyArithmeticOps( for (NodeDef* consumer : consumers) { // Update `consumer`'s use of `node` to `input`'s operand. for (int i = 0; i < consumer->input_size(); ++i) { - if (NodeName(consumer->input(i)) == node->name()) { - *consumer->mutable_input(i) = simplified_tensor; + int operand_pos; + string operand_node_name = + ParseNodeName(consumer->input(i), &operand_pos); + if (operand_node_name == node->name()) { + *consumer->mutable_input(i) = + (operand_pos < 0 + ? AsControlDependency(NodeName(simplified_tensor)) + : simplified_tensor); } + VLOG(2) << "Update input " << consumer->input(i) << " of " + << consumer->name() << " to " << simplified_tensor; } - VLOG(2) << "Update input " << node->name() << " of " << consumer->name() - << " to " << simplified_tensor; node_map.UpdateInput(consumer->name(), node->name(), simplified_tensor); if (!nodes_to_simplify.Exists(consumer)) { nodes_to_simplify.PushBack(consumer); |