aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc55
1 files changed, 29 insertions, 26 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 44293f582e..b1ee2e46b0 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -1160,28 +1160,25 @@ StatusOr<bool> HloRematerialization::Run(
TuplePointsToAnalysis::Run(
module, /*include_loop_fusion_instructions=*/true));
- // Adjust memory limit to account for the parameter and output of the entry
+ // Adjust memory limit to account for the output of the entry
// computation. This is necessary because the per-computation accounting in
- // MemoryUsageTracker do not include parameters and output as these are
- // typically allocated by the caller. With this adjustment the memory limit
- // accounts for the size of all HLO instructions (parameters, output
- // instructions, etc).
- auto total_size = [this](const HloInstruction* instruction) {
- int64 total_size = 0;
- for (const LogicalBuffer* logical_buffer :
- points_to_analysis_->GetBuffersDefinedByInstruction(instruction)) {
- total_size += size_function_(logical_buffer->shape());
- }
- return total_size;
- };
- const HloComputation* entry_computation = module->entry_computation();
- memory_limit_bytes -= total_size(entry_computation->root_instruction());
- for (const HloInstruction* param :
- entry_computation->parameter_instructions()) {
- memory_limit_bytes -= total_size(param);
- }
- VLOG(1) << "Adjusted memory limit accounting for parameters and output: "
- << HumanReadableNumBytes(memory_limit_bytes);
+ // MemoryUsageTracker do not include output as these are typically allocated
+ // by the caller.
+ int64 module_output_size = 0;
+ ShapeUtil::ForEachSubshape(
+ module->entry_computation()->root_instruction()->shape(),
+ [&module_output_size, this](const Shape& subshape,
+ const ShapeIndex& /*index*/) {
+ module_output_size += size_function_(subshape);
+ return Status::OK();
+ })
+ .IgnoreError();
+
+ const int64 adjusted_memory_limit_bytes =
+ memory_limit_bytes - module_output_size;
+ VLOG(1) << "Adjusted memory limit accounting for output ("
+ << HumanReadableNumBytes(module_output_size)
+ << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
// Create initial sequence of HLO instructions.
@@ -1204,8 +1201,13 @@ StatusOr<bool> HloRematerialization::Run(
return Status::OK();
}));
+ // The peak memory usage of the module equals the peak memory use of the entry
+ // computation plus the output size of the computation. This is because the
+ // peak memory for a computation does not include the output as this is
+ // typically accounted for in the caller.
const int64 before_peak_memory =
- computation_peak_memory_.at(module->entry_computation());
+ computation_peak_memory_.at(module->entry_computation()) +
+ module_output_size;
VLOG(1) << "Peak memory usage of module (before): "
<< HumanReadableNumBytes(before_peak_memory);
@@ -1216,9 +1218,9 @@ StatusOr<bool> HloRematerialization::Run(
// Subcomputations called by the entry computation will also be
// rematerialized.
- TF_ASSIGN_OR_RETURN(bool changed,
- RematerializeComputation(module->entry_computation(),
- sequence, memory_limit_bytes));
+ TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation(
+ module->entry_computation(), sequence,
+ adjusted_memory_limit_bytes));
// Rematerialization can introduce dead code. This occurs if all uses of an
// instruction are replaced with rematerializations of the instruction.
@@ -1257,7 +1259,8 @@ StatusOr<bool> HloRematerialization::Run(
<< " instructions in module " << module->name() << "; "
<< net_instructions_added_ << " net instructions added";
const int64 current_peak_memory =
- computation_peak_memory_.at(module->entry_computation());
+ computation_peak_memory_.at(module->entry_computation()) +
+ module_output_size;
VLOG(1) << "Peak memory usage of module now "
<< HumanReadableNumBytes(current_peak_memory) << " ("
<< current_peak_memory << " bytes), was "