diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-10 15:43:55 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-10 15:46:47 -0700 |
commit | f7e24ab1113ae7094e4831a606a29e0d5b956bfe (patch) | |
tree | e6e382c9fb8b747d844d145b15822464ffe853eb /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | |
parent | ff7f7a566b356a7e2de2b8f174d0f09e673179f4 (diff) |
Remove cancelling pairs of transposes that are separated by a non-branching chain of ops that preserve value, order, and shape. Off by default.
PiperOrigin-RevId: 196183111
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 43 |
1 files changed, 42 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index d60c3124ed..d648fa0787 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -1122,7 +1122,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposes) { ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT); Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4}); Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4}); - Output perm3 = ops::Const(s.WithOpName("perm2"), {0, 1, 2, 3}, {4}); + Output perm3 = ops::Const(s.WithOpName("perm3"), {0, 1, 2, 3}, {4}); Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm1); Output transpose2 = ops::Transpose(s.WithOpName("transpose2"), transpose1, perm2); @@ -1248,6 +1248,47 @@ TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) { EXPECT_EQ(6, output.node_size()); } +TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesThroughChain) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output inputs_shape = + ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4}); + Output inputs = + ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT); + Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4}); + Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4}); + Output transpose1 = ops::Transpose( + s.WithOpName("transpose1").WithControlDependencies(perm2), inputs, perm1); + Output identity = ops::Identity(s.WithOpName("id"), transpose1); + Output transpose2 = + ops::Transpose(s.WithOpName("transpose2"), identity, perm2); + Output id1 = ops::Identity(s.WithOpName("id1"), transpose2); + + GrapplerItem item; + item.fetch = {"id1"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphDef output; + ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE); + EnableOnlyRemoveIdentityTranspose(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); + + std::set<string> nodes_after_optimization; + for (const NodeDef& node : output.node()) { + nodes_after_optimization.insert(node.name()); + if (node.name() == "id") { + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("inputs", node.input(0)); + EXPECT_EQ("^perm2", node.input(1)); + } + if (node.name() == "id1") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("id", node.input(0)); + } + } + EXPECT_EQ(nodes_after_optimization, + std::set<string>({"id", "id1", "inputs_shape", "inputs", "perm2"})); +} + TEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT, |