diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-10 11:09:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-10 11:12:23 -0700 |
commit | afa17984849881f39fb56c6e3500d866539924d5 (patch) | |
tree | fc1d4fddfcde926ad23f2fede369c30715973d47 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | |
parent | c276b8314cd3161c5626d845edcfb6697cefd043 (diff) |
Adds support for hoisting out common denominator in arithmetic_optimizer
PiperOrigin-RevId: 192314177
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 85 |
1 files changed, 84 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 9677175d2e..e639812858 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -31,6 +31,9 @@ namespace grappler { namespace { +constexpr char kHoistFactorOptimizerDiv[] = + "ArithmeticOptimizer/HoistCommonFactor_Div_"; + constexpr char kHoistFactorOptimizerMul[] = "ArithmeticOptimizer/HoistCommonFactor_Mul_"; @@ -42,6 +45,11 @@ string HoistMulName(const string& name) { return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, ""); } +// Optimized name of outer Div node by HoistCommonFactorOutOfAggregation +string HoistDivName(const string& name) { + return AddPrefixToNodeName(name, kHoistFactorOptimizerDiv, ""); +} + // Optimized name of inner Add node by HoistCommonFactorOutOfAggregation string HoistAddName(const string& name) { return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, ""); @@ -558,7 +566,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { EXPECT_EQ("^Placeholder", add_1_const_node->input(0)); } -TEST_F(ArithmeticOptimizerTest, HoistFactor) { +TEST_F(ArithmeticOptimizerTest, HoistFactorMul) { for (bool matching_shapes : {true, false}) { for (bool use_addn : {true, false}) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -625,6 +633,81 @@ TEST_F(ArithmeticOptimizerTest, HoistFactor) { } } +TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) { + for (bool matching_shapes : {true, false}) { + for (bool use_addn : {true, false}) { + for (bool use_ints : {true, false}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = use_ints + ? ops::Const(s.WithOpName("x"), {1, 2}, {1, 2}) + : ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output y1 = use_ints + ? ops::Const(s.WithOpName("y1"), {3, 4}, {1, 2}) + : ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2}); + Output y2; + if (matching_shapes) { + y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5, 6}, {1, 2}) + : ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2}); + } else { + y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5}, {1, 1}) + : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1}); + } + Output div1 = ops::Div(s.WithOpName("div1"), y1, x); + Output div2 = ops::Div(s.WithOpName("div2"), y2, x); + Output id = + use_addn + ? ops::Identity(s.WithOpName("id"), + ops::AddN(s.WithOpName("add"), {div1, div2})) + : ops::Identity(s.WithOpName("id"), + ops::Add(s.WithOpName("add"), div1, div2)); + + GrapplerItem item; + item.fetch = {"id"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ArithmeticOptimizer optimizer; + EnableOnlyHoistCommonFactor(&optimizer); + + GraphDef output; + OptimizeTwice(&optimizer, &item, &output); + + // We expect the following rewrite(s) to occur: + // + // Add Div + // / \ / \ + // Div Div -> Add x + // / \ / \ / \ + // y1 x y2 x y1 y2 + // + // If "root" op is AddN and shapes does not match, this rewrite is not + // possible and graph should stay intact. + NodeMap node_map(&output); + + if ((use_addn && !matching_shapes) || use_ints) { + VerifyGraphsMatch(item.graph, output, __LINE__); + } else { + EXPECT_EQ(9, output.node_size()); + + const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add")); + ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found"; + EXPECT_EQ("y1", new_add_node->input(0)); + EXPECT_EQ("y2", new_add_node->input(1)); + + const NodeDef* new_div_node = node_map.GetNode(HoistDivName("add")); + ASSERT_TRUE(new_div_node != nullptr) << "Hoisted Div node not found"; + EXPECT_EQ(new_add_node->name(), new_div_node->input(0)); + EXPECT_EQ("x", new_div_node->input(1)); + + const NodeDef* id_node = node_map.GetNode("id"); + ASSERT_TRUE(id_node != nullptr) << "Id node not found"; + EXPECT_EQ("id", id_node->name()); + EXPECT_EQ(HoistDivName("add"), id_node->input(0)); + } + } + } + } +} + TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); |