aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-24 13:21:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-24 13:24:49 -0700
commit4355b923c273a4e07655f860a95428b2db977741 (patch)
treec714250d587cb5672f59e2ecda107c416254a1d8 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
parent893aa776009418c841d49c924207f3cdaf1d5174 (diff)
Implement hoisting of common prefix of unary ops to concat.
PiperOrigin-RevId: 194135148
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc102
1 files changed, 102 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index cb1f2ea732..df10dbdf48 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -98,6 +98,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
// should explicitly enable required optimization for tests isolation
void DisableAllStages(ArithmeticOptimizer* optimizer) {
ArithmeticOptimizer::ArithmeticOptimizerOptions options;
+ options.dedup_computations = false;
options.enable_try_simplify_and_replace = false;
options.combine_add_to_addn = false;
options.hoist_common_factor_out_of_aggregation = false;
@@ -147,6 +148,10 @@ class ArithmeticOptimizerTest : public GrapplerTest {
DisableAllStages(optimizer);
optimizer->options_.remove_negation = true;
}
+ void EnableOnlyHoistCWiseUnaryFromConcat(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.hoist_unary_out_of_concat = true;
+ }
};
TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -2086,5 +2091,102 @@ TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
EXPECT_EQ("mul1", mul3_node->input(1));
}
+TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
+ Output b = ops::Variable(s.WithOpName("b"), {32}, DT_FLOAT);
+ Output c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
+ Output axis = ops::Const(s.WithOpName("axis"), 0, {});
+ Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
+ Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
+ Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
+ // Test case with chains of length 1.
+ Output sin_a =
+ ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl3), a);
+ Output exp_a =
+ ops::Exp(s.WithOpName("exp_a").WithControlDependencies(ctrl1), sin_a);
+ Output exp_b = ops::Exp(s.WithOpName("exp_b"), b);
+ Output exp_c =
+ ops::Exp(s.WithOpName("exp_c").WithControlDependencies(ctrl2), c);
+ Output concat =
+ ops::Concat(s.WithOpName("concat"), {exp_a, exp_b, exp_c}, axis);
+ Output id = ops::Identity(s.WithOpName("id"), concat);
+
+ // Test case with chains of length 2.
+ Output exp_a2 =
+ ops::Exp(s.WithOpName("exp_a2").WithControlDependencies(ctrl1), sin_a);
+ Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), b);
+ Output exp_c2 =
+ ops::Exp(s.WithOpName("exp_c2").WithControlDependencies(ctrl2), c);
+ Output cos_exp_a2 = ops::Cos(
+ s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl1), exp_a2);
+ Output cos_exp_b2 = ops::Cos(
+ s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
+ Output cos_exp_c2 = ops::Cos(s.WithOpName("cos_exp_c2"), exp_c2);
+ Output concat2 = ops::Concat(s.WithOpName("concat2"),
+ {cos_exp_a2, cos_exp_b2, cos_exp_c2}, axis);
+ Output id2 = ops::Identity(s.WithOpName("id2"), concat2);
+ GrapplerItem item;
+ item.fetch = {"id", "id2"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyHoistCWiseUnaryFromConcat(&optimizer);
+
+ OptimizeAndPrune(&optimizer, &item, &output);
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "concat") {
+ EXPECT_EQ(6, node.input_size());
+ EXPECT_EQ("sin_a", node.input(0));
+ EXPECT_EQ("b", node.input(1));
+ EXPECT_EQ("c", node.input(2));
+ EXPECT_EQ("axis", node.input(3));
+ EXPECT_EQ("^ctrl1", node.input(4));
+ EXPECT_EQ("^ctrl2", node.input(5));
+ found++;
+ }
+ if (node.name() == "exp_a") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("concat", node.input(0));
+ found++;
+ }
+ if (node.name() == "id") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("exp_a", node.input(0));
+ found++;
+ }
+
+ if (node.name() == "concat2") {
+ EXPECT_EQ(7, node.input_size());
+ EXPECT_EQ("sin_a", node.input(0));
+ EXPECT_EQ("b", node.input(1));
+ EXPECT_EQ("c", node.input(2));
+ EXPECT_EQ("axis", node.input(3));
+ EXPECT_EQ("^ctrl1", node.input(4));
+ EXPECT_EQ("^ctrl2", node.input(5));
+ EXPECT_EQ("^ctrl3", node.input(6));
+ found++;
+ }
+ if (node.name() == "exp_a2") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("concat2", node.input(0));
+ found++;
+ }
+ if (node.name() == "cos_exp_a2") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("exp_a2", node.input(0));
+ found++;
+ }
+ if (node.name() == "id2") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("cos_exp_a2", node.input(0));
+ found++;
+ }
+ }
+ EXPECT_EQ(7, found);
+}
+
} // namespace grappler
} // namespace tensorflow