aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.cc29
1 files changed, 21 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
index bf30764488..5cee865b7a 100644
--- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
@@ -195,13 +195,15 @@ class ListScheduler {
return entry;
}
- // Returns the number of bytes freed if the HLO instruction is scheduled.
- // If the instruction calls subcomputations, we count the memory used by the
- // subcomputations as memory "defined" by the instruction. This is not
- // entirely accurate, because subcomputation memory will be freed after the
- // instruction finishes. But it is more accurate than not taking
- // subcomputations into account at all. In the future, we may improve
- // accounting for subcomputation memory (b/65409243).
+ // Returns the number of bytes freed *after* the HLO instruction finishes.
+ // The current List algorithm only considers two states for an instruction:
+ // right before it runs, and after it finishes. We don't represent memory
+ // usage during the execution of an instruction. But if the instruction calls
+ // subcomputations, they are only live during the instruction's execution.
+ // We end up counting the memory used by subcomputations as memory "defined"
+ // by the instruction. This is not entirely accurate, but it is more accurate
+ // than not taking subcomputations into account at all. In the future, we may
+ // improve accounting for subcomputation memory (b/65409243).
int64 BytesFreedIfScheduled(const ReadyListEntry& entry) {
int64 freed_bytes = 0;
for (const auto& kv : entry.used_buffer_unscheduled_use_counts) {
@@ -223,7 +225,18 @@ class ListScheduler {
}
}
}
- return freed_bytes - entry.bytes_defined - max_subcomputation_bytes;
+ int64 bytes_defined;
+ if (max_subcomputation_bytes > 0 &&
+ (entry.instruction->opcode() == HloOpcode::kWhile ||
+ entry.instruction->opcode() == HloOpcode::kCall ||
+ entry.instruction->opcode() == HloOpcode::kConditional)) {
+ // The output buffer of while/call/conditional is always aliased with the
+ // output buffer of the root instruction in the body. Don't double count.
+ bytes_defined = max_subcomputation_bytes;
+ } else {
+ bytes_defined = entry.bytes_defined + max_subcomputation_bytes;
+ }
+ return freed_bytes - bytes_defined;
}
// Constructs the scheduling priority of the given instruction.