From 46ffa99df62b3ecdab65f9bbf202921205d59e68 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 3 Nov 2017 10:54:10 -0700 Subject: Add optimization to fuse Conj with Transpose or ConjugateTranspose. PiperOrigin-RevId: 174483320 --- .../grappler/optimizers/arithmetic_optimizer.cc | 25 ++++++++ .../optimizers/arithmetic_optimizer_test.cc | 73 ++++++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 445e5cf972..f1c31ebb25 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -864,6 +864,31 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( return new_mul_node->name(); } } + + // Fuse ops by absorbing Conj into Transpose or ConjugateTranspose. + if (node->op() == "Conj" || node->op() == "Transpose" || + node->op() == "ConjugateTranspose") { + const NodeDef* input = node_map->GetNode(node->input(0)); + const NodeDef* transpose_op = node->op() == "Conj" ? input : node; + const NodeDef* conj_op = node->op() == "Conj" ? node : input; + if ((transpose_op->op() == "Transpose" || + transpose_op->op() == "ConjugateTranspose") && + conj_op->op() == "Conj") { + NodeDef* new_op = graph_def->add_node(); + *new_op = *transpose_op; + new_op->set_name(node->name() + "_fused"); + // Flip the type of transpose op to absorb the conjugation. + new_op->set_op(transpose_op->op() == "Transpose" ? "ConjugateTranspose" + : "Transpose"); + new_op->set_input(0, input->input(0)); + node_map->AddNode(new_op->name(), new_op); + node_map->UpdateInput(new_op->name(), node->name(), input->input(0)); + AddFrameControlDeps(node, {new_op}, "", {}, graph_def, node_map, + frame_map); + return new_op->name(); + } + } + return ""; } diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 5c3fdd2553..c1535886d1 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -191,6 +191,79 @@ TEST_F(ArithmeticOptimizerTest, SimplifyHoistFactor) { EXPECT_EQ("mul1_hoist", new_id.input(0)); } +TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2}); + Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2}); + Output z = ops::Complex(s.WithOpName("z"), re, im); + Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2}); + Output conj = ops::Conj(s.WithOpName("conj"), z); + Output transp = ops::Transpose(s.WithOpName("trans"), conj, perm); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ArithmeticOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + LOG(INFO) << output.DebugString(); + + EXPECT_EQ(7, output.node_size()); + EXPECT_EQ("trans_fused", output.node(6).name()); + EXPECT_EQ("ConjugateTranspose", output.node(6).op()); + EXPECT_EQ("z", output.node(6).input(0)); + EXPECT_EQ("perm", output.node(6).input(1)); +} + +TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2}); + Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2}); + Output z = ops::Complex(s.WithOpName("z"), re, im); + Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2}); + Output conj = ops::Conj(s.WithOpName("conj"), z); + Output transp = + ops::ConjugateTranspose(s.WithOpName("conjugate_trans"), conj, perm); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ArithmeticOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + LOG(INFO) << output.DebugString(); + + EXPECT_EQ(7, output.node_size()); + EXPECT_EQ("conjugate_trans_fused", output.node(6).name()); + EXPECT_EQ("Transpose", output.node(6).op()); + EXPECT_EQ("z", output.node(6).input(0)); + EXPECT_EQ("perm", output.node(6).input(1)); +} + +TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2}); + Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2}); + Output z = ops::Complex(s.WithOpName("z"), re, im); + Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2}); + Output trans = ops::Transpose(s.WithOpName("trans"), z, perm); + Output conj = ops::Conj(s.WithOpName("conj"), trans); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ArithmeticOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + LOG(INFO) << output.DebugString(); + + EXPECT_EQ(7, output.node_size()); + EXPECT_EQ("conj_fused", output.node(6).name()); + EXPECT_EQ("ConjugateTranspose", output.node(6).op()); + EXPECT_EQ("z", output.node(6).input(0)); + EXPECT_EQ("perm", output.node(6).input(1)); +} + TEST_F(ArithmeticOptimizerTest, IdentityReshape) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = -- cgit v1.2.3