diff options
author | 2018-07-10 06:13:26 -0700 | |
---|---|---|
committer | 2018-07-10 06:16:52 -0700 | |
commit | 13da30d465a017207d9ca2116f978f0d2a9d15b5 (patch) | |
tree | c1159fb3bc3cf42690056ad9532224be6e820e4e | |
parent | 4657eb350f3a50b0c8cff6e62a1927f25b1f38bf (diff) |
Teach HLO CSE about domain instructions
Previously we didn't applied CSE for domain instructions at all but if
two domain instruction have the same entry and exit metadata then it is
safe to CSE them.
PiperOrigin-RevId: 203934674
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_cse.cc | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_cse_test.cc | 35 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 5 |
3 files changed, 41 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 3f1deec2df..06484f4012 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -143,10 +143,8 @@ StatusOr<bool> HloCSE::Run(HloModule* module) { if (instruction->operand_count() == 0) { continue; } - // Skip instructions which have side effects or are a domain (which must - // not be CSE-ed). - if (instruction->HasSideEffect() || - instruction->opcode() == HloOpcode::kDomain) { + // Skip instructions which have side effects. + if (instruction->HasSideEffect()) { continue; } diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index c98a79fc71..76b9c66651 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -536,5 +536,40 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { EXPECT_EQ(2, computation->instruction_count()); } +TEST_F(HloCseTest, Domain) { + auto module = ParseHloString(R"( +HloModule module +ENTRY %entry { + %param = f32[] parameter(0), sharding={maximal device=0} + %domain.0 = f32[] domain(%param), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + %domain.1 = f32[] domain(%param), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + %domain.2 = f32[] domain(%param), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=2}} + %negate.0 = f32[] negate(%domain.0) + %negate.1 = f32[] negate(%domain.1) + %negate.2 = f32[] negate(%domain.2) + %domain.3 = f32[] domain(%negate.0), + domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} + %domain.4 = f32[] domain(%negate.1), + domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} + %domain.5 = f32[] domain(%negate.2), + domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}} + %add = f32[] add(%domain.3, %domain.4) + ROOT %sub = f32[] subtract(%add, %domain.5) +})") + .ValueOrDie(); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + LOG(INFO) << "AAAAA " << module->ToString(); + const HloInstruction* sub = module->entry_computation()->root_instruction(); + const HloInstruction* add = sub->operand(0); + EXPECT_EQ(add->operand(0), add->operand(1)); + EXPECT_NE(add->operand(0), sub->operand(1)); + EXPECT_NE(add->operand(1), sub->operand(1)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index f5fa255693..ef7b5e3924 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1528,7 +1528,6 @@ bool HloInstruction::IdenticalSlowPath( return true; // These opcodes have complex or special behavior so just return false. - case HloOpcode::kDomain: case HloOpcode::kWhile: case HloOpcode::kAfterAll: return false; @@ -1550,6 +1549,10 @@ bool HloInstruction::IdenticalSlowPath( return eq_computations(true_computation(), other.true_computation()) && eq_computations(false_computation(), other.false_computation()); + case HloOpcode::kDomain: + return operand_side_metadata().Matches(other.operand_side_metadata()) && + user_side_metadata().Matches(other.user_side_metadata()); + // Ops migrated to subclasses should never come to this line. // TODO(b/80131774): Remove this switch when migration is complete. case HloOpcode::kBatchNormTraining: |