diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-18 14:41:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-18 14:45:36 -0700 |
commit | 964a32573bffbb798d0eb97ec9b37da0657c4dbd (patch) | |
tree | a197af43ef8a2e82914c31a9e42d5fd655368973 /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | |
parent | 33170cc661f3838aa7d0d7fc19bb0c6ba4812a3c (diff) |
Clean up remove_negation pass in Grappler.
PiperOrigin-RevId: 213520177
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 42 |
1 files changed, 15 insertions, 27 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 11ce121cba..992e85d2c6 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1325,38 +1325,26 @@ class RemoveNegationStage : public ArithmeticOptimizerStage { } Status TrySimplify(NodeDef* node, string* simplified_node_name) override { - const string node_name = node->name(); NodeDef* x; NodeDef* y; TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x)); TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y)); bool updated = false; - if (IsAdd(*node)) { - if (IsNeg(*x)) { - // (-a) + b = b - a - node->set_op("Sub"); - node->mutable_input()->SwapElements(0, 1); - node->set_input(1, x->input(0)); - node->add_input(AsControlDependency(x->name())); - ctx().node_map->AddOutput(NodeName(x->input(0)), node_name); - updated = true; - } else if (IsNeg(*y)) { - // a + (-b) = a - b - node->set_op("Sub"); - node->set_input(1, y->input(0)); - node->add_input(AsControlDependency(y->name())); - ctx().node_map->AddOutput(NodeName(y->input(0)), node_name); - updated = true; - } - } else if (IsSub(*node)) { - if (IsNeg(*y)) { - // a - (-b) = a + b - node->set_op("Add"); - node->set_input(1, y->input(0)); - node->add_input(AsControlDependency(y->name())); - ctx().node_map->AddOutput(NodeName(y->input(0)), node_name); - updated = true; - } + if (IsNeg(*y)) { + // a - (-b) = a + b or a + (-b) = a - b + ForwardControlDependencies(node, {y}); + ctx().node_map->UpdateInput(node->name(), node->input(1), y->input(0)); + node->set_op(IsAdd(*node) ? "Sub" : "Add"); + node->set_input(1, y->input(0)); + updated = true; + } else if (IsAdd(*node) && IsNeg(*x)) { + // (-a) + b = b - a + ForwardControlDependencies(node, {x}); + ctx().node_map->UpdateInput(node->name(), node->input(0), x->input(0)); + node->set_op("Sub"); + node->mutable_input()->SwapElements(0, 1); + node->set_input(1, x->input(0)); + updated = true; } if (updated) { AddToOptimizationQueue(node); |