aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_domain_map.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_domain_map.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.cc13
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc
index ebd5adb5d5..9e096320db 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc
@@ -41,11 +41,15 @@ namespace xla {
bool HloDomainMap::InSameDomain(HloInstruction* instruction1,
HloInstruction* instruction2) const {
- int64 domain_id1 = FindOrDefault(instruction_to_domain_, instruction1, -1);
- int64 domain_id2 = FindOrDefault(instruction_to_domain_, instruction2, -1);
+ int64 domain_id1 = GetDomainId(instruction1);
+ int64 domain_id2 = GetDomainId(instruction2);
return domain_id1 >= 0 && domain_id1 == domain_id2;
}
+int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const {
+ return FindOrDefault(instruction_to_domain_, instruction, -1);
+}
+
Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain);
// We only check operands, so we are sure to not process the empty domain from
@@ -58,6 +62,11 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
}
}
+ if (instruction == instruction->parent()->root_instruction()) {
+ auto domain = MakeUnique<DomainMetadata::Domain>();
+ domain->enter_domains.insert(instruction);
+ TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
+ }
return Status::OK();
}