diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_domain_remover.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_domain_remover.cc | 48 |
1 files changed, 5 insertions, 43 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.cc b/tensorflow/compiler/xla/service/hlo_domain_remover.cc index 1d06040b0e..67fad0769f 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_remover.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_domain_remover.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" #include "tensorflow/compiler/xla/service/hlo_domain_map.h" +#include "tensorflow/compiler/xla/service/hlo_domain_verifier.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -43,54 +43,16 @@ class HloDomainRemover::RunContext { Status HloDomainRemover::RunContext::VerifyAndNormalizeDomain( const DomainMetadata::Domain& domain) { - // Verify that the whole kDomain frontier bounding the instruction reach set, - // has matching metadata. - // A kDomain instruction has two sides of metadata, a user facing and an - // operand facing. - // A reachable instruction set can make contact with a kDomain instruction on - // a user facing side (the kDomain is operand of the instruction), or on a - // operand facing side (the kDomain is user of the instruction). - // And depending on the contact side, the proper metadata object - // (user_side_metadata() vs. operand_side_metadata()) needs to be used for - // consistency checks. - const DomainMetadata* ref_metadata = nullptr; - VLOG(4) << "Reach set:"; - for (HloInstruction* instruction : domain.instructions) { - VLOG(4) << " " << instruction->name(); - } - VLOG(4) << " Domains:"; - for (HloInstruction* instruction : domain.enter_domains) { - const DomainMetadata& meta = instruction->user_side_metadata(); - VLOG(4) << " User side: " << instruction->name(); - VLOG(4) << " " << meta.ToString(); - if (ref_metadata == nullptr) { - ref_metadata = &meta; - } else { - TF_RET_CHECK(meta.Matches(*ref_metadata)) - << "Metadata mismatch at instruction " << instruction->name() << " : " - << meta.ToString() << " vs " << ref_metadata->ToString(); - } - } - for (HloInstruction* instruction : domain.exit_domains) { - const DomainMetadata& meta = instruction->operand_side_metadata(); - VLOG(4) << " Operand side: " << instruction->name(); - VLOG(4) << " " << meta.ToString(); - if (ref_metadata == nullptr) { - ref_metadata = &meta; - } else { - TF_RET_CHECK(meta.Matches(*ref_metadata)) - << "Metadata mismatch at instruction " << instruction->name() << " : " - << meta.ToString() << " vs " << ref_metadata->ToString(); - } - } + TF_ASSIGN_OR_RETURN(const DomainMetadata* ref_metadata, + HloDomainVerifier::VerifyDomain(domain)); if (ref_metadata != nullptr) { VLOG(4) << "Applying domain normalization: " << ref_metadata->ToString(); - TF_RETURN_IF_ERROR(ref_metadata->NormalizeInstructions(domain)); + TF_RETURN_IF_ERROR(remover_->normalizer_(domain, ref_metadata)); } else { // No kDomain instruction was present within this domain, so call the // generic normalization functions and have them apply their heuristic. VLOG(2) << "Applying domain-less normalization"; - TF_RETURN_IF_ERROR(remover_->normalizer_(domain)); + TF_RETURN_IF_ERROR(remover_->normalizer_(domain, nullptr)); } return Status::OK(); } |