aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_cse_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_cse_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc179
1 files changed, 178 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index 76b9c66651..90fbaa37c5 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -239,7 +239,7 @@ TEST_F(HloCseTest, IdenticalInstructions) {
EXPECT_EQ(5, computation->instruction_count());
EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3));
- HloCSE cse(/*is_layout_sensitive=*/false);
+ HloCSE cse(/*is_layout_sensitive=*/true);
EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
EXPECT_EQ(3, computation->instruction_count());
@@ -248,6 +248,183 @@ TEST_F(HloCseTest, IdenticalInstructions) {
EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand, first_operand));
}
+// Test two identical while loops with same inputs
+TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) {
+ auto module = ParseHloString(R"(
+ HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput
+
+ %body (param: (f32[], f32[])) -> (f32[], f32[]) {
+ %param = (f32[], f32[]) parameter(0)
+ %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param),
+index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param),
+index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1)
+ ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add)
+ }
+
+ %condition (param.1: (f32[], f32[])) -> pred[] {
+ %param.1 = (f32[], f32[]) parameter(0)
+ ROOT %constant = pred[] constant(false)
+ }
+
+ %condition.1 (param.2: (f32[], f32[])) -> pred[] {
+ %param.2 = (f32[], f32[]) parameter(0)
+ ROOT %constant.1 = pred[] constant(false)
+ }
+
+ ENTRY %WhileLoopsIdenticalConditionsAndBodiesSameInput () -> (f32[], f32[])
+{ %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) %tuple.1 =
+(f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3) %while = (f32[],
+f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT
+%while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
+condition=%condition.1, body=%body
+ }
+ )")
+ .ValueOrDie();
+
+ auto computation = module->entry_computation();
+
+ EXPECT_EQ(5, computation->instruction_count());
+ HloCSE cse(true);
+ EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_EQ(4, computation->instruction_count());
+}
+
+// Test two while loops with same conditions, same inputs, but different
+// bodies
+TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) {
+ auto module = ParseHloString(R"(
+ HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies
+
+ %body (param: (f32[], f32[])) -> (f32[], f32[]) {
+ %param = (f32[], f32[]) parameter(0)
+ %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param),
+index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param),
+index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1)
+ ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add)
+ }
+
+ %body2 (param.1: (f32[], f32[])) -> (f32[], f32[]) {
+ %param.1 = (f32[], f32[]) parameter(0)
+ %get-tuple-element.2 = f32[] get-tuple-element((f32[], f32[]) %param.1),
+index=0 %get-tuple-element.3 = f32[] get-tuple-element((f32[], f32[]) %param.1),
+index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[]
+%get-tuple-element.3) ROOT %tuple.2 = (f32[], f32[]) tuple(f32[]
+%get-tuple-element.2, f32[] %sub)
+ }
+
+ %condition (param.2: (f32[], f32[])) -> pred[] {
+ %param.2 = (f32[], f32[]) parameter(0)
+ ROOT %constant = pred[] constant(false)
+ }
+
+ %condition.1 (param.3: (f32[], f32[])) -> pred[] {
+ %param.3 = (f32[], f32[]) parameter(0)
+ ROOT %constant.1 = pred[] constant(false)
+ }
+
+ ENTRY %WhileLoopsIdenticalConditionsSameInputAndDifferentBodies () ->
+(f32[], f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2)
+ %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3)
+ %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
+condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[],
+f32[]) %tuple.1), condition=%condition.1, body=%body2
+ }
+ )")
+ .ValueOrDie();
+
+ auto computation = module->entry_computation();
+
+ EXPECT_EQ(5, computation->instruction_count());
+ HloCSE cse(true);
+ EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_EQ(5, computation->instruction_count());
+}
+
+// Test two identical while loops with different inputs
+TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) {
+ auto module = ParseHloString(R"(
+ HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput
+
+ %body (param: (f32[], f32[])) -> (f32[], f32[]) {
+ %param = (f32[], f32[]) parameter(0)
+ %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param),
+index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param),
+index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1)
+ ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add)
+ }
+
+ %condition (param.1: (f32[], f32[])) -> pred[] {
+ %param.1 = (f32[], f32[]) parameter(0)
+ ROOT %constant = pred[] constant(false)
+ }
+
+ %condition.1 (param.2: (f32[], f32[])) -> pred[] {
+ %param.2 = (f32[], f32[]) parameter(0)
+ ROOT %constant.1 = pred[] constant(false)
+ }
+
+ ENTRY %WhileLoopsIdenticalConditionsAndBodiesDifferentInput () -> (f32[],
+f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2)
+ %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3)
+ %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
+condition=%condition, body=%body %constant.4 = f32[] constant(1) %constant.5 =
+f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[]
+%constant.5) ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.2),
+condition=%condition.1, body=%body
+ }
+
+ )")
+ .ValueOrDie();
+
+ auto computation = module->entry_computation();
+
+ EXPECT_EQ(8, computation->instruction_count());
+ HloCSE cse(true);
+ EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_EQ(8, computation->instruction_count());
+}
+
+// Test two while loops with identical bodies and same inputs, but different
+// conditions
+TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) {
+ auto module = ParseHloString(R"(
+ HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions
+
+ %body (param: (f32[], f32[])) -> (f32[], f32[]) {
+ %param = (f32[], f32[]) parameter(0)
+ %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param),
+index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param),
+index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1)
+ ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add)
+ }
+
+ %condition (param.1: (f32[], f32[])) -> pred[] {
+ %param.1 = (f32[], f32[]) parameter(0)
+ ROOT %constant = pred[] constant(false)
+ }
+
+ %condition.1 (param.2: (f32[], f32[])) -> pred[] {
+ %param.2 = (f32[], f32[]) parameter(0)
+ ROOT %constant.1 = pred[] constant(true)
+ }
+
+ ENTRY %WhileLoopsIdenticalBodiesAndInputDifferntConditions () -> (f32[],
+f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2)
+ %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3)
+ %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
+condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[],
+f32[]) %tuple.1), condition=%condition.1, body=%body
+ })")
+ .ValueOrDie();
+
+ auto computation = module->entry_computation();
+
+ EXPECT_EQ(5, computation->instruction_count());
+ HloCSE cse(true);
+ EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_EQ(5, computation->instruction_count());
+}
+
TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) {
// Test that two identical instructions with different layouts are *not*
// commoned if the pass is layout sensitive.