diff options
author | 2018-06-14 17:22:37 -0700 | |
---|---|---|
committer | 2018-06-14 17:25:24 -0700 | |
commit | 7e05b8a1c7fec4852e275e708555a759947270d7 (patch) | |
tree | d50789d00d38c0ff5cbf56d4780149760a55d3f0 | |
parent | 9e4cbaf3a3a3bfca913bebdcfc082265c7a13ad6 (diff) |
[TF:XLA] Account for subcomputations in heap simulator during scheduling.
PiperOrigin-RevId: 200646674
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_assignment.cc | 5 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/heap_simulator.cc | 52 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/heap_simulator.h | 58 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/heap_simulator_test.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_scheduling.cc | 37 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_scheduling_test.cc | 104 |
7 files changed, 204 insertions, 56 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index cb2e159a38..396ce13e7f 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1101,6 +1101,7 @@ tf_cc_test( srcs = ["hlo_scheduling_test.cc"], deps = [ ":buffer_value", + ":heap_simulator", ":hlo", ":hlo_ordering", ":hlo_scheduling", diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 5d3b0cb333..afe4b2e142 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -631,8 +631,9 @@ Status BufferAssignment::ComputeSummaryStats() { } } if (module_sequence.size() == module_->computation_count()) { - TF_ASSIGN_OR_RETURN(const int64 min_size, - MinimumMemoryForModule(module_sequence, buffer_size_)); + TF_ASSIGN_OR_RETURN( + const int64 min_size, + HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_)); stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size; } diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 5dba50a63b..a04aa4069d 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -26,7 +26,8 @@ namespace xla { using tensorflow::gtl::FlatMap; using tensorflow::gtl::FlatSet; -StatusOr<int64> MinimumMemoryForModule( +/*static*/ +StatusOr<int64> HeapSimulator::MinimumMemoryForModule( const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function) { if (module_sequence.empty()) { @@ -49,15 +50,19 @@ StatusOr<int64> MinimumMemoryForModule( return result.heap_size; } -StatusOr<int64> MinimumMemoryForComputation( +/*static*/ +StatusOr<int64> HeapSimulator::MinimumMemoryForComputation( const HloComputation& computation, const std::vector<const HloInstruction*>& sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap<const HloComputation*, int64>* + memory_by_computation) { TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation, - sequence, points_to_analysis, size_function)); + sequence, points_to_analysis, size_function, + HeapSimulator::Options(), memory_by_computation)); return result.heap_size; } @@ -81,9 +86,11 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run( std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation, const std::vector<const HloInstruction*>& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const BufferValue::SizeFunction& size_fn, const Options& options) { + const BufferValue::SizeFunction& size_fn, const Options& options, + const tensorflow::gtl::FlatMap<const HloComputation*, int64>* + memory_by_computation) { HeapSimulator heap(std::move(algorithm), size_fn, options, - /*module_sequence=*/nullptr); + /*module_sequence=*/nullptr, memory_by_computation); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -254,6 +261,12 @@ Status HeapSimulator::RunComputation( Alloc(buffer, instruction); } } + // Account for the memory used by subcomputations when estimating the + // current heap size. + if (memory_by_computation_ != nullptr) { + algorithm_->AccountForSubcomputationMemory(instruction, + *memory_by_computation_); + } // If the whole module is sequential, we can save memory by running the // heap-simulation for sub-computations inline. E.g. the buffers for the @@ -321,12 +334,15 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr<HeapAlgorithm> algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence) + const SequentialHloOrdering::HloModuleSequence* module_sequence, + const tensorflow::gtl::FlatMap<const HloComputation*, int64>* + memory_by_computation) : no_fragmentation_stats_(MakeUnique<NoFragmentationStatsHeap>()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), - module_sequence_(module_sequence) { + module_sequence_(module_sequence), + memory_by_computation_(memory_by_computation) { debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr); } @@ -495,6 +511,26 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { } } +void NoFragmentationStatsHeap::AccountForSubcomputationMemory( + const HloInstruction* instruction, + const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + memory_by_computation) { + // We only count the memory usage of the largest subcomputation, instead of + // adding them all, because subcomputations won't execute in parallel. + int64 max_subcomputation_bytes = 0; + for (const auto* c : instruction->called_computations()) { + auto it = memory_by_computation.find(c); + if (it != memory_by_computation.end()) { + int64 subcomputation_bytes = it->second; + if (subcomputation_bytes > max_subcomputation_bytes) { + max_subcomputation_bytes = subcomputation_bytes; + } + } + } + max_heap_size_ = + std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes); +} + void NoFragmentationStatsHeap::Free(const BufferValue* buffer, int64 size) { current_heap_size_ -= size; } diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 3be3bb8e7f..811a6042df 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -34,21 +34,6 @@ limitations under the License. namespace xla { -// Returns the minimum memory required to compute an HLO module where all -// computations have been scheduled (represented by the given module_sequence), -// assuming no fragmentation. -StatusOr<int64> MinimumMemoryForModule( - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const LogicalBuffer::SizeFunction& size_function); - -// Returns the minimum memory required to compute the given computation, -// assuming no fragmentation. -StatusOr<int64> MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector<const HloInstruction*>& sequence, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); - // Forward declare classes defined below. class HeapAlgorithm; @@ -100,6 +85,23 @@ class HeapSimulator { const BufferValueFlatSet* buffers_to_assign; }; + // Returns the minimum memory required to compute an HLO module where all + // computations have been scheduled (represented by the given + // module_sequence), assuming no fragmentation. + static StatusOr<int64> MinimumMemoryForModule( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function); + + // Returns the minimum memory required to compute the given computation, + // assuming no fragmentation. + static StatusOr<int64> MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector<const HloInstruction*>& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap<const HloComputation*, int64>* + memory_by_computation = nullptr); + // Run the heap simulation with the given algorithm, assuming the given // module_sequence, which must contain a topologically-consistent total // ordering of all instructions within each computation. The result is invalid @@ -126,7 +128,9 @@ class HeapSimulator { const std::vector<const HloInstruction*>& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, - const Options& options = Options()); + const Options& options = Options(), + const tensorflow::gtl::FlatMap<const HloComputation*, int64>* + memory_by_computation = nullptr); private: // If 'module_sequence' is non-null, it is used to find kCall and kWhile @@ -135,7 +139,9 @@ class HeapSimulator { HeapSimulator( std::unique_ptr<HeapAlgorithm> algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence); + const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr, + const tensorflow::gtl::FlatMap<const HloComputation*, int64>* + memory_by_computation = nullptr); ~HeapSimulator(); Status RunComputation( @@ -159,7 +165,13 @@ class HeapSimulator { const std::unique_ptr<HeapAlgorithm> algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; + // module_sequence_ is set by buffer assignment, and memory_by_computation_ is + // set by hlo scheduling. Then, in RunComputation, we check both in order to + // handle subcomputations. It would be good to unify the handling of + // subcomputations, but it's not clear how. const SequentialHloOrdering::HloModuleSequence* module_sequence_; + const tensorflow::gtl::FlatMap<const HloComputation*, int64>* + memory_by_computation_; // In addition to Alloc and Free, the heap simulator exposes a concept of // buffer sharing. When ShareBuffer is called, instead of allocating new @@ -204,6 +216,11 @@ class HeapAlgorithm { // Alloc allocates a buffer of 'size' bytes. virtual void Alloc(const BufferValue* buffer, int64 size) = 0; + virtual void AccountForSubcomputationMemory( + const HloInstruction* instruction, + const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + memory_by_computation) {} + // Free de-allocates a previously allocated buffer. virtual void Free(const BufferValue* buffer, int64 size) = 0; @@ -222,7 +239,14 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { ~NoFragmentationStatsHeap() override = default; void Alloc(const BufferValue* buffer, int64 size) override; + + void AccountForSubcomputationMemory( + const HloInstruction* instruction, + const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + memory_by_computation) override; + void Free(const BufferValue* buffer, int64 size) override; + Result Finish() override; private: diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 309ab85f78..93d7a14125 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -89,7 +89,8 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { cond_lt}; module_sequence[body_computation] = {body_param}; module_sequence[entry_computation] = {iter, data, tuple, while_op}; - EXPECT_EQ(56, MinimumMemoryForModule(module_sequence, size_fn).ValueOrDie()); + EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn) + .ValueOrDie()); } const char kAlloc[] = "Alloc"; diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index b14ade3549..641b9ecec9 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -375,7 +375,7 @@ int64 SumLogicalBufferSizes( return size; } -StatusOr<std::vector<const HloInstruction*>> ScheduleComputationsInModule( +StatusOr<std::vector<const HloInstruction*>> ScheduleComputationHelper( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -498,29 +498,29 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler( std::vector<const HloInstruction*> list_sequence, ListMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); - TF_ASSIGN_OR_RETURN( - const int64 list_memory, - MinimumMemoryForComputation(computation, list_sequence, - points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN(const int64 list_memory, + HeapSimulator::MinimumMemoryForComputation( + computation, list_sequence, points_to_analysis, + size_function, &memory_by_computation)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); TF_ASSIGN_OR_RETURN(std::vector<const HloInstruction*> dfs_sequence, DFSMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); - TF_ASSIGN_OR_RETURN( - const int64 dfs_memory, - MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, - size_function)); + TF_ASSIGN_OR_RETURN(const int64 dfs_memory, + HeapSimulator::MinimumMemoryForComputation( + computation, dfs_sequence, points_to_analysis, + size_function, &memory_by_computation)); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); TF_ASSIGN_OR_RETURN( std::vector<const HloInstruction*> post_order_sequence, PostOrderMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); - TF_ASSIGN_OR_RETURN( - const int64 post_order_memory, - MinimumMemoryForComputation(computation, post_order_sequence, - points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN(const int64 post_order_memory, + HeapSimulator::MinimumMemoryForComputation( + computation, post_order_sequence, points_to_analysis, + size_function, &memory_by_computation)); VLOG(2) << "Min-memory post order sequence: " << HumanReadableNumBytes(post_order_memory); @@ -551,12 +551,13 @@ StatusOr<SequentialHloOrdering::HloModuleSequence> ScheduleComputationsInModule( for (const auto* computation : module.MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { TF_ASSIGN_OR_RETURN(auto one_computation_sequence, - ScheduleComputationsInModule( + ScheduleComputationHelper( *computation, *points_to_analysis, size_function, algorithm, memory_by_computation)); memory_by_computation[computation] = - MinimumMemoryForComputation(*computation, one_computation_sequence, - *points_to_analysis, size_function) + HeapSimulator::MinimumMemoryForComputation( + *computation, one_computation_sequence, *points_to_analysis, + size_function, &memory_by_computation) .ValueOrDie(); sequence[computation] = std::move(one_computation_sequence); } @@ -571,8 +572,8 @@ StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation( TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); tensorflow::gtl::FlatMap<const HloComputation*, int64> empty_map; - return ScheduleComputationsInModule(computation, *points_to_analysis, - size_function, nullptr, empty_map); + return ScheduleComputationHelper(computation, *points_to_analysis, + size_function, nullptr, empty_map); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 6f1b1215d3..73f22f81f4 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include <memory> #include <string> +#include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -144,7 +145,7 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { // ROOT %subtract = f32[4]{0} subtract( // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1) // } - // %SubcomputationsNotAccounted () -> f32[2,4] { + // %ListAccountsForSubcomputations () -> f32[2,4] { // %constant.3 = f32[2,4]{1,0} constant( // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } }) // %transpose = f32[2,4]{1,0} transpose( @@ -210,16 +211,16 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule( - *module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - }, - ListMemoryScheduler)); + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); + auto entry_computation = module->entry_computation(); + EXPECT_EQ(entry_computation->instruction_count(), + sequence.at(entry_computation).size()); SequentialHloOrdering ordering(module.get(), sequence); // This schedule is an example of List's greedy heuristics being suboptimal. // The while_loop is more expensive than transpose, so it would have been @@ -228,6 +229,24 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast)); EXPECT_TRUE(ordering.ExecutesBefore(bcast, add)); EXPECT_TRUE(ordering.ExecutesBefore(transpose, add)); + + tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation; + memory_by_computation[cond_computation] = 17; + memory_by_computation[body_computation] = 16; + std::unique_ptr<TuplePointsToAnalysis> points_to_analysis = + TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); + + // HeapSimulator doesn't account for subcomputations + EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn) + .ValueOrDie()); + // HeapSimulator accounts for subcomputations. The max mem doesn't change + // because the while body isn't live during the peak. + EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn, &memory_by_computation) + .ValueOrDie()); } TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { @@ -325,5 +344,70 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); } +TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { + auto module = CreateNewModule(); + const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); + const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); + + // param != 0 + // Needs 17 bytes + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "cond_param")); + HloInstruction* zero_vector = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2<float>({{0, 0, 0, 0}}))); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); + auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); + + // param - 1 + // Needs 16 bytes + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "body_param")); + HloInstruction* one_vector = body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2<float>({{1, 1, 1, 1}}))); + body_builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kSubtract, body_param, one_vector)); + auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* while_init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2<float>({{1, 1, 1, 1}}))); + // Creates 16 bytes, ignoring subcomputations + builder.AddInstruction(HloInstruction::CreateWhile( + r1f32, cond_computation, body_computation, while_init)); + + module->AddEntryComputation(builder.Build()); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + auto entry_computation = module->entry_computation(); + EXPECT_EQ(entry_computation->instruction_count(), + sequence.at(entry_computation).size()); + + tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation; + memory_by_computation[cond_computation] = 17; + memory_by_computation[body_computation] = 16; + std::unique_ptr<TuplePointsToAnalysis> points_to_analysis = + TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); + + // HeapSimulator doesn't account for subcomputations + EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn) + .ValueOrDie()); + // HeapSimulator accounts for subcomputations + EXPECT_EQ(33, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn, &memory_by_computation) + .ValueOrDie()); +} + } // namespace } // namespace xla |