aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-10 15:43:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-10 15:46:47 -0700
commitf7e24ab1113ae7094e4831a606a29e0d5b956bfe (patch)
treee6e382c9fb8b747d844d145b15822464ffe853eb /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
parentff7f7a566b356a7e2de2b8f174d0f09e673179f4 (diff)
Remove cancelling pairs of transposes that are separated by a non-branching chain of ops that preserve value, order, and shape. Off by default.
PiperOrigin-RevId: 196183111
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc43
1 files changed, 42 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index d60c3124ed..d648fa0787 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -1122,7 +1122,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposes) {
ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
- Output perm3 = ops::Const(s.WithOpName("perm2"), {0, 1, 2, 3}, {4});
+ Output perm3 = ops::Const(s.WithOpName("perm3"), {0, 1, 2, 3}, {4});
Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm1);
Output transpose2 =
ops::Transpose(s.WithOpName("transpose2"), transpose1, perm2);
@@ -1248,6 +1248,47 @@ TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
EXPECT_EQ(6, output.node_size());
}
+TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesThroughChain) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output inputs_shape =
+ ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
+ Output inputs =
+ ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
+ Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
+ Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
+ Output transpose1 = ops::Transpose(
+ s.WithOpName("transpose1").WithControlDependencies(perm2), inputs, perm1);
+ Output identity = ops::Identity(s.WithOpName("id"), transpose1);
+ Output transpose2 =
+ ops::Transpose(s.WithOpName("transpose2"), identity, perm2);
+ Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
+
+ GrapplerItem item;
+ item.fetch = {"id1"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ EnableOnlyRemoveIdentityTranspose(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ std::set<string> nodes_after_optimization;
+ for (const NodeDef& node : output.node()) {
+ nodes_after_optimization.insert(node.name());
+ if (node.name() == "id") {
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("inputs", node.input(0));
+ EXPECT_EQ("^perm2", node.input(1));
+ }
+ if (node.name() == "id1") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("id", node.input(0));
+ }
+ }
+ EXPECT_EQ(nodes_after_optimization,
+ std::set<string>({"id", "id1", "inputs_shape", "inputs", "perm2"}));
+}
+
TEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,