diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-24 13:21:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-24 13:24:49 -0700 |
commit | 4355b923c273a4e07655f860a95428b2db977741 (patch) | |
tree | c714250d587cb5672f59e2ecda107c416254a1d8 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | |
parent | 893aa776009418c841d49c924207f3cdaf1d5174 (diff) |
Implement hoisting of common prefix of unary ops to concat.
PiperOrigin-RevId: 194135148
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 102 |
1 files changed, 102 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index cb1f2ea732..df10dbdf48 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -98,6 +98,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { // should explicitly enable required optimization for tests isolation void DisableAllStages(ArithmeticOptimizer* optimizer) { ArithmeticOptimizer::ArithmeticOptimizerOptions options; + options.dedup_computations = false; options.enable_try_simplify_and_replace = false; options.combine_add_to_addn = false; options.hoist_common_factor_out_of_aggregation = false; @@ -147,6 +148,10 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.remove_negation = true; } + void EnableOnlyHoistCWiseUnaryFromConcat(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.hoist_unary_out_of_concat = true; + } }; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -2086,5 +2091,102 @@ TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) { EXPECT_EQ("mul1", mul3_node->input(1)); } +TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT); + Output b = ops::Variable(s.WithOpName("b"), {32}, DT_FLOAT); + Output c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT); + Output axis = ops::Const(s.WithOpName("axis"), 0, {}); + Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {}); + Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {}); + Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {}); + // Test case with chains of length 1. + Output sin_a = + ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl3), a); + Output exp_a = + ops::Exp(s.WithOpName("exp_a").WithControlDependencies(ctrl1), sin_a); + Output exp_b = ops::Exp(s.WithOpName("exp_b"), b); + Output exp_c = + ops::Exp(s.WithOpName("exp_c").WithControlDependencies(ctrl2), c); + Output concat = + ops::Concat(s.WithOpName("concat"), {exp_a, exp_b, exp_c}, axis); + Output id = ops::Identity(s.WithOpName("id"), concat); + + // Test case with chains of length 2. + Output exp_a2 = + ops::Exp(s.WithOpName("exp_a2").WithControlDependencies(ctrl1), sin_a); + Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), b); + Output exp_c2 = + ops::Exp(s.WithOpName("exp_c2").WithControlDependencies(ctrl2), c); + Output cos_exp_a2 = ops::Cos( + s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl1), exp_a2); + Output cos_exp_b2 = ops::Cos( + s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2); + Output cos_exp_c2 = ops::Cos(s.WithOpName("cos_exp_c2"), exp_c2); + Output concat2 = ops::Concat(s.WithOpName("concat2"), + {cos_exp_a2, cos_exp_b2, cos_exp_c2}, axis); + Output id2 = ops::Identity(s.WithOpName("id2"), concat2); + GrapplerItem item; + item.fetch = {"id", "id2"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyHoistCWiseUnaryFromConcat(&optimizer); + + OptimizeAndPrune(&optimizer, &item, &output); + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "concat") { + EXPECT_EQ(6, node.input_size()); + EXPECT_EQ("sin_a", node.input(0)); + EXPECT_EQ("b", node.input(1)); + EXPECT_EQ("c", node.input(2)); + EXPECT_EQ("axis", node.input(3)); + EXPECT_EQ("^ctrl1", node.input(4)); + EXPECT_EQ("^ctrl2", node.input(5)); + found++; + } + if (node.name() == "exp_a") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("concat", node.input(0)); + found++; + } + if (node.name() == "id") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("exp_a", node.input(0)); + found++; + } + + if (node.name() == "concat2") { + EXPECT_EQ(7, node.input_size()); + EXPECT_EQ("sin_a", node.input(0)); + EXPECT_EQ("b", node.input(1)); + EXPECT_EQ("c", node.input(2)); + EXPECT_EQ("axis", node.input(3)); + EXPECT_EQ("^ctrl1", node.input(4)); + EXPECT_EQ("^ctrl2", node.input(5)); + EXPECT_EQ("^ctrl3", node.input(6)); + found++; + } + if (node.name() == "exp_a2") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("concat2", node.input(0)); + found++; + } + if (node.name() == "cos_exp_a2") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("exp_a2", node.input(0)); + found++; + } + if (node.name() == "id2") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("cos_exp_a2", node.input(0)); + found++; + } + } + EXPECT_EQ(7, found); +} + } // namespace grappler } // namespace tensorflow |