diff options
author | Dimitris Vardoulakis <dimvar@google.com> | 2018-09-18 21:52:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-18 21:56:09 -0700 |
commit | d7cc73c300b12e7c02507bcfaff146d6c4955f19 (patch) | |
tree | 4a2cfc40b6b158eb55f9f7f663cb1ce63130fdea | |
parent | 50e7f03591a5d2b6b2abc29e5549ea0077259706 (diff) |
[TF:XLA] Change HloPtrComparator to work across HLO modules. Declaring the method out of line does not increase compile time.
PiperOrigin-RevId: 213571783
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 20 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 14 |
2 files changed, 21 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index e905f2983a..ad58833e4d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2910,6 +2910,26 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } +bool HloPtrComparator::operator()(const HloInstruction* const& lhs, + const HloInstruction* const& rhs) const { + if (rhs == nullptr) { + // Nothing compares less than nullptr. + return false; + } + if (lhs == nullptr) { + return true; + } + auto lhs_module = lhs->GetModule(); + auto rhs_module = rhs->GetModule(); + CHECK((lhs_module == nullptr && rhs_module == nullptr) || + (lhs_module != nullptr && rhs_module != nullptr)); + if (lhs_module != nullptr && + lhs_module->unique_id() != rhs_module->unique_id()) { + return lhs_module->unique_id() < rhs_module->unique_id(); + } + return lhs->unique_id() < rhs->unique_id(); +} + bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 1ef8cd5036..d615df0831 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1693,21 +1693,9 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); // To make the iteration order over the map deterministic, the comparator // should not be using the pointer values, but rather an intrinsic property of // the hlo. Exception: null pointer values compare less than non-null. -// -// Note that this cannot be used for HLO instructions across multiple modules -// since the id of HLO instructions are only unique within each HLO module. struct HloPtrComparator { bool operator()(const HloInstruction* const& lhs, - const HloInstruction* const& rhs) const { - if (rhs == nullptr) { - // Nothing compares less than nullptr. - return false; - } - if (lhs == nullptr) { - return true; - } - return lhs->unique_id() < rhs->unique_id(); - } + const HloInstruction* const& rhs) const; }; template <typename ValueT> |