diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2018-06-05 13:12:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-05 13:14:33 -0700 |
commit | c681be04ec15cdfc225bc61132420781bf23d298 (patch) | |
tree | e704da33b9a6cb2dbda0fe1fbe49d2e4e0a7a171 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | |
parent | b7928ac78d3cd688967bcf4e5253e384b355070f (diff) |
Move SimplifyAggregation to separate aggregation stage.
PiperOrigin-RevId: 199346067
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 68 |
1 files changed, 48 insertions, 20 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index f15cbfe407..f79347cde6 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -40,21 +40,37 @@ constexpr char kHoistFactorOptimizerMul[] = constexpr char kHoistFactorOptimizerAdd[] = "ArithmeticOptimizer/HoistCommonFactor_Add_"; -// Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation +constexpr char kSimplifyAggregationConst[] = + "ArithmeticOptimizer/SimplifyAggregation_Const_"; + +constexpr char kSimplifyAggregationMul[] = + "ArithmeticOptimizer/SimplifyAggregation_Mul_"; + +// Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation. string HoistMulName(const string& name) { return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, ""); } -// Optimized name of outer Div node by HoistCommonFactorOutOfAggregation +// Optimized name of outer Div node by HoistCommonFactorOutOfAggregation. string HoistDivName(const string& name) { return AddPrefixToNodeName(name, kHoistFactorOptimizerDiv, ""); } -// Optimized name of inner Add node by HoistCommonFactorOutOfAggregation +// Optimized name of inner Add node by HoistCommonFactorOutOfAggregation. string HoistAddName(const string& name) { return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, ""); } +// Optimized name of Const node by SimplifyAggregation. +string AggregationConstName(const string& name) { + return AddPrefixToNodeName(name, kSimplifyAggregationConst, ""); +} + +// Optimized name of Mul node by SimplifyAggregation. +string AggregationMulName(const string& name) { + return AddPrefixToNodeName(name, kSimplifyAggregationMul, ""); +} + string OptimizedName(const string& name) { return AddPrefixToNodeName(name, kArithmeticOptimizer); } @@ -140,6 +156,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.remove_logical_not = false; options.reorder_cast_and_transpose = false; options.replace_mul_with_square = false; + options.simplify_aggregation = false; optimizer->options_ = options; } @@ -226,6 +243,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.remove_logical_not = true; } + + void EnableOnlySimplifyAggregation(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.simplify_aggregation = true; + } }; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -500,10 +522,10 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) { Output id = ops::Identity(s.WithOpName("id"), add); GrapplerItem item; + item.fetch = {"id"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - std::vector<string> fetch = {"id"}; - auto tensors_expected = EvaluateNodes(item.graph, fetch); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); EXPECT_EQ(1, tensors_expected.size()); ArithmeticOptimizer optimizer; @@ -513,22 +535,25 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) { EXPECT_EQ(5, output.node_size()); - const NodeDef* new_const = node_map.GetNode(OptimizedName("add_const")); + const string optimized_const_name = AggregationConstName("add"); + const string optimized_mul_name = AggregationMulName("add"); + + const NodeDef* new_const = node_map.GetNode(optimized_const_name); ASSERT_NE(new_const, nullptr); EXPECT_EQ("^x", new_const->input(0)); EXPECT_EQ(std::string("\0\0\0@", 4), new_const->attr().at("value").tensor().tensor_content()); - const NodeDef* new_mul = node_map.GetNode(OptimizedName("add_mul")); + const NodeDef* new_mul = node_map.GetNode(optimized_mul_name); ASSERT_NE(new_mul, nullptr); - EXPECT_EQ(OptimizedName("add_const"), new_mul->input(0)); + EXPECT_EQ(optimized_const_name, new_mul->input(0)); EXPECT_EQ("x", new_mul->input(1)); const NodeDef* new_id = node_map.GetNode("id"); ASSERT_NE(new_id, nullptr); - EXPECT_EQ(OptimizedName("add_mul"), new_id->input(0)); + EXPECT_EQ(optimized_mul_name, new_id->input(0)); - auto tensors = EvaluateNodes(output, fetch); + auto tensors = EvaluateNodes(output, item.fetch); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } @@ -554,21 +579,24 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) { EXPECT_EQ(6, output.node_size()); - const NodeDef* new_const = node_map.GetNode(OptimizedName("add_const")); + const string optimized_const_name = AggregationConstName("add"); + const string optimized_mul_name = AggregationMulName("add"); + + const NodeDef* new_const = node_map.GetNode(optimized_const_name); ASSERT_NE(new_const, nullptr); EXPECT_EQ("^x", new_const->input(0)); EXPECT_EQ(std::string("\0\0\0@", 4), new_const->attr().at("value").tensor().tensor_content()); - const NodeDef* new_mul = node_map.GetNode(OptimizedName("add_mul")); + const NodeDef* new_mul = node_map.GetNode(optimized_mul_name); ASSERT_NE(new_mul, nullptr); - EXPECT_EQ(OptimizedName("add_const"), new_mul->input(0)); + EXPECT_EQ(optimized_const_name, new_mul->input(0)); EXPECT_EQ("x", new_mul->input(1)); EXPECT_EQ("^y", new_mul->input(2)); const NodeDef* new_id = node_map.GetNode("id"); ASSERT_NE(new_id, nullptr); - EXPECT_EQ(OptimizedName("add_mul"), new_id->input(0)); + EXPECT_EQ(optimized_mul_name, new_id->input(0)); auto tensors = EvaluateNodes(output, fetch); EXPECT_EQ(1, tensors.size()); @@ -633,24 +661,24 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { ASSERT_NE(add_4_node, nullptr); EXPECT_EQ("Add", add_4_node->op()); EXPECT_EQ(2, add_4_node->input_size()); - EXPECT_EQ(OptimizedName("Add_const"), add_4_node->input(0)); - EXPECT_EQ(OptimizedName("Add_1_const"), add_4_node->input(1)); + EXPECT_EQ(AggregationConstName("Add"), add_4_node->input(0)); + EXPECT_EQ(AggregationConstName("Add_1"), add_4_node->input(1)); const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5")); ASSERT_NE(add_5_node, nullptr); EXPECT_EQ("Add", add_5_node->op()); EXPECT_EQ(2, add_5_node->input_size()); - EXPECT_EQ(OptimizedName("Add_const"), add_5_node->input(0)); - EXPECT_EQ(OptimizedName("Add_1_const"), add_5_node->input(1)); + EXPECT_EQ(AggregationConstName("Add"), add_5_node->input(0)); + EXPECT_EQ(AggregationConstName("Add_1"), add_5_node->input(1)); - const NodeDef* add_const_node = node_map.GetNode(OptimizedName("Add_const")); + const NodeDef* add_const_node = node_map.GetNode(AggregationConstName("Add")); ASSERT_NE(add_const_node, nullptr); EXPECT_EQ("Const", add_const_node->op()); EXPECT_EQ(1, add_const_node->input_size()); EXPECT_EQ("^Placeholder", add_const_node->input(0)); const NodeDef* add_1_const_node = - node_map.GetNode(OptimizedName("Add_1_const")); + node_map.GetNode(AggregationConstName("Add_1")); ASSERT_NE(add_1_const_node, nullptr); EXPECT_EQ("Const", add_1_const_node->op()); EXPECT_EQ(1, add_1_const_node->input_size()); |