aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Dimitris Vardoulakis <dimvar@google.com>2018-09-18 21:52:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-18 21:56:09 -0700
commitd7cc73c300b12e7c02507bcfaff146d6c4955f19 (patch)
tree4a2cfc40b6b158eb55f9f7f663cb1ce63130fdea
parent50e7f03591a5d2b6b2abc29e5549ea0077259706 (diff)
[TF:XLA] Change HloPtrComparator to work across HLO modules. Declaring the method out of line does not increase compile time.
PiperOrigin-RevId: 213571783
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h14
2 files changed, 21 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index e905f2983a..ad58833e4d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -2910,6 +2910,26 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
return os << ToString(kind);
}
+bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
+ const HloInstruction* const& rhs) const {
+ if (rhs == nullptr) {
+ // Nothing compares less than nullptr.
+ return false;
+ }
+ if (lhs == nullptr) {
+ return true;
+ }
+ auto lhs_module = lhs->GetModule();
+ auto rhs_module = rhs->GetModule();
+ CHECK((lhs_module == nullptr && rhs_module == nullptr) ||
+ (lhs_module != nullptr && rhs_module != nullptr));
+ if (lhs_module != nullptr &&
+ lhs_module->unique_id() != rhs_module->unique_id()) {
+ return lhs_module->unique_id() < rhs_module->unique_id();
+ }
+ return lhs->unique_id() < rhs->unique_id();
+}
+
bool HloInstruction::CouldBeBitcast() const {
switch (opcode_) {
case HloOpcode::kTranspose:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 1ef8cd5036..d615df0831 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -1693,21 +1693,9 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
// To make the iteration order over the map deterministic, the comparator
// should not be using the pointer values, but rather an intrinsic property of
// the hlo. Exception: null pointer values compare less than non-null.
-//
-// Note that this cannot be used for HLO instructions across multiple modules
-// since the id of HLO instructions are only unique within each HLO module.
struct HloPtrComparator {
bool operator()(const HloInstruction* const& lhs,
- const HloInstruction* const& rhs) const {
- if (rhs == nullptr) {
- // Nothing compares less than nullptr.
- return false;
- }
- if (lhs == nullptr) {
- return true;
- }
- return lhs->unique_id() < rhs->unique_id();
- }
+ const HloInstruction* const& rhs) const;
};
template <typename ValueT>