diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-18 14:41:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-18 14:45:36 -0700 |
commit | 964a32573bffbb798d0eb97ec9b37da0657c4dbd (patch) | |
tree | a197af43ef8a2e82914c31a9e42d5fd655368973 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | |
parent | 33170cc661f3838aa7d0d7fc19bb0c6ba4812a3c (diff) |
Clean up remove_negation pass in Grappler.
PiperOrigin-RevId: 213520177
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 42 |
1 files changed, 24 insertions, 18 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index bc838c6659..88839d944c 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -2353,9 +2353,14 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) { Output sub_negx_y = ops::Sub(s.WithOpName("Sub_negx_y"), neg_x, y); Output sub_x_negy = ops::Sub(s.WithOpName("Sub_x_negy"), x, neg_y); Output sub_negx_negy = ops::Sub(s.WithOpName("Sub_negx_negy"), neg_x, neg_y); - auto add_all = ops::AddN(s.WithOpName("add_all"), - {add_x_y, add_negx_y, add_x_negy, add_negx_negy, - sub_x_y, sub_negx_y, sub_x_negy, sub_negx_negy}); + Output neg_x_with_dep = ops::Neg( + s.WithOpName("Neg_x_with_dep").WithControlDependencies({add_x_y}), x); + Output add_negx_with_dep_y = + ops::Add(s.WithOpName("Add_negx_with_dep_y"), neg_x_with_dep, y); + auto add_all = + ops::AddN(s.WithOpName("add_all"), + {add_x_y, add_negx_y, add_x_negy, add_negx_negy, sub_x_y, + sub_negx_y, sub_x_negy, sub_negx_negy, add_negx_with_dep_y}); GrapplerItem item; item.fetch = {"add_all"}; @@ -2370,7 +2375,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) { GraphDef output; ArithmeticOptimizer optimizer; EnableOnlyRemoveNegation(&optimizer); - OptimizeAndPrune(&optimizer, &item, &output); + OptimizeTwice(&optimizer, &item, &output); EXPECT_EQ(item.graph.node_size(), output.node_size()); int found = 0; @@ -2379,42 +2384,43 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) { if (node.name() == "Add_negx_y") { ++found; EXPECT_EQ("Sub", node.op()); - EXPECT_EQ(3, node.input_size()); + EXPECT_EQ(2, node.input_size()); EXPECT_EQ("y", node.input(0)); EXPECT_EQ("x", node.input(1)); - EXPECT_EQ("^Neg_x", node.input(2)); } else if (node.name() == "Add_x_negy") { ++found; EXPECT_EQ("Sub", node.op()); - EXPECT_EQ(3, node.input_size()); + EXPECT_EQ(2, node.input_size()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("y", node.input(1)); - EXPECT_EQ("^Neg_y", node.input(2)); } else if (node.name() == "Add_negx_negy") { ++found; EXPECT_EQ("Sub", node.op()); - EXPECT_EQ(3, node.input_size()); - EXPECT_EQ("Neg_y", node.input(0)); - EXPECT_EQ("x", node.input(1)); - EXPECT_EQ("^Neg_x", node.input(2)); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("Neg_x", node.input(0)); + EXPECT_EQ("y", node.input(1)); } else if (node.name() == "Sub_x_negy") { ++found; EXPECT_EQ("Add", node.op()); - EXPECT_EQ(3, node.input_size()); + EXPECT_EQ(2, node.input_size()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("y", node.input(1)); - EXPECT_EQ("^Neg_y", node.input(2)); } else if (node.name() == "Sub_negx_negy") { ++found; EXPECT_EQ("Sub", node.op()); - EXPECT_EQ(4, node.input_size()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("y", node.input(0)); + EXPECT_EQ("x", node.input(1)); + } else if (node.name() == "Add_negx_with_dep_y") { + ++found; + EXPECT_EQ("Sub", node.op()); + EXPECT_EQ(3, node.input_size()); EXPECT_EQ("y", node.input(0)); EXPECT_EQ("x", node.input(1)); - EXPECT_EQ("^Neg_y", node.input(2)); - EXPECT_EQ("^Neg_x", node.input(3)); + EXPECT_EQ("^Add_x_y", node.input(2)); } } - EXPECT_EQ(5, found); + EXPECT_EQ(6, found); auto tensors = EvaluateNodes(output, item.fetch, feed); EXPECT_EQ(1, tensors.size()); |