aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-18 14:41:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-18 14:45:36 -0700
commit964a32573bffbb798d0eb97ec9b37da0657c4dbd (patch)
treea197af43ef8a2e82914c31a9e42d5fd655368973 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
parent33170cc661f3838aa7d0d7fc19bb0c6ba4812a3c (diff)
Clean up remove_negation pass in Grappler.
PiperOrigin-RevId: 213520177
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc42
1 files changed, 24 insertions, 18 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index bc838c6659..88839d944c 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -2353,9 +2353,14 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
Output sub_negx_y = ops::Sub(s.WithOpName("Sub_negx_y"), neg_x, y);
Output sub_x_negy = ops::Sub(s.WithOpName("Sub_x_negy"), x, neg_y);
Output sub_negx_negy = ops::Sub(s.WithOpName("Sub_negx_negy"), neg_x, neg_y);
- auto add_all = ops::AddN(s.WithOpName("add_all"),
- {add_x_y, add_negx_y, add_x_negy, add_negx_negy,
- sub_x_y, sub_negx_y, sub_x_negy, sub_negx_negy});
+ Output neg_x_with_dep = ops::Neg(
+ s.WithOpName("Neg_x_with_dep").WithControlDependencies({add_x_y}), x);
+ Output add_negx_with_dep_y =
+ ops::Add(s.WithOpName("Add_negx_with_dep_y"), neg_x_with_dep, y);
+ auto add_all =
+ ops::AddN(s.WithOpName("add_all"),
+ {add_x_y, add_negx_y, add_x_negy, add_negx_negy, sub_x_y,
+ sub_negx_y, sub_x_negy, sub_negx_negy, add_negx_with_dep_y});
GrapplerItem item;
item.fetch = {"add_all"};
@@ -2370,7 +2375,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveNegation(&optimizer);
- OptimizeAndPrune(&optimizer, &item, &output);
+ OptimizeTwice(&optimizer, &item, &output);
EXPECT_EQ(item.graph.node_size(), output.node_size());
int found = 0;
@@ -2379,42 +2384,43 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
if (node.name() == "Add_negx_y") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("x", node.input(1));
- EXPECT_EQ("^Neg_x", node.input(2));
} else if (node.name() == "Add_x_negy") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("y", node.input(1));
- EXPECT_EQ("^Neg_y", node.input(2));
} else if (node.name() == "Add_negx_negy") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(3, node.input_size());
- EXPECT_EQ("Neg_y", node.input(0));
- EXPECT_EQ("x", node.input(1));
- EXPECT_EQ("^Neg_x", node.input(2));
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("Neg_x", node.input(0));
+ EXPECT_EQ("y", node.input(1));
} else if (node.name() == "Sub_x_negy") {
++found;
EXPECT_EQ("Add", node.op());
- EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("y", node.input(1));
- EXPECT_EQ("^Neg_y", node.input(2));
} else if (node.name() == "Sub_negx_negy") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(4, node.input_size());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("y", node.input(0));
+ EXPECT_EQ("x", node.input(1));
+ } else if (node.name() == "Add_negx_with_dep_y") {
+ ++found;
+ EXPECT_EQ("Sub", node.op());
+ EXPECT_EQ(3, node.input_size());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("x", node.input(1));
- EXPECT_EQ("^Neg_y", node.input(2));
- EXPECT_EQ("^Neg_x", node.input(3));
+ EXPECT_EQ("^Add_x_y", node.input(2));
}
}
- EXPECT_EQ(5, found);
+ EXPECT_EQ(6, found);
auto tensors = EvaluateNodes(output, item.fetch, feed);
EXPECT_EQ(1, tensors.size());