aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-10 06:13:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-10 06:16:52 -0700
commit13da30d465a017207d9ca2116f978f0d2a9d15b5 (patch)
treec1159fb3bc3cf42690056ad9532224be6e820e4e
parent4657eb350f3a50b0c8cff6e62a1927f25b1f38bf (diff)
Teach HLO CSE about domain instructions
Previously we didn't applied CSE for domain instructions at all but if two domain instruction have the same entry and exit metadata then it is safe to CSE them. PiperOrigin-RevId: 203934674
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc35
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc5
3 files changed, 41 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
index 3f1deec2df..06484f4012 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -143,10 +143,8 @@ StatusOr<bool> HloCSE::Run(HloModule* module) {
if (instruction->operand_count() == 0) {
continue;
}
- // Skip instructions which have side effects or are a domain (which must
- // not be CSE-ed).
- if (instruction->HasSideEffect() ||
- instruction->opcode() == HloOpcode::kDomain) {
+ // Skip instructions which have side effects.
+ if (instruction->HasSideEffect()) {
continue;
}
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index c98a79fc71..76b9c66651 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -536,5 +536,40 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) {
EXPECT_EQ(2, computation->instruction_count());
}
+TEST_F(HloCseTest, Domain) {
+ auto module = ParseHloString(R"(
+HloModule module
+ENTRY %entry {
+ %param = f32[] parameter(0), sharding={maximal device=0}
+ %domain.0 = f32[] domain(%param),
+ domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
+ %domain.1 = f32[] domain(%param),
+ domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
+ %domain.2 = f32[] domain(%param),
+ domain={kind="sharding", entry={maximal device=0}, exit={maximal device=2}}
+ %negate.0 = f32[] negate(%domain.0)
+ %negate.1 = f32[] negate(%domain.1)
+ %negate.2 = f32[] negate(%domain.2)
+ %domain.3 = f32[] domain(%negate.0),
+ domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
+ %domain.4 = f32[] domain(%negate.1),
+ domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
+ %domain.5 = f32[] domain(%negate.2),
+ domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}}
+ %add = f32[] add(%domain.3, %domain.4)
+ ROOT %sub = f32[] subtract(%add, %domain.5)
+})")
+ .ValueOrDie();
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ LOG(INFO) << "AAAAA " << module->ToString();
+ const HloInstruction* sub = module->entry_computation()->root_instruction();
+ const HloInstruction* add = sub->operand(0);
+ EXPECT_EQ(add->operand(0), add->operand(1));
+ EXPECT_NE(add->operand(0), sub->operand(1));
+ EXPECT_NE(add->operand(1), sub->operand(1));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index f5fa255693..ef7b5e3924 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1528,7 +1528,6 @@ bool HloInstruction::IdenticalSlowPath(
return true;
// These opcodes have complex or special behavior so just return false.
- case HloOpcode::kDomain:
case HloOpcode::kWhile:
case HloOpcode::kAfterAll:
return false;
@@ -1550,6 +1549,10 @@ bool HloInstruction::IdenticalSlowPath(
return eq_computations(true_computation(), other.true_computation()) &&
eq_computations(false_computation(), other.false_computation());
+ case HloOpcode::kDomain:
+ return operand_side_metadata().Matches(other.operand_side_metadata()) &&
+ user_side_metadata().Matches(other.user_side_metadata());
+
// Ops migrated to subclasses should never come to this line.
// TODO(b/80131774): Remove this switch when migration is complete.
case HloOpcode::kBatchNormTraining: