diff options
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 31 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_assignment.cc | 126 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_assignment.h | 29 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/heap_simulator.cc | 90 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/heap_simulator.h | 32 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/heap_simulator_test.cc | 126 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_ordering.cc | 62 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_ordering_test.cc | 59 |
8 files changed, 441 insertions, 114 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 750e1ee3f2..2452158efa 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -666,8 +666,8 @@ cc_library( ], deps = [ ":buffer_liveness", - ":heap_simulator", ":hlo", + ":hlo_ordering", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -707,51 +707,38 @@ cc_test( ], ) -cc_library( - name = "heap_simulator", - srcs = [ - "heap_simulator.cc", - ], - hdrs = [ - "heap_simulator.h", - ], - deps = [ - ":hlo", - ":liveness_util", - ":logical_buffer", - ":tuple_points_to_analysis", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", - ], -) - cc_test( name = "heap_simulator_test", srcs = ["heap_simulator_test.cc"], deps = [ - ":heap_simulator", ":hlo", + ":hlo_ordering", ":logical_buffer", ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:lib", "//tensorflow/core:test_main", ], ) +# The hlo_ordering library contains both hlo_ordering and heap_simulator because +# they are mutually dependent. cc_library( name = "hlo_ordering", srcs = [ + "heap_simulator.cc", "hlo_ordering.cc", ], hdrs = [ + "heap_simulator.h", "hlo_ordering.h", ], deps = [ ":call_graph", - ":heap_simulator", ":hlo", + ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 3cdbf892f7..a79468f939 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -548,6 +548,8 @@ Status BufferAssigner::AssignBuffersForComputation( const FlatSet<const HloInstruction*>* hlos_to_allocate, const FlatSet<const LogicalBuffer*>& colocated_buffers, const FlatSet<BufferAllocation::Index>& colocated_allocations, + FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>* + buffers_to_assign_sequentially, BufferAssignment* assignment) { // Buffers are sorted and assigned to BufferAllocations in decreasing order of // size. @@ -578,9 +580,16 @@ Status BufferAssigner::AssignBuffersForComputation( // If there is a sequential instruction ordering, we'll delay assignment of // temp buffers until after the main assignment loop. const BufferLiveness& liveness = assignment->liveness(); - const std::vector<const HloInstruction*>* sequential_order = - liveness.hlo_ordering().SequentialOrder(*computation); - FlatSet<const LogicalBuffer*> unassigned_temp_buffers; + const bool has_sequential_order = + liveness.hlo_ordering().SequentialOrder(*computation) != nullptr; + if (has_sequential_order && buffers_to_assign_sequentially != nullptr) { + // Every sequential computation must get an entry in the + // buffers_to_assign_sequentially map, even if we end up with an empty set + // of buffers. This ensures we can correctly determine whether to run + // whole-module heap simulation. + buffers_to_assign_sequentially->emplace(computation, + FlatSet<const LogicalBuffer*>()); + } // Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers // first for simplicity. This means any previously created BufferAllocation is @@ -599,7 +608,7 @@ Status BufferAssigner::AssignBuffersForComputation( // important reuse case where an elementwise instruction reuses one of its // operand's buffer. This improves locality. std::sort(sorted_buffers.begin(), sorted_buffers.end(), - [this, sequential_order, &liveness, &post_order_position]( + [this, has_sequential_order, &liveness, &post_order_position]( const LogicalBuffer* a, const LogicalBuffer* b) { // Primary sort is by decreasing buffer size. const int64 a_size = buffer_size_(*a); @@ -609,7 +618,7 @@ Status BufferAssigner::AssignBuffersForComputation( } // Otherwise live out buffers come before others, if the // instructions are sequentially ordered. - if (sequential_order != nullptr) { + if (has_sequential_order) { const bool a_live_out = liveness.MaybeLiveOut(*a); const bool b_live_out = liveness.MaybeLiveOut(*b); if (a_live_out != b_live_out) { @@ -746,7 +755,7 @@ Status BufferAssigner::AssignBuffersForComputation( } } - if (!assignment->HasAllocation(*buffer) && sequential_order != nullptr && + if (!assignment->HasAllocation(*buffer) && has_sequential_order && !liveness.MaybeLiveOut(*buffer)) { // There is a sequential instruction ordering, so we delay assignment of // temp buffers until after the loop. We do this right before we decide to @@ -758,7 +767,7 @@ Status BufferAssigner::AssignBuffersForComputation( // for the definition of temp buffers. CHECK(!is_entry_parameter) << *buffer; CHECK(!is_thread_local) << *buffer; - unassigned_temp_buffers.insert(buffer); + (*buffers_to_assign_sequentially)[computation].insert(buffer); VLOG(3) << "Delaying assignment of temp buffer: " << *buffer; continue; } @@ -772,27 +781,68 @@ Status BufferAssigner::AssignBuffersForComputation( } } - if (!unassigned_temp_buffers.empty()) { - TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering( - *sequential_order, unassigned_temp_buffers, *computation, assignment)); - } return Status::OK(); } Status BufferAssigner::AssignBuffersWithSequentialOrdering( - const std::vector<const HloInstruction*>& sequence, - const FlatSet<const LogicalBuffer*>& buffers_to_assign, - const HloComputation& computation, BufferAssignment* assignment) { + const FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>& + buffers_to_assign_sequentially, + bool run_whole_module_heap_simulation, BufferAssignment* assignment) { // Run the sequence of instructions through the heap simulator. The heuristic // that seems to give the best results is lazy-best-fit, with all runs of // alloc / free calls sorted in decreasing size order. - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>( - MakeUnique<LazyBestFitHeap>(alignment_)), - sequence, computation, - assignment->points_to_analysis(), buffer_size_, - &buffers_to_assign)); + const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering(); + if (run_whole_module_heap_simulation) { + // Run the heap simulation over the whole module. This reduces memory usage, + // since buffers for kCall and kWhile sub-computations are only live for the + // duration of their calling instructions. + VLOG(1) << "Running whole-module heap simulation"; + SequentialHloOrdering::HloModuleSequence module_sequence; + FlatSet<const LogicalBuffer*> all_buffers_to_assign; + for (const auto& pair : buffers_to_assign_sequentially) { + const HloComputation* computation = pair.first; + const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second; + const std::vector<const HloInstruction*>* instruction_sequence = + hlo_ordering.SequentialOrder(*computation); + CHECK(instruction_sequence != nullptr) << computation->name(); + module_sequence[computation] = *instruction_sequence; + all_buffers_to_assign.insert(buffers_to_assign.begin(), + buffers_to_assign.end()); + } + TF_ASSIGN_OR_RETURN( + const HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>( + MakeUnique<LazyBestFitHeap>(alignment_)), + assignment->module(), module_sequence, + assignment->points_to_analysis(), buffer_size_, + &all_buffers_to_assign)); + AssignBuffersFromHeapSimulator(result, assignment); + } else { + // Run the heap-simulation on a per-computation basis. Buffers for + // sub-computations are assigned disjoint BufferAllocations, assuming the + // worst-case that they may all be live concurrently. + VLOG(1) << "Running per-computation heap simulation"; + for (const auto& pair : buffers_to_assign_sequentially) { + const HloComputation* computation = pair.first; + const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second; + const std::vector<const HloInstruction*>* instruction_sequence = + hlo_ordering.SequentialOrder(*computation); + CHECK(instruction_sequence != nullptr) << computation->name(); + TF_ASSIGN_OR_RETURN( + const HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>( + MakeUnique<LazyBestFitHeap>(alignment_)), + *computation, *instruction_sequence, + assignment->points_to_analysis(), buffer_size_, + &buffers_to_assign)); + AssignBuffersFromHeapSimulator(result, assignment); + } + } + return Status::OK(); +} + +void BufferAssigner::AssignBuffersFromHeapSimulator( + const HeapSimulator::Result& result, BufferAssignment* assignment) { if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) { assignment->stats_.preallocated_temp_fragmentation_bytes = result.fragmentation_size; @@ -801,8 +851,6 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( result.fragmentation_size; } - // Use the results of the heap simulator to create one allocation per - // computation, with LogicalBuffers packed to specific offsets. BufferAllocation* allocation = assignment->NewEmptyAllocation( result.heap_size, /*is_thread_local=*/false, /*is_reusable=*/true); for (const auto& buffer_chunk : result.chunk_map) { @@ -810,7 +858,6 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HeapSimulator::Chunk& chunk = buffer_chunk.second; assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size); } - return Status::OK(); } // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining @@ -1108,8 +1155,6 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment( TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferLiveness> liveness, BufferLiveness::Run(module, std::move(hlo_ordering))); - std::vector<const HloComputation*> thread_local_computations; - std::vector<const HloComputation*> global_computations; VLOG(1) << "Assigning buffers to module " << module->name(); if (hlos_to_allocate != nullptr) { VLOG(3) << "LogicalBuffer assignment restricted to hlos: "; @@ -1121,9 +1166,6 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment( XLA_VLOG_LINES(3, liveness->ToString()); XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString()); - TF_RETURN_IF_ERROR(GatherComputationsByAllocationType( - module, &thread_local_computations, &global_computations)); - // Set of HLO's to allocate if hlos_to_allocate is given. Passed as a set to // AssignBuffersForComputation for fast membership testing. std::unique_ptr<FlatSet<const HloInstruction*>> hlo_set; @@ -1148,16 +1190,38 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment( AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(), &colocated_buffers, &colocated_allocations); + std::vector<const HloComputation*> thread_local_computations; + std::vector<const HloComputation*> global_computations; + TF_RETURN_IF_ERROR(GatherComputationsByAllocationType( + module, &thread_local_computations, &global_computations)); + + // First assign buffers for global computatations. Temporary buffers for + // sequential computations are collected in 'buffers_to_assign_sequentially'. + FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>> + buffers_to_assign_sequentially; for (auto* computation : global_computations) { TF_RETURN_IF_ERROR(AssignBuffersForComputation( computation, /*is_thread_local=*/false, hlo_set.get(), - colocated_buffers, colocated_allocations, assignment.get())); + colocated_buffers, colocated_allocations, + &buffers_to_assign_sequentially, assignment.get())); } + // Assign buffers with sequential ordering, if any. If all global computations + // are sequential, we can run heap simuation on the whole module, which + // reduces memory usage. + const bool run_whole_module_heap_simulation = + buffers_to_assign_sequentially.size() == global_computations.size(); + TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering( + buffers_to_assign_sequentially, run_whole_module_heap_simulation, + assignment.get())); + + // Now assign buffers for thread-local computations. All LogicalBuffers get + // their own BufferAllocation. for (auto* computation : thread_local_computations) { TF_RET_CHECK(computation != module->entry_computation()); TF_RETURN_IF_ERROR(AssignBuffersForComputation( computation, /*is_thread_local=*/true, hlo_set.get(), colocated_buffers, - colocated_allocations, assignment.get())); + colocated_allocations, /*buffers_to_assign_sequentially=*/nullptr, + assignment.get())); } // Mark all buffers which may be live out of the entry computation as diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 34667c435d..7e96caf0f4 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -23,6 +23,7 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/xla/service/buffer_liveness.h" +#include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -354,6 +355,9 @@ class BufferAssignment { void AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer, int64 offset, int64 size); + // Returns the HloModule used to construct this assignment. + const HloModule& module() { return *module_; } + // Returns the BufferLiveness object used to construct this assignment. const BufferLiveness& liveness() { return *liveness_; } @@ -427,14 +431,27 @@ class BufferAssigner { const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers, const tensorflow::gtl::FlatSet<BufferAllocation::Index>& colocated_allocations, + tensorflow::gtl::FlatMap<const HloComputation*, + tensorflow::gtl::FlatSet<const LogicalBuffer*>>* + buffers_to_assign_sequentially, BufferAssignment* assignment); - // Assigns 'buffers_to_assign' assuming the HLO instructions will be executed - // in the given 'sequential_order'. + // Assigns 'buffers_to_assign_sequentially' using heap simulation, assuming + // the HLO instructions will be executed in the sequential order given by + // assignment->liveness().hlo_ordering().SequentialOrder. If + // 'run_whole_module_heap_simulation' is true, the heap simulation will be run + // assuming all global computations are sequentially ordered. Status AssignBuffersWithSequentialOrdering( - const std::vector<const HloInstruction*>& sequential_order, - const tensorflow::gtl::FlatSet<const LogicalBuffer*>& buffers_to_assign, - const HloComputation& computation, BufferAssignment* assignment); + const tensorflow::gtl::FlatMap< + const HloComputation*, + tensorflow::gtl::FlatSet<const LogicalBuffer*>>& + buffers_to_assign_sequentially, + bool run_whole_module_heap_simulation, BufferAssignment* assignment); + + // Uses the results of the heap simulator to create a single allocation, with + // LogicalBuffers packed to specific offsets. + void AssignBuffersFromHeapSimulator(const HeapSimulator::Result& result, + BufferAssignment* assignment); // Tries to assign the given instruction to the given buffer. Returns if the // assignment was successful. @@ -477,8 +494,6 @@ class BufferAssigner { const HloComputation& computation, const BufferLiveness& buffer_liveness, std::vector<ColocatedBufferSet>* colocated_buffer_sets); - const HloModule* module_; - // Function which returns the buffer size for a given logical buffer (shape). LogicalBuffer::SizeFunction buffer_size_; diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 9c4899a67d..d7aa5664df 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -53,12 +53,44 @@ std::vector<const LogicalBuffer*> UniqueOperandSourceBuffers( /*static*/ StatusOr<HeapSimulator::Result> HeapSimulator::Run( - std::unique_ptr<HeapAlgorithm> algorithm, + std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module, + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_fn, + const FlatSet<const LogicalBuffer*>* buffers_to_assign) { + HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign); + const HloComputation* entry_computation = module.entry_computation(); + const std::vector<const HloInstruction*>& instruction_sequence = + FindOrDie(module_sequence, entry_computation); + TF_RETURN_IF_ERROR(heap.RunComputation(*entry_computation, + instruction_sequence, + points_to_analysis, &module_sequence)); + return heap.Finish(); +} + +/*static*/ +StatusOr<HeapSimulator::Result> HeapSimulator::Run( + std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation, const std::vector<const HloInstruction*>& instruction_sequence, - const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_fn, const FlatSet<const LogicalBuffer*>* buffers_to_assign) { + HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign); + TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, + points_to_analysis, + /*module_sequence=*/nullptr)); + return heap.Finish(); +} + +// Runs a heap simulation for the given 'computation', assuming the given +// 'instruction_sequence'. If 'module_sequence' is non-null, it is used to find +// kCall and kWhile sub-computations, and the heap simulation for those +// sub-computations will be run recursively. +Status HeapSimulator::RunComputation( + const HloComputation& computation, + const std::vector<const HloInstruction*>& instruction_sequence, + const TuplePointsToAnalysis& points_to_analysis, + const SequentialHloOrdering::HloModuleSequence* module_sequence) { // The goal here is to minimize memory usage, assuming the given sequential // ordering of instructions. The strategy is to walk through the instruction // sequence, calling Alloc and Free on the underlying heap algorithm. The @@ -67,7 +99,6 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run( // 'live_buffers' tracks the liveness of each buffer that we assign, by // associating it with a set of HloInstructions that need to be visited. When // the set becomes empty, the buffer is no longer used, and can be freed. - HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign); FlatMap<const LogicalBuffer*, FlatSet<const HloInstruction*>> live_buffers; const HloInstruction* root = computation.root_instruction(); @@ -90,7 +121,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run( // lifetime of buffers that aren't already connected by a data dependency. std::vector<const LogicalBuffer*> dead_buffers_to_free; for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { - if (heap.IgnoreBuffer(buffer)) { + if (IgnoreBuffer(buffer)) { continue; } for (const BufferAlias& alias : @@ -127,7 +158,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run( std::vector<const LogicalBuffer*> operand_buffers_to_free; for (const LogicalBuffer* operand_buffer : UniqueOperandSourceBuffers(instruction, points_to_analysis)) { - if (heap.IgnoreBuffer(operand_buffer)) { + if (IgnoreBuffer(operand_buffer)) { continue; } live_buffers[operand_buffer].erase(instruction); @@ -142,10 +173,10 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run( // happen before dead or operand buffers are freed; the instruction reads // the operand buffers to produce its output. // - // INVARIANT: Either heap.Alloc or heap.ShareBuffer will be called for each - // buffer that we should assign. + // INVARIANT: Either Alloc or ShareBuffer will be called for each buffer + // that we should assign. for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { - if (heap.IgnoreBuffer(buffer)) { + if (IgnoreBuffer(buffer)) { continue; } @@ -159,24 +190,50 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run( CanShareOperandBufferWithUser( operand_buffer->instruction(), operand_buffer->index(), buffer->instruction(), buffer->index(), points_to_analysis)) { - heap.ShareBuffer(buffer, operand_buffer); + ShareBuffer(buffer, operand_buffer); shared = true; break; } } if (!shared) { - heap.Alloc(buffer); + Alloc(buffer); } } + // If the whole module is sequential, we can save memory by running the + // heap-simulation for sub-computations inline. E.g. the buffers for the + // condition and body of a kWhile instruction are only live for the duration + // of the instruction itself. + // + // The order that the sub-computations are simulated does not affect + // correctness; since the whole module is sequential, we know that the + // sub-computations will never be run concurrently. + if (module_sequence != nullptr) { + if (instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kWhile) { + for (const HloComputation* called_computation : + instruction->called_computations()) { + const std::vector<const HloInstruction*>& called_sequence = + FindOrDie(*module_sequence, called_computation); + TF_RETURN_IF_ERROR(RunComputation(*called_computation, + called_sequence, points_to_analysis, + module_sequence)); + } + } + + // Other sub-computations (e.g. Map, Reduce, ...) are skipped; they are + // assigned "thread-local" allocations, meaning their buffers are not + // allocated up-front at the beginning of the computation. + } + // Free buffers that are no longer live. This is the earliest point that we // can de-allocate; right after the last use of the buffer. for (const LogicalBuffer* buffer : dead_buffers_to_free) { - heap.Free(buffer); + Free(buffer); } for (const LogicalBuffer* buffer : operand_buffers_to_free) { - heap.Free(buffer); + Free(buffer); } } @@ -187,10 +244,10 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run( const FlatSet<const HloInstruction*>& pending = buffer_pending.second; CHECK_EQ(pending.size(), 1) << *buffer; CHECK(*pending.begin() == nullptr) << *buffer; - heap.Free(buffer); + Free(buffer); } - return heap.Finish(); + return Status::OK(); } HeapSimulator::HeapSimulator( @@ -309,6 +366,11 @@ HeapSimulator::Result HeapSimulator::Finish() { result.chunk_map.emplace(buffer, chunk); } } + // If we were told to assign specific buffers, make sure we've assigned + // exactly that many buffers. + if (buffers_to_assign_ != nullptr) { + CHECK_EQ(buffers_to_assign_->size(), result.chunk_map.size()); + } } // Fragmentation is the difference between the actual and ideal sizes. diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 0ce2906767..3d98046261 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" @@ -63,17 +64,32 @@ class HeapSimulator { }; // Run the heap simulation with the given algorithm, assuming the given - // sequential ordering of instructions. The 'instruction_sequence' must - // contain a topologically-consistent total ordering of all instructions in - // the computation. The result is invalid if instructions are not run in - // exactly this sequence. + // module_sequence, which must contain a topologically-consistent total + // ordering of all instructions within each computation. The result is invalid + // if instructions are not run in exactly this sequence. + // + // Running heap simulation on the whole module tends to save memory, compared + // to running on a per-computation basis, since we can re-use buffer space for + // called sub-computations. // // If 'buffers_to_assign' is provided, only those buffers are assigned // offsets, otherwise all buffers defined by the instructions are assigned. static StatusOr<Result> Run( + std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module, + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_fn, + const tensorflow::gtl::FlatSet<const LogicalBuffer*>* buffers_to_assign = + nullptr); + + // Same as above, but runs on a single computation. The 'instruction_sequence' + // must contain a topologically-consistent total ordering of all instructions + // in the computation. The result is invalid if instructions are not run in + // exactly this sequence. + static StatusOr<Result> Run( std::unique_ptr<HeapAlgorithm> algorithm, - const std::vector<const HloInstruction*>& instruction_sequence, const HloComputation& computation, + const std::vector<const HloInstruction*>& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_fn, const tensorflow::gtl::FlatSet<const LogicalBuffer*>* buffers_to_assign = @@ -86,6 +102,12 @@ class HeapSimulator { const tensorflow::gtl::FlatSet<const LogicalBuffer*>* buffers_to_assign); ~HeapSimulator(); + Status RunComputation( + const HloComputation& computation, + const std::vector<const HloInstruction*>& instruction_sequence, + const TuplePointsToAnalysis& points_to_analysis, + const SequentialHloOrdering::HloModuleSequence* module_sequence); + bool IgnoreBuffer(const LogicalBuffer* buffer) const; void Alloc(const LogicalBuffer* buffer); void Free(const LogicalBuffer* buffer); diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 874bd5f106..0a6900f733 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -19,13 +19,16 @@ limitations under the License. #include <utility> #include <vector> +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { @@ -69,6 +72,7 @@ class HeapCallRecorder : public HeapAlgorithm { // sequence against an expected sequence. class HeapSimulatorTracker { public: + // Constructor for testing a single entry computation. HeapSimulatorTracker( const string& name, std::unique_ptr<HloComputation> computation, const std::vector<const HloInstruction*>& instruction_sequence) { @@ -83,12 +87,48 @@ class HeapSimulatorTracker { auto zero_size = [](const LogicalBuffer& buffer) { return 0; }; auto algorithm = MakeUnique<DecreasingSizeRunsHeap>( MakeUnique<HeapCallRecorder>(&actual_calls_)); - result_ = HeapSimulator::Run(std::move(algorithm), instruction_sequence, - *module_->entry_computation(), - *points_to_analysis_, zero_size) + result_ = HeapSimulator::Run( + std::move(algorithm), *module_->entry_computation(), + instruction_sequence, *points_to_analysis_, zero_size) .ConsumeValueOrDie(); } + explicit HeapSimulatorTracker(const string& name) { + module_ = MakeUnique<HloModule>(name); + } + + // Similar to the single entry computation constructor above, but runs the + // simulation over the entire module. + void RunWholeModule( + const std::vector<const HloInstruction*>& full_module_sequence) { + points_to_analysis_ = + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + + // Construct the module sequence grouped by computation. + SequentialHloOrdering::HloModuleSequence module_sequence; + tensorflow::gtl::FlatMap<const HloInstruction*, int> reverse_position; + for (int i = 0; i < full_module_sequence.size(); ++i) { + const HloInstruction* instruction = full_module_sequence[i]; + module_sequence[instruction->parent()].push_back(instruction); + reverse_position[instruction] = full_module_sequence.size() - i; + } + + // Hack the size_fn so that it returns a decreasing value as we step through + // the sequence. This lets us ensure the Alloc calls are in the sequence + // order. The Free calls are sorted by LogicalBuffer.id, which is at least + // deterministic. + auto size_fn = [&reverse_position](const LogicalBuffer& buffer) { + return reverse_position[buffer.instruction()]; + }; + auto algorithm = MakeUnique<DecreasingSizeRunsHeap>( + MakeUnique<HeapCallRecorder>(&actual_calls_)); + result_ = HeapSimulator::Run(std::move(algorithm), *module_, + module_sequence, *points_to_analysis_, size_fn) + .ConsumeValueOrDie(); + } + + HloModule* module() { return module_.get(); } + // Returns the buffer defined at the given instruction and index. const LogicalBuffer* BufferAt(const HloInstruction* instruction, const ShapeIndex& index) const { @@ -358,6 +398,86 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { }); } +TEST_F(HeapSimulatorTest, WholeModule) { + HeapSimulatorTracker tracker(TestName()); + + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); + HloInstruction* cond_data = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kLt, cond_iter, cond_data)); + HloComputation* cond_computation = + tracker.module()->AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloComputation* body_computation = + tracker.module()->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, param)); + tracker.module()->AddEntryComputation(builder.Build()); + + tracker.RunWholeModule( + {param, while_op, body_param, cond_param, cond_iter, cond_data, cond_lt}); + tracker.ExpectCallSequence({ + // The entry computation param and while_op are allocated first. + {kAlloc, tracker.BufferAt(param, {})}, + {kAlloc, tracker.BufferAt(param, {0})}, + {kAlloc, tracker.BufferAt(param, {1})}, + {kAlloc, tracker.BufferAt(while_op, {})}, + {kAlloc, tracker.BufferAt(while_op, {0})}, + {kAlloc, tracker.BufferAt(while_op, {1})}, + + // Now the while body param is allocated and freed. + {kAlloc, tracker.BufferAt(body_param, {})}, + {kAlloc, tracker.BufferAt(body_param, {0})}, + {kAlloc, tracker.BufferAt(body_param, {1})}, + {kFree, tracker.BufferAt(body_param, {})}, + {kFree, tracker.BufferAt(body_param, {0})}, + {kFree, tracker.BufferAt(body_param, {1})}, + + // Now the while cond param is allocated. The GTE instructions just alias + // the param elements, so the param tuple can immediately be freed. + {kAlloc, tracker.BufferAt(cond_param, {})}, + {kAlloc, tracker.BufferAt(cond_param, {0})}, + {kAlloc, tracker.BufferAt(cond_param, {1})}, + {kFree, tracker.BufferAt(cond_param, {})}, + + // Now the final cond less-than buffer is allocated. + {kAlloc, tracker.BufferAt(cond_lt, {})}, + + // The order of the remaining Free calls is based on the LogicalBuffer.id, + // which is deterministic, but not obvious. + {kFree, tracker.BufferAt(param, {})}, + {kFree, tracker.BufferAt(param, {0})}, + {kFree, tracker.BufferAt(param, {1})}, + + {kFree, tracker.BufferAt(while_op, {})}, + {kFree, tracker.BufferAt(while_op, {0})}, + {kFree, tracker.BufferAt(while_op, {1})}, + + {kFree, tracker.BufferAt(cond_param, {0})}, + {kFree, tracker.BufferAt(cond_param, {1})}, + {kFree, tracker.BufferAt(cond_lt, {})}, + + {kFinish, nullptr}, + }); +} + // Base class for heap algorithm tests. class HeapAlgorithmTestBase : public ::testing::Test { protected: diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 7476b72f02..725ce17d66 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -221,23 +221,6 @@ string SequentialHloOrdering::ToString() const { return tensorflow::str_util::Join(pieces, "\n"); } -namespace { -StatusOr<int64> MinimumMemoryForSequence( - const HloComputation& computation, - const std::vector<const HloInstruction*>& sequence, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - // The absolute minimum memory required for a given sequence of instructions - // is determined by the sequence of Alloc and Free calls on a simulated heap, - // ignoring fragmentation. - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), sequence, - computation, points_to_analysis, size_function)); - return result.heap_size; -} -} // namespace - StatusOr<int64> MinimumMemoryForSequence( const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function) { @@ -249,17 +232,16 @@ StatusOr<int64> MinimumMemoryForSequence( TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, TuplePointsToAnalysis::Run(module)); - int64 total_memory = 0; - for (const auto& pair : module_sequence) { - const HloComputation* computation = pair.first; - const std::vector<const HloInstruction*>& sequence = pair.second; - TF_ASSIGN_OR_RETURN( - const int64 memory, - MinimumMemoryForSequence(*computation, sequence, *points_to_analysis, - size_function)); - total_memory += memory; - } - return total_memory; + // The absolute minimum memory required for a given sequence of instructions + // is determined by the sequence of Alloc and Free calls on a simulated heap, + // ignoring fragmentation. We run the heap simulation on the whole module, + // rather than summing each computation, since it gives us a better lower + // bound, by minimizing the liveness of sub-computations. + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), *module, + module_sequence, *points_to_analysis, size_function)); + return result.heap_size; } namespace { @@ -516,6 +498,18 @@ StatusOr<std::vector<const HloInstruction*>> RunDFSMemoryScheduler( return sequence; } +StatusOr<int64> MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector<const HloInstruction*>& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation, + sequence, points_to_analysis, size_function)); + return result.heap_size; +} + StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, @@ -523,13 +517,17 @@ StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence( // We try both a list-scheduler based ordering and a DFS based ordering, and // choose whichever returns a lower min-memory, not accounting for // fragmentation. + // + // Note that this is just a heuristic. One obvious inaccuracy is that the + // memory required for sub-computations might be different when considered + // within the caller's context. But it's good enough for now. TF_ASSIGN_OR_RETURN( std::vector<const HloInstruction*> list_sequence, ListScheduler::Run(computation, points_to_analysis, size_function)); TF_ASSIGN_OR_RETURN( const int64 list_memory, - MinimumMemoryForSequence(computation, list_sequence, points_to_analysis, - size_function)); + MinimumMemoryForComputation(computation, list_sequence, + points_to_analysis, size_function)); VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes"; TF_ASSIGN_OR_RETURN( @@ -537,8 +535,8 @@ StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence( RunDFSMemoryScheduler(computation, points_to_analysis, size_function)); TF_ASSIGN_OR_RETURN( const int64 dfs_memory, - MinimumMemoryForSequence(computation, dfs_sequence, points_to_analysis, - size_function)); + MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, + size_function)); VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes"; if (list_memory <= dfs_memory) { diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 01b5fd9364..c387fbb89b 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -155,6 +155,65 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { EXPECT_FALSE(ordering.ExecutesBefore(y, c)); } +class MinimumMemoryForSequenceTest : public HloTestBase {}; + +TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { + HloModule module(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); + HloInstruction* cond_data = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kLt, cond_iter, cond_data)); + HloComputation* cond_computation = + module.AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloComputation* body_computation = + module.AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + // Entry params: 8 bytes (4 bytes per param), TOTAL=8 + HloInstruction* iter = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); + HloInstruction* data = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param_data")); + // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); + // While: 8 bytes (4 bytes per element), TOTAL=32 + // Both cond and body use a max of 24 bytes, TOTAL=56 + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, tuple)); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build()); + + auto size_fn = [](const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + + SequentialHloOrdering::HloModuleSequence module_sequence; + module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, + cond_lt}; + module_sequence[body_computation] = {body_param}; + module_sequence[entry_computation] = {iter, data, tuple, while_op}; + EXPECT_EQ(56, + MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); +} + } // namespace } // namespace xla |