diff options
author | Yuanzhong Xu <yuanzx@google.com> | 2018-02-16 14:17:13 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-16 14:20:59 -0800 |
commit | ea70fb58f923a2c86ccc14cd38618afdc0dfa1bc (patch) | |
tree | 5a0e734d7f27e4a03ddea33f2b830f39f1ecba9b | |
parent | 785ee91c0d4f9a0e8eafa082f725c25ae134c9b3 (diff) |
[XLA] HLO scheduling: update entries in ready queue when priority changes.
PiperOrigin-RevId: 186045619
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_scheduling.cc | 77 |
1 files changed, 41 insertions, 36 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 8dc4d4f7ba..f6e33403f5 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_scheduling.h" -#include <queue> +#include <map> #include <utility> #include <vector> @@ -151,8 +151,10 @@ class ListScheduler { int64 bytes_defined; // For each buffer B used by this instruction, we keep a pair (B, U), where - // U is the number of uses of B that have not yet been scheduled. - std::vector<std::pair<const LogicalBuffer* const, int64>> + // U is the number of uses of B that have not yet been scheduled. This pair + // is a pointer into the unscheduled_use_count_ map, so it gets updated for + // free when we update counts in the map. + std::vector<const std::pair<const LogicalBuffer* const, int64>*> used_buffer_unscheduled_use_counts; }; @@ -175,8 +177,8 @@ class ListScheduler { } auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer); CHECK(unscheduled_use_count_it != unscheduled_use_count_.end()); - entry.used_buffer_unscheduled_use_counts.emplace_back( - unscheduled_use_count_it->first, unscheduled_use_count_it->second); + entry.used_buffer_unscheduled_use_counts.push_back( + &*unscheduled_use_count_it); } return entry; } @@ -185,8 +187,8 @@ class ListScheduler { int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { int64 freed_bytes = 0; for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { - auto buffer = kv.first; - auto use_count = kv.second; + auto buffer = kv->first; + auto use_count = kv->second; if (use_count == 1) { freed_bytes += size_function_(*buffer); } @@ -217,23 +219,18 @@ class ListScheduler { } } - auto priority_comparator = - [this](const std::pair<Priority, ReadyListEntry>& lhs, - const std::pair<Priority, ReadyListEntry>& rhs) { - return lhs.first < rhs.first; - }; - std::priority_queue<std::pair<Priority, ReadyListEntry>, - std::vector<std::pair<Priority, ReadyListEntry>>, - decltype(priority_comparator)> - ready_queue(priority_comparator); + // Use a multimap to sort ReadyListEntry according to their priority. + std::multimap<Priority, ReadyListEntry> ready_queue; - // Set of instructions in the ready list. - tensorflow::gtl::FlatSet<const HloInstruction*> ready_instructions; + // Map of ready instructions to their iterators in ready_queue. + tensorflow::gtl::FlatMap<const HloInstruction*, + std::multimap<Priority, ReadyListEntry>::iterator> + ready_instructions; auto add_to_ready_queue = [&](HloInstruction* inst) { auto entry = MakeReadyListEntry(inst); - ready_queue.emplace(GetPriority(entry), std::move(entry)); - ready_instructions.insert(inst); + auto it = ready_queue.emplace(GetPriority(entry), std::move(entry)); + ready_instructions[inst] = it; }; for (auto* instruction : computation_.instructions()) { @@ -247,14 +244,10 @@ class ListScheduler { while (!ready_queue.empty()) { // Remove the selected instruction from the ready list and add it to the // schedule. - const HloInstruction* best = ready_queue.top().second.instruction; - ready_queue.pop(); - // We may have duplicates in the priority queue, because when a ready - // instruction's priority goes up, we reinsert it to the priority queue. - // Skip the duplicate. - if (scheduled_instructions_.find(best) != scheduled_instructions_.end()) { - continue; - } + auto best_it = ready_queue.end(); + --best_it; + const HloInstruction* best = best_it->second.instruction; + ready_queue.erase(best_it); ready_instructions.erase(best); schedule.push_back(best); scheduled_instructions_.insert(best); @@ -287,16 +280,27 @@ class ListScheduler { update_pred_count(succ); } // The unscheduled use count for a buffer has changed to 1, so the - // priorities of some ready instructions may go up. We reinsert them to - // the priority queue, so that they can appear earlier. The old entries - // will become duplicates and will be skipped. + // priorities of some ready instructions may go up. We update them in the + // ready queue, so that they can appear earlier. if (adjust_ready_queue) { for (HloInstruction* operand : best->operands()) { for (HloInstruction* operand_user : operand->users()) { - if (ready_instructions.find(operand_user) != - ready_instructions.end()) { - add_to_ready_queue(operand_user); + auto ready_instructions_it = ready_instructions.find(operand_user); + if (ready_instructions_it == ready_instructions.end()) { + continue; + } + auto ready_queue_it = ready_instructions_it->second; + auto& entry = ready_queue_it->second; + Priority new_priority = GetPriority(entry); + if (new_priority == ready_queue_it->first) { + continue; } + // Create a new entry in ready_queue, then update + // ready_instructions[operand_user] to refer to the new entry. + ready_instructions_it->second = + ready_queue.emplace(new_priority, std::move(entry)); + // Remove the old entry in ready_queue. + ready_queue.erase(ready_queue_it); } } } @@ -317,8 +321,9 @@ class ListScheduler { buffer_uses_; // A map containing the count of unscheduled HLOs which using a particular - // LogicalBuffer. We rely on iterator stability in this map. - tensorflow::gtl::FlatMap<const LogicalBuffer*, int64> unscheduled_use_count_; + // LogicalBuffer. We rely on iterator stability in this map, and that the map + // entries are std::pair's. + std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_; // Set of instructions which have been scheduled. tensorflow::gtl::FlatSet<const HloInstruction*> scheduled_instructions_; |