aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Yuanzhong Xu <yuanzx@google.com>2018-09-25 13:04:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 13:08:20 -0700
commita50dff24b6f38fef7ead20e1015509cac905ed29 (patch)
treee0b2036b53bd289cf56acde6f99eb62e8a0afe57 /tensorflow
parent471e20a6738a326adeb0eef2d158b61bbfd23d6d (diff)
[XLA] Avoid recursion in global decreasing size best-fit heap.
PiperOrigin-RevId: 214489542
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc45
1 files changed, 23 insertions, 22 deletions
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index a07eaaf997..2bd04259c0 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -827,33 +827,34 @@ class BufferIntervalTree {
// interval.
std::vector<Chunk> ChunksOverlappingInTime(int64 start, int64 end) {
std::vector<Chunk> result;
- if (node_count_ > 0) {
- ChunksOverlappingInTimeHelper(start, end, &node_storage_[0], &result);
+ if (node_count_ == 0) {
+ return result;
+ }
+ std::vector<BufferIntervalTreeNode*> visiting_stack;
+ visiting_stack.push_back(&node_storage_[0]);
+ while (!visiting_stack.empty()) {
+ BufferIntervalTreeNode* top = visiting_stack.back();
+ visiting_stack.pop_back();
+ if (start > top->subtree_end) {
+ continue;
+ }
+ if (top->left != nullptr) {
+ visiting_stack.push_back(top->left);
+ }
+ if (top->start <= end && top->end >= start) {
+ result.push_back(top->chunk);
+ }
+ if (end < top->start) {
+ continue;
+ }
+ if (top->right != nullptr) {
+ visiting_stack.push_back(top->right);
+ }
}
return result;
}
private:
- void ChunksOverlappingInTimeHelper(int64 start, int64 end,
- BufferIntervalTreeNode* visiting_node,
- std::vector<Chunk>* result) {
- if (start > visiting_node->subtree_end) {
- return;
- }
- if (visiting_node->left != nullptr) {
- ChunksOverlappingInTimeHelper(start, end, visiting_node->left, result);
- }
- if (visiting_node->start <= end && visiting_node->end >= start) {
- result->push_back(visiting_node->chunk);
- }
- if (end < visiting_node->start) {
- return;
- }
- if (visiting_node->right != nullptr) {
- ChunksOverlappingInTimeHelper(start, end, visiting_node->right, result);
- }
- }
-
int64 node_count_ = 0;
std::vector<BufferIntervalTreeNode> node_storage_;
};