diff options
author | 2017-06-16 16:18:44 -0700 | |
---|---|---|
committer | 2017-06-16 16:24:15 -0700 | |
commit | 2473505b70f03838fef4ae6108387ba66d443d62 (patch) | |
tree | 5e1bd767e5e5cc1d8cf2e1d329e8bd7108b29c86 | |
parent | df609c9de4ea0cae0fb1d41893b8071d67bd6bb2 (diff) |
Adjust rematerialization cost function to only be the inverse of its benefit (total_memory/saved_memory).
PiperOrigin-RevId: 159290105
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_rematerialization.cc | 57 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_rematerialization.h | 6 |
2 files changed, 15 insertions, 48 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 2c1b0fff4e..fb6d8674b6 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -58,9 +58,8 @@ bool IsRematerializable(const HloInstruction* instruction) { return false; } - // Don't rematerialize instructions with side effects, those with a cost that - // might not be captured by HloCostAnalysis, or instructions which cannot be - // cloned safely. + // Don't rematerialize instructions with side effects or instructions which + // cannot be cloned safely. switch (instruction->opcode()) { case HloOpcode::kCall: case HloOpcode::kConstant: @@ -802,23 +801,14 @@ bool MemoryUsageTracker::Check() const { // Computes and returns the cost of rematerializing the given instruction. // Cost per rematerialized instruction is defined as: // -// (flop_count + transcendental_count + element_count) / memory_reduced +// memory_limit_bytes / memory_reduced // -// flop_count: from HloCostAnalysis -// transcendental_count: from HloCostAnalysis -// element_count: number of elements accessed in operands and output of -// instruction -// memory_reduced: The memory usage reduced by rematerializing the -// instruction. -// -// This is a rough estimate of the extra execution time per byte saved by -// rematerializing this instruction for its remaining uses. In general, we -// want the most memory saving for the least latency penalty which is captured -// by this heuristic. +// The idea is to choose the operation that will save the most memory for +// rematerialization and do not worry about how much the compute costs since +// running out of memory is more harmful than taking longer to get the answer. int64 RematerializationCost(const HloInstruction* instruction, const MemoryUsageTracker& memory_tracker, - const HloCostAnalysis& cost_analysis, - int64 memory_reduced) { + int64 memory_reduced, int64 memory_limit_bytes) { // If none of the users of 'instruction' have been placed in the sequence (as // tracked by memory_tracker), then rematerialization of 'instruction' is a // zero-cost move of 'instruction' in the sequence. @@ -830,22 +820,8 @@ int64 RematerializationCost(const HloInstruction* instruction, } CHECK_GT(memory_reduced, 0); - const int64 bytes_accessed = cost_analysis.bytes_accessed(*instruction); - const int64 elements_accessed = - ShapeUtil::IsTuple(instruction->shape()) - ? bytes_accessed - : bytes_accessed / ShapeUtil::ByteSizeOfPrimitiveType( - instruction->shape().element_type()); - - // Multiply by 256 to improve precision of cost. Without this factor, - // many instructions such as many elementwise instructions would have - // zero cost because the bytes reduced can be several times greater than - // the element count. - return 256 * - (cost_analysis.flop_count(*instruction) + - cost_analysis.transcendental_count(*instruction) + - elements_accessed) / - memory_reduced; + // Return the inverse of the benefit of rematerialization. + return memory_limit_bytes / memory_reduced; } // Selects and returns the best candidate instruction for rematerialization. @@ -856,8 +832,8 @@ int64 RematerializationCost(const HloInstruction* instruction, HloInstruction* PickRematerializationCandidate( const MemoryUsageTracker& memory_tracker, const InstructionList& instruction_list, - const HloCostAnalysis& cost_analysis, - const tensorflow::gtl::FlatSet<const HloInstruction*>& blacklist) { + const tensorflow::gtl::FlatSet<const HloInstruction*>& blacklist, + int64 memory_limit_bytes) { HloInstruction* best = nullptr; int64 best_cost = 0; @@ -891,12 +867,12 @@ HloInstruction* PickRematerializationCandidate( if (memory_reduced <= 0) { VLOG(5) << "candidate " << candidate->name() - << " memory reduced = " << memory_reduced << " <= 0"; + << " memory reduced = " << memory_reduced << " <= 0"; continue; } const int cost = RematerializationCost(candidate, memory_tracker, - cost_analysis, memory_reduced); + memory_reduced, memory_limit_bytes); VLOG(5) << "candidate " << candidate->name() << ", memory reduced " << memory_reduced << ", cost per byte " << cost; @@ -1011,7 +987,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation( << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); HloInstruction* best = PickRematerializationCandidate( - memory_tracker, instruction_list, cost_analysis_, blacklist); + memory_tracker, instruction_list, blacklist, memory_limit_bytes); if (best == nullptr) { VLOG(3) << "Unable to find rematerialization candidate at program " @@ -1211,11 +1187,6 @@ StatusOr<bool> HloRematerialization::Run( VLOG(1) << "Peak memory usage of module (before): " << HumanReadableNumBytes(before_peak_memory); - // Run cost analysis. Operation cost is used in the heuristic for selecting - // instructions for rematerialization. - TF_RETURN_IF_ERROR( - module->entry_computation()->root_instruction()->Accept(&cost_analysis_)); - // Subcomputations called by the entry computation will also be // rematerialized. TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 1693f93183..42c279d440 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -18,7 +18,6 @@ #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -61,7 +60,7 @@ class HloRematerialization { protected: HloRematerialization(const ShapeSizeFunction& size_function) - : size_function_(size_function), cost_analysis_(size_function_) {} + : size_function_(size_function) {} ~HloRematerialization() {} // Runs rematerialization on the given module. Returns whether the module was @@ -100,9 +99,6 @@ class HloRematerialization { // Call graph of the hlo_module. std::unique_ptr<CallGraph> call_graph_; - // Analysis used for computing the rematerialization cost of instructions. - HloCostAnalysis cost_analysis_; - // The peak memory usage of each computation. The map contains only those // computations called from sequential context // (CallContext::kSequential). These values are updated as rematerialization |