aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yuanzhong Xu <yuanzx@google.com>2018-02-16 14:17:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-16 14:20:59 -0800
commitea70fb58f923a2c86ccc14cd38618afdc0dfa1bc (patch)
tree5a0e734d7f27e4a03ddea33f2b830f39f1ecba9b
parent785ee91c0d4f9a0e8eafa082f725c25ae134c9b3 (diff)
[XLA] HLO scheduling: update entries in ready queue when priority changes.
PiperOrigin-RevId: 186045619
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.cc77
1 files changed, 41 insertions, 36 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc
index 8dc4d4f7ba..f6e33403f5 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
-#include <queue>
+#include <map>
#include <utility>
#include <vector>
@@ -151,8 +151,10 @@ class ListScheduler {
int64 bytes_defined;
// For each buffer B used by this instruction, we keep a pair (B, U), where
- // U is the number of uses of B that have not yet been scheduled.
- std::vector<std::pair<const LogicalBuffer* const, int64>>
+ // U is the number of uses of B that have not yet been scheduled. This pair
+ // is a pointer into the unscheduled_use_count_ map, so it gets updated for
+ // free when we update counts in the map.
+ std::vector<const std::pair<const LogicalBuffer* const, int64>*>
used_buffer_unscheduled_use_counts;
};
@@ -175,8 +177,8 @@ class ListScheduler {
}
auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer);
CHECK(unscheduled_use_count_it != unscheduled_use_count_.end());
- entry.used_buffer_unscheduled_use_counts.emplace_back(
- unscheduled_use_count_it->first, unscheduled_use_count_it->second);
+ entry.used_buffer_unscheduled_use_counts.push_back(
+ &*unscheduled_use_count_it);
}
return entry;
}
@@ -185,8 +187,8 @@ class ListScheduler {
int64 BytesFreedIfScheduled(const ReadyListEntry& entry) {
int64 freed_bytes = 0;
for (const auto& kv : entry.used_buffer_unscheduled_use_counts) {
- auto buffer = kv.first;
- auto use_count = kv.second;
+ auto buffer = kv->first;
+ auto use_count = kv->second;
if (use_count == 1) {
freed_bytes += size_function_(*buffer);
}
@@ -217,23 +219,18 @@ class ListScheduler {
}
}
- auto priority_comparator =
- [this](const std::pair<Priority, ReadyListEntry>& lhs,
- const std::pair<Priority, ReadyListEntry>& rhs) {
- return lhs.first < rhs.first;
- };
- std::priority_queue<std::pair<Priority, ReadyListEntry>,
- std::vector<std::pair<Priority, ReadyListEntry>>,
- decltype(priority_comparator)>
- ready_queue(priority_comparator);
+ // Use a multimap to sort ReadyListEntry according to their priority.
+ std::multimap<Priority, ReadyListEntry> ready_queue;
- // Set of instructions in the ready list.
- tensorflow::gtl::FlatSet<const HloInstruction*> ready_instructions;
+ // Map of ready instructions to their iterators in ready_queue.
+ tensorflow::gtl::FlatMap<const HloInstruction*,
+ std::multimap<Priority, ReadyListEntry>::iterator>
+ ready_instructions;
auto add_to_ready_queue = [&](HloInstruction* inst) {
auto entry = MakeReadyListEntry(inst);
- ready_queue.emplace(GetPriority(entry), std::move(entry));
- ready_instructions.insert(inst);
+ auto it = ready_queue.emplace(GetPriority(entry), std::move(entry));
+ ready_instructions[inst] = it;
};
for (auto* instruction : computation_.instructions()) {
@@ -247,14 +244,10 @@ class ListScheduler {
while (!ready_queue.empty()) {
// Remove the selected instruction from the ready list and add it to the
// schedule.
- const HloInstruction* best = ready_queue.top().second.instruction;
- ready_queue.pop();
- // We may have duplicates in the priority queue, because when a ready
- // instruction's priority goes up, we reinsert it to the priority queue.
- // Skip the duplicate.
- if (scheduled_instructions_.find(best) != scheduled_instructions_.end()) {
- continue;
- }
+ auto best_it = ready_queue.end();
+ --best_it;
+ const HloInstruction* best = best_it->second.instruction;
+ ready_queue.erase(best_it);
ready_instructions.erase(best);
schedule.push_back(best);
scheduled_instructions_.insert(best);
@@ -287,16 +280,27 @@ class ListScheduler {
update_pred_count(succ);
}
// The unscheduled use count for a buffer has changed to 1, so the
- // priorities of some ready instructions may go up. We reinsert them to
- // the priority queue, so that they can appear earlier. The old entries
- // will become duplicates and will be skipped.
+ // priorities of some ready instructions may go up. We update them in the
+ // ready queue, so that they can appear earlier.
if (adjust_ready_queue) {
for (HloInstruction* operand : best->operands()) {
for (HloInstruction* operand_user : operand->users()) {
- if (ready_instructions.find(operand_user) !=
- ready_instructions.end()) {
- add_to_ready_queue(operand_user);
+ auto ready_instructions_it = ready_instructions.find(operand_user);
+ if (ready_instructions_it == ready_instructions.end()) {
+ continue;
+ }
+ auto ready_queue_it = ready_instructions_it->second;
+ auto& entry = ready_queue_it->second;
+ Priority new_priority = GetPriority(entry);
+ if (new_priority == ready_queue_it->first) {
+ continue;
}
+ // Create a new entry in ready_queue, then update
+ // ready_instructions[operand_user] to refer to the new entry.
+ ready_instructions_it->second =
+ ready_queue.emplace(new_priority, std::move(entry));
+ // Remove the old entry in ready_queue.
+ ready_queue.erase(ready_queue_it);
}
}
}
@@ -317,8 +321,9 @@ class ListScheduler {
buffer_uses_;
// A map containing the count of unscheduled HLOs which using a particular
- // LogicalBuffer. We rely on iterator stability in this map.
- tensorflow::gtl::FlatMap<const LogicalBuffer*, int64> unscheduled_use_count_;
+ // LogicalBuffer. We rely on iterator stability in this map, and that the map
+ // entries are std::pair's.
+ std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_;
// Set of instructions which have been scheduled.
tensorflow::gtl::FlatSet<const HloInstruction*> scheduled_instructions_;