diff options
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_rematerialization.cc | 55 |
1 files changed, 29 insertions, 26 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 44293f582e..b1ee2e46b0 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -1160,28 +1160,25 @@ StatusOr<bool> HloRematerialization::Run( TuplePointsToAnalysis::Run( module, /*include_loop_fusion_instructions=*/true)); - // Adjust memory limit to account for the parameter and output of the entry + // Adjust memory limit to account for the output of the entry // computation. This is necessary because the per-computation accounting in - // MemoryUsageTracker do not include parameters and output as these are - // typically allocated by the caller. With this adjustment the memory limit - // accounts for the size of all HLO instructions (parameters, output - // instructions, etc). - auto total_size = [this](const HloInstruction* instruction) { - int64 total_size = 0; - for (const LogicalBuffer* logical_buffer : - points_to_analysis_->GetBuffersDefinedByInstruction(instruction)) { - total_size += size_function_(logical_buffer->shape()); - } - return total_size; - }; - const HloComputation* entry_computation = module->entry_computation(); - memory_limit_bytes -= total_size(entry_computation->root_instruction()); - for (const HloInstruction* param : - entry_computation->parameter_instructions()) { - memory_limit_bytes -= total_size(param); - } - VLOG(1) << "Adjusted memory limit accounting for parameters and output: " - << HumanReadableNumBytes(memory_limit_bytes); + // MemoryUsageTracker do not include output as these are typically allocated + // by the caller. + int64 module_output_size = 0; + ShapeUtil::ForEachSubshape( + module->entry_computation()->root_instruction()->shape(), + [&module_output_size, this](const Shape& subshape, + const ShapeIndex& /*index*/) { + module_output_size += size_function_(subshape); + return Status::OK(); + }) + .IgnoreError(); + + const int64 adjusted_memory_limit_bytes = + memory_limit_bytes - module_output_size; + VLOG(1) << "Adjusted memory limit accounting for output (" + << HumanReadableNumBytes(module_output_size) + << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes); XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); // Create initial sequence of HLO instructions. @@ -1204,8 +1201,13 @@ StatusOr<bool> HloRematerialization::Run( return Status::OK(); })); + // The peak memory usage of the module equals the peak memory use of the entry + // computation plus the output size of the computation. This is because the + // peak memory for a computation does not include the output as this is + // typically accounted for in the caller. const int64 before_peak_memory = - computation_peak_memory_.at(module->entry_computation()); + computation_peak_memory_.at(module->entry_computation()) + + module_output_size; VLOG(1) << "Peak memory usage of module (before): " << HumanReadableNumBytes(before_peak_memory); @@ -1216,9 +1218,9 @@ StatusOr<bool> HloRematerialization::Run( // Subcomputations called by the entry computation will also be // rematerialized. - TF_ASSIGN_OR_RETURN(bool changed, - RematerializeComputation(module->entry_computation(), - sequence, memory_limit_bytes)); + TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( + module->entry_computation(), sequence, + adjusted_memory_limit_bytes)); // Rematerialization can introduce dead code. This occurs if all uses of an // instruction are replaced with rematerializations of the instruction. @@ -1257,7 +1259,8 @@ StatusOr<bool> HloRematerialization::Run( << " instructions in module " << module->name() << "; " << net_instructions_added_ << " net instructions added"; const int64 current_peak_memory = - computation_peak_memory_.at(module->entry_computation()); + computation_peak_memory_.at(module->entry_computation()) + + module_output_size; VLOG(1) << "Peak memory usage of module now " << HumanReadableNumBytes(current_peak_memory) << " (" << current_peak_memory << " bytes), was " |