diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/heap_simulator.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/heap_simulator.cc | 43 |
1 files changed, 22 insertions, 21 deletions
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 38c3982ebf..e0f3a7e0e2 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -29,13 +29,13 @@ using tensorflow::gtl::FlatSet; /*static*/ StatusOr<int64> HeapSimulator::MinimumMemoryForModule( - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const LogicalBuffer::SizeFunction& size_function) { - if (module_sequence.empty()) { + if (schedule.empty()) { return 0; } - const HloModule* module = module_sequence.begin()->first->parent(); + const HloModule* module = schedule.module(); TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, TuplePointsToAnalysis::Run(module)); @@ -47,14 +47,13 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForModule( TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(), *module, - module_sequence, *points_to_analysis, size_function)); + schedule, *points_to_analysis, size_function)); return result.heap_size; } /*static*/ StatusOr<int64> HeapSimulator::MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector<const HloInstruction*>& sequence, + const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap<const HloComputation*, int64>* @@ -71,13 +70,13 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForComputation( /*static*/ StatusOr<HeapSimulator::Result> HeapSimulator::Run( std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options) { - HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence); + HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule); const HloComputation* entry_computation = module.entry_computation(); - const std::vector<const HloInstruction*>& instruction_sequence = - FindOrDie(module_sequence, entry_computation); + const HloInstructionSequence& instruction_sequence = + schedule.sequence(entry_computation); TF_RETURN_IF_ERROR(heap.RunComputation( *entry_computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -86,13 +85,13 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run( /*static*/ StatusOr<HeapSimulator::Result> HeapSimulator::Run( std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation, - const std::vector<const HloInstruction*>& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options, const tensorflow::gtl::FlatMap<const HloComputation*, int64>* memory_by_computation) { HeapSimulator heap(std::move(algorithm), size_fn, options, - /*module_sequence=*/nullptr, memory_by_computation); + /*schedule=*/nullptr, memory_by_computation); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -102,7 +101,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run( // 'instruction_sequence'. Status HeapSimulator::RunComputation( const HloComputation& computation, - const std::vector<const HloInstruction*>& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis) { VLOG(3) << "Computation:\n" << computation.ToString(); // The goal here is to minimize memory usage, assuming the given sequential @@ -133,7 +132,8 @@ Status HeapSimulator::RunComputation( // set of instructions that need to be visited contains all users of all // aliases, that is, all users of all instructions that have the buffer // contained in their points-to set. - for (const HloInstruction* instruction : instruction_sequence) { + for (const HloInstruction* instruction : + instruction_sequence.instructions()) { const PointsToSet& points_to = points_to_analysis.GetPointsToSet(instruction); const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet(); @@ -166,7 +166,8 @@ Status HeapSimulator::RunComputation( std::vector<const BufferValue*> dead_buffers_to_free; std::vector<const BufferValue*> operand_buffers_to_free; - for (const HloInstruction* instruction : instruction_sequence) { + for (const HloInstruction* instruction : + instruction_sequence.instructions()) { const TuplePointsToAnalysis::BufferDefinitionVector& buffers_defined_by_instruction = points_to_analysis.GetBuffersDefinedByInstruction(instruction); @@ -285,14 +286,14 @@ Status HeapSimulator::RunComputation( // The order that the sub-computations are simulated does not affect // correctness; since the whole module has been scheduled, we know that the // sub-computations will never be run concurrently. - if (module_sequence_ != nullptr) { + if (schedule_ != nullptr) { if (instruction->opcode() == HloOpcode::kCall || instruction->opcode() == HloOpcode::kConditional || instruction->opcode() == HloOpcode::kWhile) { for (const HloComputation* called_computation : instruction->called_computations()) { - const std::vector<const HloInstruction*>& called_sequence = - FindOrDie(*module_sequence_, called_computation); + const HloInstructionSequence& called_sequence = + schedule_->sequence(called_computation); TF_RETURN_IF_ERROR(RunComputation( *called_computation, called_sequence, points_to_analysis)); } @@ -343,16 +344,16 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr<HeapAlgorithm> algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence, + const HloSchedule* schedule, const tensorflow::gtl::FlatMap<const HloComputation*, int64>* memory_by_computation) : no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), - module_sequence_(module_sequence), + schedule_(schedule), memory_by_computation_(memory_by_computation) { - debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr); + debug_trace_.set_whole_module_simulation(schedule_ != nullptr); } HeapSimulator::~HeapSimulator() {} |