diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2018-06-05 12:19:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-05 12:22:39 -0700 |
commit | 2b5f598fbd822f911ad305ae1e57325aefd50826 (patch) | |
tree | 30ced01eceaa62a99ea7908688df5f79bf4c46d6 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | |
parent | 920df27282b3f5d03d79f54ef05cea305c2a30d7 (diff) |
Move ReplaceMulWithSquare to a separate optimizer stage.
PiperOrigin-RevId: 199338297
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 47 |
1 files changed, 27 insertions, 20 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index b9fec0f860..f15cbfe407 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -139,6 +139,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.remove_negation = false; options.remove_logical_not = false; options.reorder_cast_and_transpose = false; + options.replace_mul_with_square = false; optimizer->options_ = options; } @@ -201,6 +202,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.reorder_cast_and_transpose = true; } + void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.replace_mul_with_square = true; + } + void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.hoist_cwise_unary_chains = true; @@ -345,33 +351,36 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) { test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, MulToSquare) { +TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); Output d = ops::Const(s.WithOpName("d"), {3.0f, 4.0f}, {1, 2}); Output mul = ops::Mul(s.WithControlDependencies(d).WithOpName("mul"), c, c); Output id = ops::Identity(s.WithOpName("id"), mul); + GrapplerItem item; + item.fetch = {"id"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - std::vector<string> fetch = {"id"}; - auto tensors_expected = EvaluateNodes(item.graph, fetch); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); EXPECT_EQ(1, tensors_expected.size()); - ArithmeticOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + ArithmeticOptimizer optimizer; + EnableOnlyReplaceMulWithSquare(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); - EXPECT_EQ(5, output.node_size()); - EXPECT_EQ("id", output.node(3).name()); - EXPECT_EQ(OptimizedName("mul_square"), output.node(3).input(0)); - EXPECT_EQ("Square", output.node(4).op()); - EXPECT_EQ(OptimizedName("mul_square"), output.node(4).name()); - EXPECT_EQ(2, output.node(4).input_size()); - EXPECT_EQ("c", output.node(4).input(0)); - EXPECT_EQ("^d", output.node(4).input(1)); + EXPECT_EQ(4, output.node_size()); - auto tensors = EvaluateNodes(output, fetch); + NodeMap node_map(&output); + const string p = "ArithmeticOptimizer/ReplaceMulWithSquare"; + const NodeDef* square_node = node_map.GetNode(strings::StrCat(p, "_", "mul")); + + ASSERT_NE(square_node, nullptr); + EXPECT_EQ("Square", square_node->op()); + EXPECT_EQ("c", square_node->input(0)); + EXPECT_EQ("^d", square_node->input(1)); + + auto tensors = EvaluateNodes(output, item.fetch); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } @@ -386,12 +395,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AdjacentNodes) { auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1); auto id = ops::Identity(s.WithOpName("id"), recip2); - std::vector<string> fetch = {"id"}; - GrapplerItem item; - item.fetch = fetch; + item.fetch = {"id"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - auto tensors_expected = EvaluateNodes(item.graph, fetch); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); EXPECT_EQ(1, tensors_expected.size()); GraphDef output; @@ -404,7 +411,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AdjacentNodes) { EXPECT_EQ("id", output.node(1).name()); EXPECT_EQ("c", output.node(1).input(0)); - auto tensors = EvaluateNodes(output, fetch); + auto tensors = EvaluateNodes(output, item.fetch); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } |