aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-03 10:54:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-03 10:58:14 -0700
commit46ffa99df62b3ecdab65f9bbf202921205d59e68 (patch)
treef770adaf057fc9197d03dd5a6cf69b3466eda3f9
parent96c415ad77c20e1cf2da5e61f85e24fd6c36eb28 (diff)
Add optimization to fuse Conj with Transpose or ConjugateTranspose.
PiperOrigin-RevId: 174483320
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc25
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc73
2 files changed, 98 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 445e5cf972..f1c31ebb25 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -864,6 +864,31 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
return new_mul_node->name();
}
}
+
+ // Fuse ops by absorbing Conj into Transpose or ConjugateTranspose.
+ if (node->op() == "Conj" || node->op() == "Transpose" ||
+ node->op() == "ConjugateTranspose") {
+ const NodeDef* input = node_map->GetNode(node->input(0));
+ const NodeDef* transpose_op = node->op() == "Conj" ? input : node;
+ const NodeDef* conj_op = node->op() == "Conj" ? node : input;
+ if ((transpose_op->op() == "Transpose" ||
+ transpose_op->op() == "ConjugateTranspose") &&
+ conj_op->op() == "Conj") {
+ NodeDef* new_op = graph_def->add_node();
+ *new_op = *transpose_op;
+ new_op->set_name(node->name() + "_fused");
+ // Flip the type of transpose op to absorb the conjugation.
+ new_op->set_op(transpose_op->op() == "Transpose" ? "ConjugateTranspose"
+ : "Transpose");
+ new_op->set_input(0, input->input(0));
+ node_map->AddNode(new_op->name(), new_op);
+ node_map->UpdateInput(new_op->name(), node->name(), input->input(0));
+ AddFrameControlDeps(node, {new_op}, "", {}, graph_def, node_map,
+ frame_map);
+ return new_op->name();
+ }
+ }
+
return "";
}
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 5c3fdd2553..c1535886d1 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -191,6 +191,79 @@ TEST_F(ArithmeticOptimizerTest, SimplifyHoistFactor) {
EXPECT_EQ("mul1_hoist", new_id.input(0));
}
+TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2});
+ Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2});
+ Output z = ops::Complex(s.WithOpName("z"), re, im);
+ Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
+ Output conj = ops::Conj(s.WithOpName("conj"), z);
+ Output transp = ops::Transpose(s.WithOpName("trans"), conj, perm);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ ArithmeticOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ LOG(INFO) << output.DebugString();
+
+ EXPECT_EQ(7, output.node_size());
+ EXPECT_EQ("trans_fused", output.node(6).name());
+ EXPECT_EQ("ConjugateTranspose", output.node(6).op());
+ EXPECT_EQ("z", output.node(6).input(0));
+ EXPECT_EQ("perm", output.node(6).input(1));
+}
+
+TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2});
+ Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2});
+ Output z = ops::Complex(s.WithOpName("z"), re, im);
+ Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
+ Output conj = ops::Conj(s.WithOpName("conj"), z);
+ Output transp =
+ ops::ConjugateTranspose(s.WithOpName("conjugate_trans"), conj, perm);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ ArithmeticOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ LOG(INFO) << output.DebugString();
+
+ EXPECT_EQ(7, output.node_size());
+ EXPECT_EQ("conjugate_trans_fused", output.node(6).name());
+ EXPECT_EQ("Transpose", output.node(6).op());
+ EXPECT_EQ("z", output.node(6).input(0));
+ EXPECT_EQ("perm", output.node(6).input(1));
+}
+
+TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2});
+ Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2});
+ Output z = ops::Complex(s.WithOpName("z"), re, im);
+ Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
+ Output trans = ops::Transpose(s.WithOpName("trans"), z, perm);
+ Output conj = ops::Conj(s.WithOpName("conj"), trans);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ ArithmeticOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ LOG(INFO) << output.DebugString();
+
+ EXPECT_EQ(7, output.node_size());
+ EXPECT_EQ("conj_fused", output.node(6).name());
+ EXPECT_EQ("ConjugateTranspose", output.node(6).op());
+ EXPECT_EQ("z", output.node(6).input(0));
+ EXPECT_EQ("perm", output.node(6).input(1));
+}
+
TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =