aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2018-06-05 12:19:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-05 12:22:39 -0700
commit2b5f598fbd822f911ad305ae1e57325aefd50826 (patch)
tree30ced01eceaa62a99ea7908688df5f79bf4c46d6 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
parent920df27282b3f5d03d79f54ef05cea305c2a30d7 (diff)
Move ReplaceMulWithSquare to a separate optimizer stage.
PiperOrigin-RevId: 199338297
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc47
1 files changed, 27 insertions, 20 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index b9fec0f860..f15cbfe407 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -139,6 +139,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
options.remove_negation = false;
options.remove_logical_not = false;
options.reorder_cast_and_transpose = false;
+ options.replace_mul_with_square = false;
optimizer->options_ = options;
}
@@ -201,6 +202,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
optimizer->options_.reorder_cast_and_transpose = true;
}
+ void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.replace_mul_with_square = true;
+ }
+
void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.hoist_cwise_unary_chains = true;
@@ -345,33 +351,36 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
-TEST_F(ArithmeticOptimizerTest, MulToSquare) {
+TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
Output d = ops::Const(s.WithOpName("d"), {3.0f, 4.0f}, {1, 2});
Output mul = ops::Mul(s.WithControlDependencies(d).WithOpName("mul"), c, c);
Output id = ops::Identity(s.WithOpName("id"), mul);
+
GrapplerItem item;
+ item.fetch = {"id"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- std::vector<string> fetch = {"id"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(1, tensors_expected.size());
- ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ ArithmeticOptimizer optimizer;
+ EnableOnlyReplaceMulWithSquare(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
- EXPECT_EQ(5, output.node_size());
- EXPECT_EQ("id", output.node(3).name());
- EXPECT_EQ(OptimizedName("mul_square"), output.node(3).input(0));
- EXPECT_EQ("Square", output.node(4).op());
- EXPECT_EQ(OptimizedName("mul_square"), output.node(4).name());
- EXPECT_EQ(2, output.node(4).input_size());
- EXPECT_EQ("c", output.node(4).input(0));
- EXPECT_EQ("^d", output.node(4).input(1));
+ EXPECT_EQ(4, output.node_size());
- auto tensors = EvaluateNodes(output, fetch);
+ NodeMap node_map(&output);
+ const string p = "ArithmeticOptimizer/ReplaceMulWithSquare";
+ const NodeDef* square_node = node_map.GetNode(strings::StrCat(p, "_", "mul"));
+
+ ASSERT_NE(square_node, nullptr);
+ EXPECT_EQ("Square", square_node->op());
+ EXPECT_EQ("c", square_node->input(0));
+ EXPECT_EQ("^d", square_node->input(1));
+
+ auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
@@ -386,12 +395,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AdjacentNodes) {
auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1);
auto id = ops::Identity(s.WithOpName("id"), recip2);
- std::vector<string> fetch = {"id"};
-
GrapplerItem item;
- item.fetch = fetch;
+ item.fetch = {"id"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(1, tensors_expected.size());
GraphDef output;
@@ -404,7 +411,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AdjacentNodes) {
EXPECT_EQ("id", output.node(1).name());
EXPECT_EQ("c", output.node(1).input(0));
- auto tensors = EvaluateNodes(output, fetch);
+ auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}