diff options
author | Dimitris Vardoulakis <dimvar@google.com> | 2018-06-14 17:22:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-14 17:25:24 -0700 |
commit | 7e05b8a1c7fec4852e275e708555a759947270d7 (patch) | |
tree | d50789d00d38c0ff5cbf56d4780149760a55d3f0 /tensorflow/compiler/xla/service/heap_simulator.cc | |
parent | 9e4cbaf3a3a3bfca913bebdcfc082265c7a13ad6 (diff) |
[TF:XLA] Account for subcomputations in heap simulator during scheduling.
PiperOrigin-RevId: 200646674
Diffstat (limited to 'tensorflow/compiler/xla/service/heap_simulator.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/heap_simulator.cc | 52 |
1 files changed, 44 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 5dba50a63b..a04aa4069d 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -26,7 +26,8 @@ namespace xla { using tensorflow::gtl::FlatMap; using tensorflow::gtl::FlatSet; -StatusOr<int64> MinimumMemoryForModule( +/*static*/ +StatusOr<int64> HeapSimulator::MinimumMemoryForModule( const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function) { if (module_sequence.empty()) { @@ -49,15 +50,19 @@ StatusOr<int64> MinimumMemoryForModule( return result.heap_size; } -StatusOr<int64> MinimumMemoryForComputation( +/*static*/ +StatusOr<int64> HeapSimulator::MinimumMemoryForComputation( const HloComputation& computation, const std::vector<const HloInstruction*>& sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap<const HloComputation*, int64>* + memory_by_computation) { TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation, - sequence, points_to_analysis, size_function)); + sequence, points_to_analysis, size_function, + HeapSimulator::Options(), memory_by_computation)); return result.heap_size; } @@ -81,9 +86,11 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run( std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation, const std::vector<const HloInstruction*>& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const BufferValue::SizeFunction& size_fn, const Options& options) { + const BufferValue::SizeFunction& size_fn, const Options& options, + const tensorflow::gtl::FlatMap<const HloComputation*, int64>* + memory_by_computation) { HeapSimulator heap(std::move(algorithm), size_fn, options, - /*module_sequence=*/nullptr); + /*module_sequence=*/nullptr, memory_by_computation); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -254,6 +261,12 @@ Status HeapSimulator::RunComputation( Alloc(buffer, instruction); } } + // Account for the memory used by subcomputations when estimating the + // current heap size. + if (memory_by_computation_ != nullptr) { + algorithm_->AccountForSubcomputationMemory(instruction, + *memory_by_computation_); + } // If the whole module is sequential, we can save memory by running the // heap-simulation for sub-computations inline. E.g. the buffers for the @@ -321,12 +334,15 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr<HeapAlgorithm> algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence) + const SequentialHloOrdering::HloModuleSequence* module_sequence, + const tensorflow::gtl::FlatMap<const HloComputation*, int64>* + memory_by_computation) : no_fragmentation_stats_(MakeUnique<NoFragmentationStatsHeap>()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), - module_sequence_(module_sequence) { + module_sequence_(module_sequence), + memory_by_computation_(memory_by_computation) { debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr); } @@ -495,6 +511,26 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { } } +void NoFragmentationStatsHeap::AccountForSubcomputationMemory( + const HloInstruction* instruction, + const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + memory_by_computation) { + // We only count the memory usage of the largest subcomputation, instead of + // adding them all, because subcomputations won't execute in parallel. + int64 max_subcomputation_bytes = 0; + for (const auto* c : instruction->called_computations()) { + auto it = memory_by_computation.find(c); + if (it != memory_by_computation.end()) { + int64 subcomputation_bytes = it->second; + if (subcomputation_bytes > max_subcomputation_bytes) { + max_subcomputation_bytes = subcomputation_bytes; + } + } + } + max_heap_size_ = + std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes); +} + void NoFragmentationStatsHeap::Free(const BufferValue* buffer, int64 size) { current_heap_size_ -= size; } |