aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc14
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 343820de71..5c9073f049 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -555,12 +555,18 @@ void ArithmeticOptimizer::SimplifyArithmeticOps(
for (NodeDef* consumer : consumers) {
// Update `consumer`'s use of `node` to `input`'s operand.
for (int i = 0; i < consumer->input_size(); ++i) {
- if (NodeName(consumer->input(i)) == node->name()) {
- *consumer->mutable_input(i) = simplified_tensor;
+ int operand_pos;
+ string operand_node_name =
+ ParseNodeName(consumer->input(i), &operand_pos);
+ if (operand_node_name == node->name()) {
+ *consumer->mutable_input(i) =
+ (operand_pos < 0
+ ? AsControlDependency(NodeName(simplified_tensor))
+ : simplified_tensor);
}
+ VLOG(2) << "Update input " << consumer->input(i) << " of "
+ << consumer->name() << " to " << simplified_tensor;
}
- VLOG(2) << "Update input " << node->name() << " of " << consumer->name()
- << " to " << simplified_tensor;
node_map.UpdateInput(consumer->name(), node->name(), simplified_tensor);
if (!nodes_to_simplify.Exists(consumer)) {
nodes_to_simplify.PushBack(consumer);