diff options
author | Dimitris Vardoulakis <dimvar@google.com> | 2018-10-04 13:43:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 13:53:22 -0700 |
commit | 08ecc62a38dc58e85cb46ad281486d1c75b1db9b (patch) | |
tree | bc620a28a23ff465d582b354dd46b5fc98a004f2 /tensorflow/compiler | |
parent | d96e073e77929006c519cd3082461d9757865dd7 (diff) |
[TF:XLA] Improve the accounting for subcomputations in the List scheduler to avoid double-counting.
PiperOrigin-RevId: 215795640
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_memory_scheduler.cc | 29 |
1 files changed, 21 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index bf30764488..5cee865b7a 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -195,13 +195,15 @@ class ListScheduler { return entry; } - // Returns the number of bytes freed if the HLO instruction is scheduled. - // If the instruction calls subcomputations, we count the memory used by the - // subcomputations as memory "defined" by the instruction. This is not - // entirely accurate, because subcomputation memory will be freed after the - // instruction finishes. But it is more accurate than not taking - // subcomputations into account at all. In the future, we may improve - // accounting for subcomputation memory (b/65409243). + // Returns the number of bytes freed *after* the HLO instruction finishes. + // The current List algorithm only considers two states for an instruction: + // right before it runs, and after it finishes. We don't represent memory + // usage during the execution of an instruction. But if the instruction calls + // subcomputations, they are only live during the instruction's execution. + // We end up counting the memory used by subcomputations as memory "defined" + // by the instruction. This is not entirely accurate, but it is more accurate + // than not taking subcomputations into account at all. In the future, we may + // improve accounting for subcomputation memory (b/65409243). int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { int64 freed_bytes = 0; for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { @@ -223,7 +225,18 @@ class ListScheduler { } } } - return freed_bytes - entry.bytes_defined - max_subcomputation_bytes; + int64 bytes_defined; + if (max_subcomputation_bytes > 0 && + (entry.instruction->opcode() == HloOpcode::kWhile || + entry.instruction->opcode() == HloOpcode::kCall || + entry.instruction->opcode() == HloOpcode::kConditional)) { + // The output buffer of while/call/conditional is always aliased with the + // output buffer of the root instruction in the body. Don't double count. + bytes_defined = max_subcomputation_bytes; + } else { + bytes_defined = entry.bytes_defined + max_subcomputation_bytes; + } + return freed_bytes - bytes_defined; } // Constructs the scheduling priority of the given instruction. |