aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/heap_simulator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/heap_simulator.cc')
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc43
1 files changed, 22 insertions, 21 deletions
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 38c3982ebf..e0f3a7e0e2 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -29,13 +29,13 @@ using tensorflow::gtl::FlatSet;
/*static*/
StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
+ const HloSchedule& schedule,
const LogicalBuffer::SizeFunction& size_function) {
- if (module_sequence.empty()) {
+ if (schedule.empty()) {
return 0;
}
- const HloModule* module = module_sequence.begin()->first->parent();
+ const HloModule* module = schedule.module();
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(module));
@@ -47,14 +47,13 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
TF_ASSIGN_OR_RETURN(
HeapSimulator::Result result,
HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(), *module,
- module_sequence, *points_to_analysis, size_function));
+ schedule, *points_to_analysis, size_function));
return result.heap_size;
}
/*static*/
StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
- const HloComputation& computation,
- const std::vector<const HloInstruction*>& sequence,
+ const HloComputation& computation, const HloInstructionSequence& sequence,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
@@ -71,13 +70,13 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
/*static*/
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
+ const HloSchedule& schedule,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_fn, const Options& options) {
- HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence);
+ HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule);
const HloComputation* entry_computation = module.entry_computation();
- const std::vector<const HloInstruction*>& instruction_sequence =
- FindOrDie(module_sequence, entry_computation);
+ const HloInstructionSequence& instruction_sequence =
+ schedule.sequence(entry_computation);
TF_RETURN_IF_ERROR(heap.RunComputation(
*entry_computation, instruction_sequence, points_to_analysis));
return heap.Finish();
@@ -86,13 +85,13 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
/*static*/
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
- const std::vector<const HloInstruction*>& instruction_sequence,
+ const HloInstructionSequence& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_fn, const Options& options,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
memory_by_computation) {
HeapSimulator heap(std::move(algorithm), size_fn, options,
- /*module_sequence=*/nullptr, memory_by_computation);
+ /*schedule=*/nullptr, memory_by_computation);
TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
points_to_analysis));
return heap.Finish();
@@ -102,7 +101,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
// 'instruction_sequence'.
Status HeapSimulator::RunComputation(
const HloComputation& computation,
- const std::vector<const HloInstruction*>& instruction_sequence,
+ const HloInstructionSequence& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis) {
VLOG(3) << "Computation:\n" << computation.ToString();
// The goal here is to minimize memory usage, assuming the given sequential
@@ -133,7 +132,8 @@ Status HeapSimulator::RunComputation(
// set of instructions that need to be visited contains all users of all
// aliases, that is, all users of all instructions that have the buffer
// contained in their points-to set.
- for (const HloInstruction* instruction : instruction_sequence) {
+ for (const HloInstruction* instruction :
+ instruction_sequence.instructions()) {
const PointsToSet& points_to =
points_to_analysis.GetPointsToSet(instruction);
const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet();
@@ -166,7 +166,8 @@ Status HeapSimulator::RunComputation(
std::vector<const BufferValue*> dead_buffers_to_free;
std::vector<const BufferValue*> operand_buffers_to_free;
- for (const HloInstruction* instruction : instruction_sequence) {
+ for (const HloInstruction* instruction :
+ instruction_sequence.instructions()) {
const TuplePointsToAnalysis::BufferDefinitionVector&
buffers_defined_by_instruction =
points_to_analysis.GetBuffersDefinedByInstruction(instruction);
@@ -285,14 +286,14 @@ Status HeapSimulator::RunComputation(
// The order that the sub-computations are simulated does not affect
// correctness; since the whole module has been scheduled, we know that the
// sub-computations will never be run concurrently.
- if (module_sequence_ != nullptr) {
+ if (schedule_ != nullptr) {
if (instruction->opcode() == HloOpcode::kCall ||
instruction->opcode() == HloOpcode::kConditional ||
instruction->opcode() == HloOpcode::kWhile) {
for (const HloComputation* called_computation :
instruction->called_computations()) {
- const std::vector<const HloInstruction*>& called_sequence =
- FindOrDie(*module_sequence_, called_computation);
+ const HloInstructionSequence& called_sequence =
+ schedule_->sequence(called_computation);
TF_RETURN_IF_ERROR(RunComputation(
*called_computation, called_sequence, points_to_analysis));
}
@@ -343,16 +344,16 @@ Status HeapSimulator::RunComputation(
HeapSimulator::HeapSimulator(
std::unique_ptr<HeapAlgorithm> algorithm,
const BufferValue::SizeFunction& size_fn, const Options& options,
- const SequentialHloOrdering::HloModuleSequence* module_sequence,
+ const HloSchedule* schedule,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
memory_by_computation)
: no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()),
algorithm_(std::move(algorithm)),
size_fn_(size_fn),
options_(options),
- module_sequence_(module_sequence),
+ schedule_(schedule),
memory_by_computation_(memory_by_computation) {
- debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr);
+ debug_trace_.set_whole_module_simulation(schedule_ != nullptr);
}
HeapSimulator::~HeapSimulator() {}