From 08ecc62a38dc58e85cb46ad281486d1c75b1db9b Mon Sep 17 00:00:00 2001 From: Dimitris Vardoulakis Date: Thu, 4 Oct 2018 13:43:31 -0700 Subject: [TF:XLA] Improve the accounting for subcomputations in the List scheduler to avoid double-counting. PiperOrigin-RevId: 215795640 --- .../compiler/xla/service/hlo_memory_scheduler.cc | 29 ++++++++++++++++------ 1 file changed, 21 insertions(+), 8 deletions(-) (limited to 'tensorflow/compiler') diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index bf30764488..5cee865b7a 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -195,13 +195,15 @@ class ListScheduler { return entry; } - // Returns the number of bytes freed if the HLO instruction is scheduled. - // If the instruction calls subcomputations, we count the memory used by the - // subcomputations as memory "defined" by the instruction. This is not - // entirely accurate, because subcomputation memory will be freed after the - // instruction finishes. But it is more accurate than not taking - // subcomputations into account at all. In the future, we may improve - // accounting for subcomputation memory (b/65409243). + // Returns the number of bytes freed *after* the HLO instruction finishes. + // The current List algorithm only considers two states for an instruction: + // right before it runs, and after it finishes. We don't represent memory + // usage during the execution of an instruction. But if the instruction calls + // subcomputations, they are only live during the instruction's execution. + // We end up counting the memory used by subcomputations as memory "defined" + // by the instruction. This is not entirely accurate, but it is more accurate + // than not taking subcomputations into account at all. In the future, we may + // improve accounting for subcomputation memory (b/65409243). int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { int64 freed_bytes = 0; for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { @@ -223,7 +225,18 @@ class ListScheduler { } } } - return freed_bytes - entry.bytes_defined - max_subcomputation_bytes; + int64 bytes_defined; + if (max_subcomputation_bytes > 0 && + (entry.instruction->opcode() == HloOpcode::kWhile || + entry.instruction->opcode() == HloOpcode::kCall || + entry.instruction->opcode() == HloOpcode::kConditional)) { + // The output buffer of while/call/conditional is always aliased with the + // output buffer of the root instruction in the body. Don't double count. + bytes_defined = max_subcomputation_bytes; + } else { + bytes_defined = entry.bytes_defined + max_subcomputation_bytes; + } + return freed_bytes - bytes_defined; } // Constructs the scheduling priority of the given instruction. -- cgit v1.2.3