aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-18 14:41:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-18 14:45:36 -0700
commit964a32573bffbb798d0eb97ec9b37da0657c4dbd (patch)
treea197af43ef8a2e82914c31a9e42d5fd655368973 /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
parent33170cc661f3838aa7d0d7fc19bb0c6ba4812a3c (diff)
Clean up remove_negation pass in Grappler.
PiperOrigin-RevId: 213520177
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc42
1 files changed, 15 insertions, 27 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 11ce121cba..992e85d2c6 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -1325,38 +1325,26 @@ class RemoveNegationStage : public ArithmeticOptimizerStage {
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- const string node_name = node->name();
NodeDef* x;
NodeDef* y;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
bool updated = false;
- if (IsAdd(*node)) {
- if (IsNeg(*x)) {
- // (-a) + b = b - a
- node->set_op("Sub");
- node->mutable_input()->SwapElements(0, 1);
- node->set_input(1, x->input(0));
- node->add_input(AsControlDependency(x->name()));
- ctx().node_map->AddOutput(NodeName(x->input(0)), node_name);
- updated = true;
- } else if (IsNeg(*y)) {
- // a + (-b) = a - b
- node->set_op("Sub");
- node->set_input(1, y->input(0));
- node->add_input(AsControlDependency(y->name()));
- ctx().node_map->AddOutput(NodeName(y->input(0)), node_name);
- updated = true;
- }
- } else if (IsSub(*node)) {
- if (IsNeg(*y)) {
- // a - (-b) = a + b
- node->set_op("Add");
- node->set_input(1, y->input(0));
- node->add_input(AsControlDependency(y->name()));
- ctx().node_map->AddOutput(NodeName(y->input(0)), node_name);
- updated = true;
- }
+ if (IsNeg(*y)) {
+ // a - (-b) = a + b or a + (-b) = a - b
+ ForwardControlDependencies(node, {y});
+ ctx().node_map->UpdateInput(node->name(), node->input(1), y->input(0));
+ node->set_op(IsAdd(*node) ? "Sub" : "Add");
+ node->set_input(1, y->input(0));
+ updated = true;
+ } else if (IsAdd(*node) && IsNeg(*x)) {
+ // (-a) + b = b - a
+ ForwardControlDependencies(node, {x});
+ ctx().node_map->UpdateInput(node->name(), node->input(0), x->input(0));
+ node->set_op("Sub");
+ node->mutable_input()->SwapElements(0, 1);
+ node->set_input(1, x->input(0));
+ updated = true;
}
if (updated) {
AddToOptimizationQueue(node);