aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Dimitris Vardoulakis <dimvar@google.com>2018-10-04 13:43:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 13:53:22 -0700
commit08ecc62a38dc58e85cb46ad281486d1c75b1db9b (patch)
treebc620a28a23ff465d582b354dd46b5fc98a004f2 /tensorflow/compiler
parentd96e073e77929006c519cd3082461d9757865dd7 (diff)
[TF:XLA] Improve the accounting for subcomputations in the List scheduler to avoid double-counting.
PiperOrigin-RevId: 215795640
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.cc29
1 files changed, 21 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
index bf30764488..5cee865b7a 100644
--- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
@@ -195,13 +195,15 @@ class ListScheduler {
return entry;
}
- // Returns the number of bytes freed if the HLO instruction is scheduled.
- // If the instruction calls subcomputations, we count the memory used by the
- // subcomputations as memory "defined" by the instruction. This is not
- // entirely accurate, because subcomputation memory will be freed after the
- // instruction finishes. But it is more accurate than not taking
- // subcomputations into account at all. In the future, we may improve
- // accounting for subcomputation memory (b/65409243).
+ // Returns the number of bytes freed *after* the HLO instruction finishes.
+ // The current List algorithm only considers two states for an instruction:
+ // right before it runs, and after it finishes. We don't represent memory
+ // usage during the execution of an instruction. But if the instruction calls
+ // subcomputations, they are only live during the instruction's execution.
+ // We end up counting the memory used by subcomputations as memory "defined"
+ // by the instruction. This is not entirely accurate, but it is more accurate
+ // than not taking subcomputations into account at all. In the future, we may
+ // improve accounting for subcomputation memory (b/65409243).
int64 BytesFreedIfScheduled(const ReadyListEntry& entry) {
int64 freed_bytes = 0;
for (const auto& kv : entry.used_buffer_unscheduled_use_counts) {
@@ -223,7 +225,18 @@ class ListScheduler {
}
}
}
- return freed_bytes - entry.bytes_defined - max_subcomputation_bytes;
+ int64 bytes_defined;
+ if (max_subcomputation_bytes > 0 &&
+ (entry.instruction->opcode() == HloOpcode::kWhile ||
+ entry.instruction->opcode() == HloOpcode::kCall ||
+ entry.instruction->opcode() == HloOpcode::kConditional)) {
+ // The output buffer of while/call/conditional is always aliased with the
+ // output buffer of the root instruction in the body. Don't double count.
+ bytes_defined = max_subcomputation_bytes;
+ } else {
+ bytes_defined = entry.bytes_defined + max_subcomputation_bytes;
+ }
+ return freed_bytes - bytes_defined;
}
// Constructs the scheduling priority of the given instruction.