aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-02 12:41:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-02 12:44:20 -0700
commitc8064f1ac3c42951aa1593260346b75d306ffe95 (patch)
tree60a9e2f15d6982d6c47878862dae5bed324476dc /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
parentfc34c057d9d1118477b3e02870b97305c2d1af86 (diff)
Rewrite Add/AddN subgraph, minimizing number of required broadcasts.
1) Collect to AddOpsGroup inputs of symbolically defined shapes, that can be broadcasted to the root shape 2) Rewrite equal shapes with AddN(s) 3) Build Add tree from aggegations of different shapes, minimizing the cost of broadcast PiperOrigin-RevId: 191331566
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc441
1 files changed, 296 insertions, 145 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index ad3edc144a..ef3ed35fa6 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -156,25 +156,23 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
item.fetch = {"div"};
- ArithmeticOptimizer optimizer;
- GraphDef output;
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {});
EXPECT_EQ(1, tensors_expected.size());
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ ArithmeticOptimizer optimizer;
+ GraphDef output;
+ OptimizeTwice(&optimizer, &item, &output);
+ NodeMap node_map(&output);
EXPECT_EQ(2, output.node_size());
- const NodeDef& new_c1 = output.node(0);
- EXPECT_EQ("c1", new_c1.name());
- const NodeDef& new_div = output.node(1);
- EXPECT_EQ("div", new_div.name());
- EXPECT_EQ(2, new_div.input_size());
- EXPECT_EQ("c1", new_div.input(0));
- EXPECT_EQ("c1", new_div.input(1));
+
+ const NodeDef* new_c1 = node_map.GetNode("c1");
+ ASSERT_NE(new_c1, nullptr);
+
+ const NodeDef* new_div = node_map.GetNode("div");
+ ASSERT_NE(new_div, nullptr);
+ EXPECT_EQ(2, new_div->input_size());
+ EXPECT_EQ("c1", new_div->input(0));
+ EXPECT_EQ("c1", new_div->input(1));
auto tensors = EvaluateNodes(output, item.fetch, {});
EXPECT_EQ(1, tensors.size());
@@ -198,20 +196,18 @@ TEST_F(ArithmeticOptimizerTest, OpDeduppingAssertAndCheckNumerics) {
ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+
+ OptimizeTwice(&optimizer, &item, &output);
+ NodeMap node_map(&output);
EXPECT_EQ(5, output.node_size());
- const NodeDef& new_div = output.node(3);
- EXPECT_EQ(4, new_div.input_size());
- EXPECT_EQ("check1", new_div.input(0));
- EXPECT_EQ("check1", new_div.input(1));
- EXPECT_EQ("^assert1", new_div.input(2));
- EXPECT_EQ("^assert1", new_div.input(3));
+ const NodeDef* new_div = node_map.GetNode("div");
+ ASSERT_NE(new_div, nullptr);
+ EXPECT_EQ(4, new_div->input_size());
+ EXPECT_EQ("check1", new_div->input(0));
+ EXPECT_EQ("check1", new_div->input(1));
+ EXPECT_EQ("^assert1", new_div->input(2));
+ EXPECT_EQ("^assert1", new_div->input(3));
}
TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
@@ -227,28 +223,24 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
+ NodeMap node_map(&output);
EXPECT_EQ(4, output.node_size());
- const NodeDef& new_c1 = output.node(0);
- EXPECT_EQ("c1", new_c1.name());
- const NodeDef& new_c2 = output.node(1);
- EXPECT_EQ("c2", new_c2.name());
- const NodeDef& new_mul1 = output.node(2);
- EXPECT_EQ("mul1", new_mul1.name());
- EXPECT_EQ(2, new_mul1.input_size());
- EXPECT_EQ("c1", new_mul1.input(0));
- EXPECT_EQ("c2", new_mul1.input(1));
- const NodeDef& new_div1 = output.node(3);
- EXPECT_EQ("div1", new_div1.name());
- EXPECT_EQ(2, new_div1.input_size());
- EXPECT_EQ("mul1", new_div1.input(0));
- EXPECT_EQ("mul1", new_div1.input(1));
+ const NodeDef* new_c1 = node_map.GetNode("c1");
+ ASSERT_NE(new_c1, nullptr);
+ const NodeDef* new_c2 = node_map.GetNode("c2");
+ ASSERT_NE(new_c2, nullptr);
+ const NodeDef* new_mul1 = node_map.GetNode("mul1");
+ ASSERT_NE(new_mul1, nullptr);
+ EXPECT_EQ(2, new_mul1->input_size());
+ EXPECT_EQ("c1", new_mul1->input(0));
+ EXPECT_EQ("c2", new_mul1->input(1));
+ const NodeDef* new_div1 = node_map.GetNode("div1");
+ ASSERT_NE(new_div1, nullptr);
+ EXPECT_EQ(2, new_div1->input_size());
+ EXPECT_EQ("mul1", new_div1->input(0));
+ EXPECT_EQ("mul1", new_div1->input(1));
}
TEST_F(ArithmeticOptimizerTest, MulToSquare) {
@@ -364,26 +356,25 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
+ NodeMap node_map(&output);
EXPECT_EQ(5, output.node_size());
- const NodeDef& new_const = output.node(3);
- EXPECT_EQ(OptimizedName("add_const"), new_const.name());
- EXPECT_EQ("^x", new_const.input(0));
+
+ const NodeDef* new_const = node_map.GetNode(OptimizedName("add_const"));
+ ASSERT_NE(new_const, nullptr);
+ EXPECT_EQ("^x", new_const->input(0));
EXPECT_EQ(std::string("\0\0\0@", 4),
- new_const.attr().at("value").tensor().tensor_content());
- const NodeDef& new_mul = output.node(4);
- EXPECT_EQ(OptimizedName("add_mul"), new_mul.name());
- EXPECT_EQ(OptimizedName("add_const"), new_mul.input(0));
- EXPECT_EQ("x", new_mul.input(1));
- const NodeDef& new_id = output.node(2);
- EXPECT_EQ("id", new_id.name());
- EXPECT_EQ(OptimizedName("add_mul"), new_id.input(0));
+ new_const->attr().at("value").tensor().tensor_content());
+
+ const NodeDef* new_mul = node_map.GetNode(OptimizedName("add_mul"));
+ ASSERT_NE(new_mul, nullptr);
+ EXPECT_EQ(OptimizedName("add_const"), new_mul->input(0));
+ EXPECT_EQ("x", new_mul->input(1));
+
+ const NodeDef* new_id = node_map.GetNode("id");
+ ASSERT_NE(new_id, nullptr);
+ EXPECT_EQ(OptimizedName("add_mul"), new_id->input(0));
}
TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
@@ -398,27 +389,26 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
+ NodeMap node_map(&output);
EXPECT_EQ(6, output.node_size());
- const NodeDef& new_const = output.node(4);
- EXPECT_EQ(OptimizedName("add_const"), new_const.name());
- EXPECT_EQ("^x", new_const.input(0));
+
+ const NodeDef* new_const = node_map.GetNode(OptimizedName("add_const"));
+ ASSERT_NE(new_const, nullptr);
+ EXPECT_EQ("^x", new_const->input(0));
EXPECT_EQ(std::string("\0\0\0@", 4),
- new_const.attr().at("value").tensor().tensor_content());
- const NodeDef& new_mul = output.node(5);
- EXPECT_EQ(OptimizedName("add_mul"), new_mul.name());
- EXPECT_EQ(OptimizedName("add_const"), new_mul.input(0));
- EXPECT_EQ("x", new_mul.input(1));
- EXPECT_EQ("^y", new_mul.input(2));
- const NodeDef& new_id = output.node(3);
- EXPECT_EQ("id", new_id.name());
- EXPECT_EQ(OptimizedName("add_mul"), new_id.input(0));
+ new_const->attr().at("value").tensor().tensor_content());
+
+ const NodeDef* new_mul = node_map.GetNode(OptimizedName("add_mul"));
+ ASSERT_NE(new_mul, nullptr);
+ EXPECT_EQ(OptimizedName("add_const"), new_mul->input(0));
+ EXPECT_EQ("x", new_mul->input(1));
+ EXPECT_EQ("^y", new_mul->input(2));
+
+ const NodeDef* new_id = node_map.GetNode("id");
+ ASSERT_NE(new_id, nullptr);
+ EXPECT_EQ(OptimizedName("add_mul"), new_id->input(0));
}
TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
@@ -458,25 +448,25 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
EXPECT_EQ(17, output.node_size());
const NodeDef* id_node = node_map.GetNode("id");
- ASSERT_TRUE(id_node != nullptr);
+ ASSERT_NE(id_node, nullptr);
EXPECT_EQ(1, id_node->input_size());
EXPECT_EQ(HoistMulName("Add_6"), id_node->input(0));
const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6"));
- ASSERT_TRUE(mul_node != nullptr);
+ ASSERT_NE(mul_node, nullptr);
EXPECT_EQ(2, mul_node->input_size());
EXPECT_EQ("Placeholder", mul_node->input(0));
EXPECT_EQ(HoistAddName("Add_6"), mul_node->input(1));
const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6"));
- ASSERT_TRUE(add_6_node != nullptr);
+ ASSERT_NE(add_6_node, nullptr);
EXPECT_EQ(3, add_6_node->input_size());
EXPECT_EQ(HoistAddName("Add_4"), add_6_node->input(0));
EXPECT_EQ(HoistAddName("Add_5"), add_6_node->input(1));
EXPECT_EQ("^Placeholder", add_6_node->input(2));
const NodeDef* add_4_node = node_map.GetNode(HoistAddName("Add_4"));
- ASSERT_TRUE(add_4_node != nullptr);
+ ASSERT_NE(add_4_node, nullptr);
EXPECT_EQ("Add", add_4_node->op());
EXPECT_EQ(3, add_4_node->input_size());
EXPECT_EQ(OptimizedName("Add_const"), add_4_node->input(0));
@@ -484,7 +474,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
EXPECT_EQ("^Placeholder", add_4_node->input(2));
const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5"));
- ASSERT_TRUE(add_5_node != nullptr);
+ ASSERT_NE(add_5_node, nullptr);
EXPECT_EQ("Add", add_5_node->op());
EXPECT_EQ(3, add_5_node->input_size());
EXPECT_EQ(OptimizedName("Add_const"), add_5_node->input(0));
@@ -492,14 +482,14 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
EXPECT_EQ("^Placeholder", add_5_node->input(2));
const NodeDef* add_const_node = node_map.GetNode(OptimizedName("Add_const"));
- ASSERT_TRUE(add_const_node != nullptr);
+ ASSERT_NE(add_const_node, nullptr);
EXPECT_EQ("Const", add_const_node->op());
EXPECT_EQ(1, add_const_node->input_size());
EXPECT_EQ("^Placeholder", add_const_node->input(0));
const NodeDef* add_1_const_node =
node_map.GetNode(OptimizedName("Add_1_const"));
- ASSERT_TRUE(add_1_const_node != nullptr);
+ ASSERT_NE(add_1_const_node, nullptr);
EXPECT_EQ("Const", add_1_const_node->op());
EXPECT_EQ(1, add_1_const_node->input_size());
EXPECT_EQ("^Placeholder", add_1_const_node->input(0));
@@ -550,17 +540,17 @@ TEST_F(ArithmeticOptimizerTest, HoistFactor) {
EXPECT_EQ(9, output.node_size());
const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
- ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found";
+ ASSERT_NE(new_add_node, nullptr) << "Hoisted Add node not found";
EXPECT_EQ("y1", new_add_node->input(0));
EXPECT_EQ("y2", new_add_node->input(1));
const NodeDef* new_mul_node = node_map.GetNode(HoistMulName("add"));
- ASSERT_TRUE(new_mul_node != nullptr) << "Hoisted Mul node not found";
+ ASSERT_NE(new_mul_node, nullptr) << "Hoisted Mul node not found";
EXPECT_EQ("x", new_mul_node->input(0));
EXPECT_EQ(new_add_node->name(), new_mul_node->input(1));
const NodeDef* id_node = node_map.GetNode("id");
- ASSERT_TRUE(id_node != nullptr) << "Id node not found";
+ ASSERT_NE(id_node, nullptr) << "Id node not found";
EXPECT_EQ("id", id_node->name());
EXPECT_EQ(HoistMulName("add"), id_node->input(0));
}
@@ -581,18 +571,17 @@ TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
+ NodeMap node_map(&output);
EXPECT_EQ(7, output.node_size());
- EXPECT_EQ(OptimizedName("trans_fused"), output.node(6).name());
- EXPECT_EQ("ConjugateTranspose", output.node(6).op());
- EXPECT_EQ("z", output.node(6).input(0));
- EXPECT_EQ("perm", output.node(6).input(1));
+
+ const NodeDef* trans_fused_node =
+ node_map.GetNode(OptimizedName("trans_fused"));
+ ASSERT_NE(trans_fused_node, nullptr);
+ EXPECT_EQ("ConjugateTranspose", trans_fused_node->op());
+ EXPECT_EQ("z", trans_fused_node->input(0));
+ EXPECT_EQ("perm", trans_fused_node->input(1));
}
TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
@@ -609,14 +598,16 @@ TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
+ NodeMap node_map(&output);
EXPECT_EQ(7, output.node_size());
- EXPECT_EQ(OptimizedName("conjugate_trans_fused"), output.node(6).name());
- EXPECT_EQ("Transpose", output.node(6).op());
- EXPECT_EQ("z", output.node(6).input(0));
- EXPECT_EQ("perm", output.node(6).input(1));
+
+ const NodeDef* conjugate_trans_fused_node =
+ node_map.GetNode(OptimizedName("conjugate_trans_fused"));
+ EXPECT_EQ("Transpose", conjugate_trans_fused_node->op());
+ EXPECT_EQ("z", conjugate_trans_fused_node->input(0));
+ EXPECT_EQ("perm", conjugate_trans_fused_node->input(1));
}
TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
@@ -632,18 +623,16 @@ TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
+ NodeMap node_map(&output);
EXPECT_EQ(7, output.node_size());
- EXPECT_EQ(OptimizedName("conj_fused"), output.node(6).name());
- EXPECT_EQ("ConjugateTranspose", output.node(6).op());
- EXPECT_EQ("z", output.node(6).input(0));
- EXPECT_EQ("perm", output.node(6).input(1));
+
+ const NodeDef* conj_fused_node =
+ node_map.GetNode(OptimizedName("conj_fused"));
+ EXPECT_EQ("ConjugateTranspose", conj_fused_node->op());
+ EXPECT_EQ("z", conj_fused_node->input(0));
+ EXPECT_EQ("perm", conj_fused_node->input(1));
}
TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
@@ -668,23 +657,22 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
+ NodeMap node_map(&output);
EXPECT_EQ(7, output.node_size());
- EXPECT_EQ(OptimizedName("matmul_fused"), output.node(6).name());
- EXPECT_EQ("a", output.node(6).input(0));
- EXPECT_EQ("b", output.node(6).input(1));
+
+ const NodeDef* matmul_fused_node =
+ node_map.GetNode(OptimizedName("matmul_fused"));
+ ASSERT_NE(matmul_fused_node, nullptr);
+ EXPECT_EQ("a", matmul_fused_node->input(0));
+ EXPECT_EQ("b", matmul_fused_node->input(1));
if (matmul_type == "BatchMatMul") {
- EXPECT_TRUE(output.node(6).attr().at("adj_x").b());
- EXPECT_TRUE(output.node(6).attr().at("adj_y").b());
+ EXPECT_TRUE(matmul_fused_node->attr().at("adj_x").b());
+ EXPECT_TRUE(matmul_fused_node->attr().at("adj_y").b());
} else {
- EXPECT_TRUE(output.node(6).attr().at("transpose_a").b());
- EXPECT_TRUE(output.node(6).attr().at("transpose_b").b());
+ EXPECT_TRUE(matmul_fused_node->attr().at("transpose_a").b());
+ EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b());
}
}
}
@@ -1322,8 +1310,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) {
// check add tree was replaced with AddN
const NodeDef* collapsed_add =
- node_map.GetNode("y/ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab");
- ASSERT_TRUE(collapsed_add != nullptr);
+ node_map.GetNode("y/ArithmeticOptimizer/AddOpsRewrite_Add_abc");
+ ASSERT_NE(collapsed_add, nullptr);
EXPECT_EQ("AddN", collapsed_add->op());
EXPECT_EQ(3, collapsed_add->input_size());
@@ -1333,7 +1321,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) {
// check output was re-wired to new node
const NodeDef* updated_outputs = node_map.GetNode("outputs");
- ASSERT_TRUE(updated_outputs != nullptr);
+ ASSERT_NE(updated_outputs, nullptr);
EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
}
@@ -1381,8 +1369,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) {
// check left Add subtree replaced with AddN
const NodeDef* collapsed_left =
- node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab");
- ASSERT_TRUE(collapsed_left != nullptr);
+ node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
+ ASSERT_NE(collapsed_left, nullptr);
EXPECT_EQ("AddN", collapsed_left->op());
EXPECT_EQ(3, collapsed_left->input_size());
@@ -1392,8 +1380,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) {
// check right Add subtree replaced with AddN
const NodeDef* collapsed_right =
- node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_xyz_Add_xy");
- ASSERT_TRUE(collapsed_right != nullptr);
+ node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_xyz");
+ ASSERT_NE(collapsed_right, nullptr);
EXPECT_EQ("AddN", collapsed_right->op());
EXPECT_EQ(3, collapsed_right->input_size());
@@ -1403,7 +1391,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) {
// check that Mul inputs re-wired to new Nodes
const NodeDef* updated_mul = node_map.GetNode("Mul");
- ASSERT_TRUE(updated_mul != nullptr);
+ ASSERT_NE(updated_mul, nullptr);
EXPECT_EQ("Mul", updated_mul->op());
EXPECT_EQ(2, updated_mul->input_size());
@@ -1444,9 +1432,9 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddInputMultipleTimes) {
NodeMap node_map(&output);
// check Add tree replaced with AddN
- const NodeDef* collapsed_add = node_map.GetNode(
- "ArithmeticOptimizer/AddOpsRewrite_Add_all_Add_ab_Add_bc");
- ASSERT_TRUE(collapsed_add != nullptr);
+ const NodeDef* collapsed_add =
+ node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_all");
+ ASSERT_NE(collapsed_add, nullptr);
EXPECT_EQ("AddN", collapsed_add->op());
EXPECT_EQ(4, collapsed_add->input_size());
@@ -1496,8 +1484,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfSymbolicallyEqualShape) {
// check add tree was replaced with AddN
const NodeDef* collapsed_add =
- node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab");
- ASSERT_TRUE(collapsed_add != nullptr);
+ node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
+ ASSERT_NE(collapsed_add, nullptr);
EXPECT_EQ("AddN", collapsed_add->op());
EXPECT_EQ(3, collapsed_add->input_size());
EXPECT_EQ("a", collapsed_add->input(0));
@@ -1506,10 +1494,173 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfSymbolicallyEqualShape) {
// check output was re-wired to new node
const NodeDef* updated_outputs = node_map.GetNode("outputs");
- ASSERT_TRUE(updated_outputs != nullptr);
+ ASSERT_NE(updated_outputs, nullptr);
EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
}
+TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MinimizeBCast) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
+ auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
+ auto c = ops::Variable(s.WithOpName("c"), {32, 32, 32}, DT_FLOAT);
+ auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
+ auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
+
+ auto x = ops::Variable(s.WithOpName("x"), {32}, DT_FLOAT);
+ auto y = ops::Variable(s.WithOpName("y"), {32, 32}, DT_FLOAT);
+ auto z = ops::Variable(s.WithOpName("z"), {32, 32, 32}, DT_FLOAT);
+ auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
+ auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
+
+ auto add_all = ops::Add(s.WithOpName("AddAll"), add_abc, add_xyz);
+ auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyAddToAddNCombining(&optimizer);
+
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ // We expect the following rewrite(s) to occur:
+ // 1) [a, x], [b, y], [c, z] - aggregate same shapes first
+ // 2) Build an aggregation tree minimizing cost of broadcast
+ //
+ // + +
+ // / \ / \
+ // + + + AddN(c, z)
+ // / \ / \ / \
+ // + c x + --> AddN(a, x) AddN(b, y)
+ // / \ / \
+ // a b y z
+ EXPECT_EQ(12, output.node_size());
+ NodeMap node_map(&output);
+
+ // expected names of outer and inner nodes
+ string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_AddAll";
+ string outer_0_add_name =
+ "ArithmeticOptimizer/AddOpsRewrite_Internal_0_AddAll";
+ string inner_0_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_AddAll";
+ string inner_1_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_1_AddAll";
+ string inner_2_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_2_AddAll";
+
+ // Add [a, x] first
+ const NodeDef* add_ax_node = node_map.GetNode(inner_0_add_name);
+ ASSERT_NE(add_ax_node, nullptr);
+ EXPECT_EQ("AddN", add_ax_node->op());
+ EXPECT_EQ(2, add_ax_node->input_size());
+ EXPECT_EQ("a", add_ax_node->input(0));
+ EXPECT_EQ("x", add_ax_node->input(1));
+
+ // Then add [b, y]
+ const NodeDef* add_by_node = node_map.GetNode(inner_1_add_name);
+ ASSERT_NE(add_by_node, nullptr);
+ EXPECT_EQ("AddN", add_by_node->op());
+ EXPECT_EQ(2, add_by_node->input_size());
+ EXPECT_EQ("b", add_by_node->input(0));
+ EXPECT_EQ("y", add_by_node->input(1));
+
+ // Then add [c, z]
+ const NodeDef* add_cz_node = node_map.GetNode(inner_2_add_name);
+ ASSERT_NE(add_cz_node, nullptr);
+ EXPECT_EQ("AddN", add_cz_node->op());
+ EXPECT_EQ(2, add_cz_node->input_size());
+ EXPECT_EQ("c", add_cz_node->input(0));
+ EXPECT_EQ("z", add_cz_node->input(1));
+
+ // Then add results together starting from smaller shapes [a, x] + [b, y]
+ const NodeDef* outer_0_node = node_map.GetNode(outer_0_add_name);
+ ASSERT_NE(outer_0_node, nullptr);
+ EXPECT_EQ("Add", outer_0_node->op());
+ EXPECT_EQ(2, outer_0_node->input_size());
+ EXPECT_EQ(inner_0_add_name, outer_0_node->input(0));
+ EXPECT_EQ(inner_1_add_name, outer_0_node->input(1));
+
+ // And finally top level Add node
+ const NodeDef* outer_node = node_map.GetNode(outer_add_name);
+ ASSERT_NE(outer_node, nullptr);
+ EXPECT_EQ("Add", outer_node->op());
+ EXPECT_EQ(2, outer_node->input_size());
+ EXPECT_EQ(outer_0_add_name, outer_node->input(0));
+ EXPECT_EQ(inner_2_add_name, outer_node->input(1));
+
+ // And outputs reading new top level Add node
+ const NodeDef* updated_outputs = node_map.GetNode("outputs");
+ ASSERT_NE(updated_outputs, nullptr);
+ EXPECT_EQ(outer_add_name, updated_outputs->input(0));
+}
+
+TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MinimizeBCastWithSymbolicShapes) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ // We have a small input with one unknown dimension
+ auto small = ops::Variable(s.WithOpName("small"), {-1, 1, 1}, DT_FLOAT);
+
+ // And second input which is larger, but has the same unknown dimension
+ // device spec prevents this node from rewriting
+ auto d = "/job:do_not_rewrite_me";
+ auto v = ops::Variable(s.WithOpName("v"), {1, 32, 32}, DT_FLOAT);
+ auto large = ops::Add(s.WithOpName("large").WithDevice(d), small, v);
+
+ // [a, c] have {?, 1, 1} shape, [b] has {?, 32, 32}
+ auto a = ops::Sqrt(s.WithOpName("a"), small);
+ auto b = ops::Square(s.WithOpName("b"), large);
+ auto c = ops::Round(s.WithOpName("c"), small);
+
+ // [add_ab, add_abc] shape must be inferred from inputs
+ auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
+ auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
+
+ auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyAddToAddNCombining(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ // We expect the following rewrite(s) to occur: it's much cheaper to add small
+ // tensors, and do the broadcast just once
+ //
+ // + +
+ // / \ / \
+ // + c --> + b
+ // / \ / \
+ // a b a c
+ EXPECT_EQ(9, output.node_size());
+ NodeMap node_map(&output);
+
+ // expected names of outer and inner nodes
+ string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_Add_abc";
+ string inner_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_Add_abc";
+
+ // outer Add node
+ const NodeDef* outer_add = node_map.GetNode(outer_add_name);
+ ASSERT_NE(outer_add, nullptr);
+ EXPECT_EQ("Add", outer_add->op());
+ EXPECT_EQ(inner_add_name, outer_add->input(0));
+ EXPECT_EQ("b", outer_add->input(1));
+
+ // inner AddN node
+ const NodeDef* inner_add = node_map.GetNode(inner_add_name);
+ ASSERT_NE(inner_add, nullptr);
+ EXPECT_EQ(2, inner_add->input_size());
+ EXPECT_EQ("a", inner_add->input(0));
+ EXPECT_EQ("c", inner_add->input(1));
+
+ // check output was re-wired to new node
+ const NodeDef* updated_outputs = node_map.GetNode("outputs");
+ ASSERT_NE(updated_outputs, nullptr);
+ EXPECT_EQ(outer_add_name, updated_outputs->input(0));
+}
+
TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);