aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-21 10:11:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-21 10:16:07 -0700
commitaa5d3126ced57f8117678bb1cb5cc41e2a72eb9a (patch)
tree2d5cde5e4a17d6f39bfcaa83cb84c9734f0305f8
parentdae7a75734f2137aae7130e064fab9dfcb799c45 (diff)
Support while ops to be CSEd in HLO. Do so if they have same bodies, conditions,
and init conditions. PiperOrigin-RevId: 205524367
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc179
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc11
2 files changed, 187 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index 76b9c66651..90fbaa37c5 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -239,7 +239,7 @@ TEST_F(HloCseTest, IdenticalInstructions) {
EXPECT_EQ(5, computation->instruction_count());
EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3));
- HloCSE cse(/*is_layout_sensitive=*/false);
+ HloCSE cse(/*is_layout_sensitive=*/true);
EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
EXPECT_EQ(3, computation->instruction_count());
@@ -248,6 +248,183 @@ TEST_F(HloCseTest, IdenticalInstructions) {
EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand, first_operand));
}
+// Test two identical while loops with same inputs
+TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) {
+ auto module = ParseHloString(R"(
+ HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput
+
+ %body (param: (f32[], f32[])) -> (f32[], f32[]) {
+ %param = (f32[], f32[]) parameter(0)
+ %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param),
+index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param),
+index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1)
+ ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add)
+ }
+
+ %condition (param.1: (f32[], f32[])) -> pred[] {
+ %param.1 = (f32[], f32[]) parameter(0)
+ ROOT %constant = pred[] constant(false)
+ }
+
+ %condition.1 (param.2: (f32[], f32[])) -> pred[] {
+ %param.2 = (f32[], f32[]) parameter(0)
+ ROOT %constant.1 = pred[] constant(false)
+ }
+
+ ENTRY %WhileLoopsIdenticalConditionsAndBodiesSameInput () -> (f32[], f32[])
+{ %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) %tuple.1 =
+(f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3) %while = (f32[],
+f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT
+%while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
+condition=%condition.1, body=%body
+ }
+ )")
+ .ValueOrDie();
+
+ auto computation = module->entry_computation();
+
+ EXPECT_EQ(5, computation->instruction_count());
+ HloCSE cse(true);
+ EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_EQ(4, computation->instruction_count());
+}
+
+// Test two while loops with same conditions, same inputs, but different
+// bodies
+TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) {
+ auto module = ParseHloString(R"(
+ HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies
+
+ %body (param: (f32[], f32[])) -> (f32[], f32[]) {
+ %param = (f32[], f32[]) parameter(0)
+ %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param),
+index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param),
+index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1)
+ ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add)
+ }
+
+ %body2 (param.1: (f32[], f32[])) -> (f32[], f32[]) {
+ %param.1 = (f32[], f32[]) parameter(0)
+ %get-tuple-element.2 = f32[] get-tuple-element((f32[], f32[]) %param.1),
+index=0 %get-tuple-element.3 = f32[] get-tuple-element((f32[], f32[]) %param.1),
+index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[]
+%get-tuple-element.3) ROOT %tuple.2 = (f32[], f32[]) tuple(f32[]
+%get-tuple-element.2, f32[] %sub)
+ }
+
+ %condition (param.2: (f32[], f32[])) -> pred[] {
+ %param.2 = (f32[], f32[]) parameter(0)
+ ROOT %constant = pred[] constant(false)
+ }
+
+ %condition.1 (param.3: (f32[], f32[])) -> pred[] {
+ %param.3 = (f32[], f32[]) parameter(0)
+ ROOT %constant.1 = pred[] constant(false)
+ }
+
+ ENTRY %WhileLoopsIdenticalConditionsSameInputAndDifferentBodies () ->
+(f32[], f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2)
+ %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3)
+ %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
+condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[],
+f32[]) %tuple.1), condition=%condition.1, body=%body2
+ }
+ )")
+ .ValueOrDie();
+
+ auto computation = module->entry_computation();
+
+ EXPECT_EQ(5, computation->instruction_count());
+ HloCSE cse(true);
+ EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_EQ(5, computation->instruction_count());
+}
+
+// Test two identical while loops with different inputs
+TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) {
+ auto module = ParseHloString(R"(
+ HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput
+
+ %body (param: (f32[], f32[])) -> (f32[], f32[]) {
+ %param = (f32[], f32[]) parameter(0)
+ %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param),
+index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param),
+index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1)
+ ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add)
+ }
+
+ %condition (param.1: (f32[], f32[])) -> pred[] {
+ %param.1 = (f32[], f32[]) parameter(0)
+ ROOT %constant = pred[] constant(false)
+ }
+
+ %condition.1 (param.2: (f32[], f32[])) -> pred[] {
+ %param.2 = (f32[], f32[]) parameter(0)
+ ROOT %constant.1 = pred[] constant(false)
+ }
+
+ ENTRY %WhileLoopsIdenticalConditionsAndBodiesDifferentInput () -> (f32[],
+f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2)
+ %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3)
+ %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
+condition=%condition, body=%body %constant.4 = f32[] constant(1) %constant.5 =
+f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[]
+%constant.5) ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.2),
+condition=%condition.1, body=%body
+ }
+
+ )")
+ .ValueOrDie();
+
+ auto computation = module->entry_computation();
+
+ EXPECT_EQ(8, computation->instruction_count());
+ HloCSE cse(true);
+ EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_EQ(8, computation->instruction_count());
+}
+
+// Test two while loops with identical bodies and same inputs, but different
+// conditions
+TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) {
+ auto module = ParseHloString(R"(
+ HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions
+
+ %body (param: (f32[], f32[])) -> (f32[], f32[]) {
+ %param = (f32[], f32[]) parameter(0)
+ %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param),
+index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param),
+index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1)
+ ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add)
+ }
+
+ %condition (param.1: (f32[], f32[])) -> pred[] {
+ %param.1 = (f32[], f32[]) parameter(0)
+ ROOT %constant = pred[] constant(false)
+ }
+
+ %condition.1 (param.2: (f32[], f32[])) -> pred[] {
+ %param.2 = (f32[], f32[]) parameter(0)
+ ROOT %constant.1 = pred[] constant(true)
+ }
+
+ ENTRY %WhileLoopsIdenticalBodiesAndInputDifferntConditions () -> (f32[],
+f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2)
+ %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3)
+ %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
+condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[],
+f32[]) %tuple.1), condition=%condition.1, body=%body
+ })")
+ .ValueOrDie();
+
+ auto computation = module->entry_computation();
+
+ EXPECT_EQ(5, computation->instruction_count());
+ HloCSE cse(true);
+ EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_EQ(5, computation->instruction_count());
+}
+
TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) {
// Test that two identical instructions with different layouts are *not*
// commoned if the pass is layout sensitive.
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 7685c822f4..8b9bdd2f46 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1522,8 +1522,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kTupleSelect:
return true;
- // These opcodes have complex or special behavior so just return false.
- case HloOpcode::kWhile:
+ // This opcode has complex or special behavior so just return false.
case HloOpcode::kAfterAll:
return false;
@@ -1539,6 +1538,14 @@ bool HloInstruction::IdenticalSlowPath(
return eq_computations(true_computation(), other.true_computation()) &&
eq_computations(false_computation(), other.false_computation());
+ case HloOpcode::kWhile: {
+ if (eq_computations(while_body(), other.while_body()) &&
+ eq_computations(while_condition(), other.while_condition())) {
+ return true;
+ }
+ return false;
+ }
+
case HloOpcode::kDomain:
return operand_side_metadata().Matches(other.operand_side_metadata()) &&
user_side_metadata().Matches(other.user_side_metadata());