aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2018-06-05 13:12:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-05 13:14:33 -0700
commitc681be04ec15cdfc225bc61132420781bf23d298 (patch)
treee704da33b9a6cb2dbda0fe1fbe49d2e4e0a7a171 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
parentb7928ac78d3cd688967bcf4e5253e384b355070f (diff)
Move SimplifyAggregation to separate aggregation stage.
PiperOrigin-RevId: 199346067
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc68
1 files changed, 48 insertions, 20 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index f15cbfe407..f79347cde6 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -40,21 +40,37 @@ constexpr char kHoistFactorOptimizerMul[] =
constexpr char kHoistFactorOptimizerAdd[] =
"ArithmeticOptimizer/HoistCommonFactor_Add_";
-// Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation
+constexpr char kSimplifyAggregationConst[] =
+ "ArithmeticOptimizer/SimplifyAggregation_Const_";
+
+constexpr char kSimplifyAggregationMul[] =
+ "ArithmeticOptimizer/SimplifyAggregation_Mul_";
+
+// Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation.
string HoistMulName(const string& name) {
return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, "");
}
-// Optimized name of outer Div node by HoistCommonFactorOutOfAggregation
+// Optimized name of outer Div node by HoistCommonFactorOutOfAggregation.
string HoistDivName(const string& name) {
return AddPrefixToNodeName(name, kHoistFactorOptimizerDiv, "");
}
-// Optimized name of inner Add node by HoistCommonFactorOutOfAggregation
+// Optimized name of inner Add node by HoistCommonFactorOutOfAggregation.
string HoistAddName(const string& name) {
return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, "");
}
+// Optimized name of Const node by SimplifyAggregation.
+string AggregationConstName(const string& name) {
+ return AddPrefixToNodeName(name, kSimplifyAggregationConst, "");
+}
+
+// Optimized name of Mul node by SimplifyAggregation.
+string AggregationMulName(const string& name) {
+ return AddPrefixToNodeName(name, kSimplifyAggregationMul, "");
+}
+
string OptimizedName(const string& name) {
return AddPrefixToNodeName(name, kArithmeticOptimizer);
}
@@ -140,6 +156,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
options.remove_logical_not = false;
options.reorder_cast_and_transpose = false;
options.replace_mul_with_square = false;
+ options.simplify_aggregation = false;
optimizer->options_ = options;
}
@@ -226,6 +243,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
DisableAllStages(optimizer);
optimizer->options_.remove_logical_not = true;
}
+
+ void EnableOnlySimplifyAggregation(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.simplify_aggregation = true;
+ }
};
TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -500,10 +522,10 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
Output id = ops::Identity(s.WithOpName("id"), add);
GrapplerItem item;
+ item.fetch = {"id"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- std::vector<string> fetch = {"id"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(1, tensors_expected.size());
ArithmeticOptimizer optimizer;
@@ -513,22 +535,25 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
EXPECT_EQ(5, output.node_size());
- const NodeDef* new_const = node_map.GetNode(OptimizedName("add_const"));
+ const string optimized_const_name = AggregationConstName("add");
+ const string optimized_mul_name = AggregationMulName("add");
+
+ const NodeDef* new_const = node_map.GetNode(optimized_const_name);
ASSERT_NE(new_const, nullptr);
EXPECT_EQ("^x", new_const->input(0));
EXPECT_EQ(std::string("\0\0\0@", 4),
new_const->attr().at("value").tensor().tensor_content());
- const NodeDef* new_mul = node_map.GetNode(OptimizedName("add_mul"));
+ const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
ASSERT_NE(new_mul, nullptr);
- EXPECT_EQ(OptimizedName("add_const"), new_mul->input(0));
+ EXPECT_EQ(optimized_const_name, new_mul->input(0));
EXPECT_EQ("x", new_mul->input(1));
const NodeDef* new_id = node_map.GetNode("id");
ASSERT_NE(new_id, nullptr);
- EXPECT_EQ(OptimizedName("add_mul"), new_id->input(0));
+ EXPECT_EQ(optimized_mul_name, new_id->input(0));
- auto tensors = EvaluateNodes(output, fetch);
+ auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
@@ -554,21 +579,24 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
EXPECT_EQ(6, output.node_size());
- const NodeDef* new_const = node_map.GetNode(OptimizedName("add_const"));
+ const string optimized_const_name = AggregationConstName("add");
+ const string optimized_mul_name = AggregationMulName("add");
+
+ const NodeDef* new_const = node_map.GetNode(optimized_const_name);
ASSERT_NE(new_const, nullptr);
EXPECT_EQ("^x", new_const->input(0));
EXPECT_EQ(std::string("\0\0\0@", 4),
new_const->attr().at("value").tensor().tensor_content());
- const NodeDef* new_mul = node_map.GetNode(OptimizedName("add_mul"));
+ const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
ASSERT_NE(new_mul, nullptr);
- EXPECT_EQ(OptimizedName("add_const"), new_mul->input(0));
+ EXPECT_EQ(optimized_const_name, new_mul->input(0));
EXPECT_EQ("x", new_mul->input(1));
EXPECT_EQ("^y", new_mul->input(2));
const NodeDef* new_id = node_map.GetNode("id");
ASSERT_NE(new_id, nullptr);
- EXPECT_EQ(OptimizedName("add_mul"), new_id->input(0));
+ EXPECT_EQ(optimized_mul_name, new_id->input(0));
auto tensors = EvaluateNodes(output, fetch);
EXPECT_EQ(1, tensors.size());
@@ -633,24 +661,24 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
ASSERT_NE(add_4_node, nullptr);
EXPECT_EQ("Add", add_4_node->op());
EXPECT_EQ(2, add_4_node->input_size());
- EXPECT_EQ(OptimizedName("Add_const"), add_4_node->input(0));
- EXPECT_EQ(OptimizedName("Add_1_const"), add_4_node->input(1));
+ EXPECT_EQ(AggregationConstName("Add"), add_4_node->input(0));
+ EXPECT_EQ(AggregationConstName("Add_1"), add_4_node->input(1));
const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5"));
ASSERT_NE(add_5_node, nullptr);
EXPECT_EQ("Add", add_5_node->op());
EXPECT_EQ(2, add_5_node->input_size());
- EXPECT_EQ(OptimizedName("Add_const"), add_5_node->input(0));
- EXPECT_EQ(OptimizedName("Add_1_const"), add_5_node->input(1));
+ EXPECT_EQ(AggregationConstName("Add"), add_5_node->input(0));
+ EXPECT_EQ(AggregationConstName("Add_1"), add_5_node->input(1));
- const NodeDef* add_const_node = node_map.GetNode(OptimizedName("Add_const"));
+ const NodeDef* add_const_node = node_map.GetNode(AggregationConstName("Add"));
ASSERT_NE(add_const_node, nullptr);
EXPECT_EQ("Const", add_const_node->op());
EXPECT_EQ(1, add_const_node->input_size());
EXPECT_EQ("^Placeholder", add_const_node->input(0));
const NodeDef* add_1_const_node =
- node_map.GetNode(OptimizedName("Add_1_const"));
+ node_map.GetNode(AggregationConstName("Add_1"));
ASSERT_NE(add_1_const_node, nullptr);
EXPECT_EQ("Const", add_1_const_node->op());
EXPECT_EQ(1, add_1_const_node->input_size());