aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-10 11:09:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-10 11:12:23 -0700
commitafa17984849881f39fb56c6e3500d866539924d5 (patch)
treefc1d4fddfcde926ad23f2fede369c30715973d47 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
parentc276b8314cd3161c5626d845edcfb6697cefd043 (diff)
Adds support for hoisting out common denominator in arithmetic_optimizer
PiperOrigin-RevId: 192314177
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc85
1 files changed, 84 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 9677175d2e..e639812858 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -31,6 +31,9 @@ namespace grappler {
namespace {
+constexpr char kHoistFactorOptimizerDiv[] =
+ "ArithmeticOptimizer/HoistCommonFactor_Div_";
+
constexpr char kHoistFactorOptimizerMul[] =
"ArithmeticOptimizer/HoistCommonFactor_Mul_";
@@ -42,6 +45,11 @@ string HoistMulName(const string& name) {
return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, "");
}
+// Optimized name of outer Div node by HoistCommonFactorOutOfAggregation
+string HoistDivName(const string& name) {
+ return AddPrefixToNodeName(name, kHoistFactorOptimizerDiv, "");
+}
+
// Optimized name of inner Add node by HoistCommonFactorOutOfAggregation
string HoistAddName(const string& name) {
return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, "");
@@ -558,7 +566,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
EXPECT_EQ("^Placeholder", add_1_const_node->input(0));
}
-TEST_F(ArithmeticOptimizerTest, HoistFactor) {
+TEST_F(ArithmeticOptimizerTest, HoistFactorMul) {
for (bool matching_shapes : {true, false}) {
for (bool use_addn : {true, false}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
@@ -625,6 +633,81 @@ TEST_F(ArithmeticOptimizerTest, HoistFactor) {
}
}
+TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) {
+ for (bool matching_shapes : {true, false}) {
+ for (bool use_addn : {true, false}) {
+ for (bool use_ints : {true, false}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = use_ints
+ ? ops::Const(s.WithOpName("x"), {1, 2}, {1, 2})
+ : ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output y1 = use_ints
+ ? ops::Const(s.WithOpName("y1"), {3, 4}, {1, 2})
+ : ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
+ Output y2;
+ if (matching_shapes) {
+ y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5, 6}, {1, 2})
+ : ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2});
+ } else {
+ y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5}, {1, 1})
+ : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
+ }
+ Output div1 = ops::Div(s.WithOpName("div1"), y1, x);
+ Output div2 = ops::Div(s.WithOpName("div2"), y2, x);
+ Output id =
+ use_addn
+ ? ops::Identity(s.WithOpName("id"),
+ ops::AddN(s.WithOpName("add"), {div1, div2}))
+ : ops::Identity(s.WithOpName("id"),
+ ops::Add(s.WithOpName("add"), div1, div2));
+
+ GrapplerItem item;
+ item.fetch = {"id"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ ArithmeticOptimizer optimizer;
+ EnableOnlyHoistCommonFactor(&optimizer);
+
+ GraphDef output;
+ OptimizeTwice(&optimizer, &item, &output);
+
+ // We expect the following rewrite(s) to occur:
+ //
+ // Add Div
+ // / \ / \
+ // Div Div -> Add x
+ // / \ / \ / \
+ // y1 x y2 x y1 y2
+ //
+ // If "root" op is AddN and shapes does not match, this rewrite is not
+ // possible and graph should stay intact.
+ NodeMap node_map(&output);
+
+ if ((use_addn && !matching_shapes) || use_ints) {
+ VerifyGraphsMatch(item.graph, output, __LINE__);
+ } else {
+ EXPECT_EQ(9, output.node_size());
+
+ const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
+ ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found";
+ EXPECT_EQ("y1", new_add_node->input(0));
+ EXPECT_EQ("y2", new_add_node->input(1));
+
+ const NodeDef* new_div_node = node_map.GetNode(HoistDivName("add"));
+ ASSERT_TRUE(new_div_node != nullptr) << "Hoisted Div node not found";
+ EXPECT_EQ(new_add_node->name(), new_div_node->input(0));
+ EXPECT_EQ("x", new_div_node->input(1));
+
+ const NodeDef* id_node = node_map.GetNode("id");
+ ASSERT_TRUE(id_node != nullptr) << "Id node not found";
+ EXPECT_EQ("id", id_node->name());
+ EXPECT_EQ(HoistDivName("add"), id_node->input(0));
+ }
+ }
+ }
+ }
+}
+
TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});