diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_domain_map.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_domain_map.cc | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index ebd5adb5d5..9e096320db 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -41,11 +41,15 @@ namespace xla { bool HloDomainMap::InSameDomain(HloInstruction* instruction1, HloInstruction* instruction2) const { - int64 domain_id1 = FindOrDefault(instruction_to_domain_, instruction1, -1); - int64 domain_id2 = FindOrDefault(instruction_to_domain_, instruction2, -1); + int64 domain_id1 = GetDomainId(instruction1); + int64 domain_id2 = GetDomainId(instruction2); return domain_id1 >= 0 && domain_id1 == domain_id2; } +int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const { + return FindOrDefault(instruction_to_domain_, instruction, -1); +} + Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain); // We only check operands, so we are sure to not process the empty domain from @@ -58,6 +62,11 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } } + if (instruction == instruction->parent()->root_instruction()) { + auto domain = MakeUnique<DomainMetadata::Domain>(); + domain->enter_domains.insert(instruction); + TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + } return Status::OK(); } |