aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/heap_simulator.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-02 17:21:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-02 18:30:05 -0700
commit5ad12420e78d0aa756fd2a41945468e826e267c2 (patch)
tree400a2866984e554e7ba1ac0de02682822ef0d6dc /tensorflow/compiler/xla/service/heap_simulator.cc
parent58196d4bf923d6fa2500e84d9d22ed8227ba305c (diff)
[XLA:HLO] Run HeapSimulator on whole-module if all computations are sequential.
Previously the HeapSimulator was only run on a per-computation basis. This meant that if you had many sub-computations in your module (e.g. many While loops), the space for all of the temporary buffers inside the conditions and bodies of the loops were in distinct memory ranges. This is overly pessimistic if all computations in the module are sequential. This CL changes the HeapSimulator to also run whole-module simulation, calling Alloc and Free on sub-computation buffers at the appropriate nested spot, right next to the calling instruction. The BufferAssigner is updated to take advantage of this when possible, as is MinimumMemoryForSequence. Change: 154908856
Diffstat (limited to 'tensorflow/compiler/xla/service/heap_simulator.cc')
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc90
1 files changed, 76 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 9c4899a67d..d7aa5664df 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -53,12 +53,44 @@ std::vector<const LogicalBuffer*> UniqueOperandSourceBuffers(
/*static*/
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
- std::unique_ptr<HeapAlgorithm> algorithm,
+ std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
+ const SequentialHloOrdering::HloModuleSequence& module_sequence,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_fn,
+ const FlatSet<const LogicalBuffer*>* buffers_to_assign) {
+ HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign);
+ const HloComputation* entry_computation = module.entry_computation();
+ const std::vector<const HloInstruction*>& instruction_sequence =
+ FindOrDie(module_sequence, entry_computation);
+ TF_RETURN_IF_ERROR(heap.RunComputation(*entry_computation,
+ instruction_sequence,
+ points_to_analysis, &module_sequence));
+ return heap.Finish();
+}
+
+/*static*/
+StatusOr<HeapSimulator::Result> HeapSimulator::Run(
+ std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
const std::vector<const HloInstruction*>& instruction_sequence,
- const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_fn,
const FlatSet<const LogicalBuffer*>* buffers_to_assign) {
+ HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign);
+ TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
+ points_to_analysis,
+ /*module_sequence=*/nullptr));
+ return heap.Finish();
+}
+
+// Runs a heap simulation for the given 'computation', assuming the given
+// 'instruction_sequence'. If 'module_sequence' is non-null, it is used to find
+// kCall and kWhile sub-computations, and the heap simulation for those
+// sub-computations will be run recursively.
+Status HeapSimulator::RunComputation(
+ const HloComputation& computation,
+ const std::vector<const HloInstruction*>& instruction_sequence,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const SequentialHloOrdering::HloModuleSequence* module_sequence) {
// The goal here is to minimize memory usage, assuming the given sequential
// ordering of instructions. The strategy is to walk through the instruction
// sequence, calling Alloc and Free on the underlying heap algorithm. The
@@ -67,7 +99,6 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
// 'live_buffers' tracks the liveness of each buffer that we assign, by
// associating it with a set of HloInstructions that need to be visited. When
// the set becomes empty, the buffer is no longer used, and can be freed.
- HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign);
FlatMap<const LogicalBuffer*, FlatSet<const HloInstruction*>> live_buffers;
const HloInstruction* root = computation.root_instruction();
@@ -90,7 +121,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
// lifetime of buffers that aren't already connected by a data dependency.
std::vector<const LogicalBuffer*> dead_buffers_to_free;
for (const LogicalBuffer* buffer : buffers_defined_by_instruction) {
- if (heap.IgnoreBuffer(buffer)) {
+ if (IgnoreBuffer(buffer)) {
continue;
}
for (const BufferAlias& alias :
@@ -127,7 +158,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::vector<const LogicalBuffer*> operand_buffers_to_free;
for (const LogicalBuffer* operand_buffer :
UniqueOperandSourceBuffers(instruction, points_to_analysis)) {
- if (heap.IgnoreBuffer(operand_buffer)) {
+ if (IgnoreBuffer(operand_buffer)) {
continue;
}
live_buffers[operand_buffer].erase(instruction);
@@ -142,10 +173,10 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
// happen before dead or operand buffers are freed; the instruction reads
// the operand buffers to produce its output.
//
- // INVARIANT: Either heap.Alloc or heap.ShareBuffer will be called for each
- // buffer that we should assign.
+ // INVARIANT: Either Alloc or ShareBuffer will be called for each buffer
+ // that we should assign.
for (const LogicalBuffer* buffer : buffers_defined_by_instruction) {
- if (heap.IgnoreBuffer(buffer)) {
+ if (IgnoreBuffer(buffer)) {
continue;
}
@@ -159,24 +190,50 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
CanShareOperandBufferWithUser(
operand_buffer->instruction(), operand_buffer->index(),
buffer->instruction(), buffer->index(), points_to_analysis)) {
- heap.ShareBuffer(buffer, operand_buffer);
+ ShareBuffer(buffer, operand_buffer);
shared = true;
break;
}
}
if (!shared) {
- heap.Alloc(buffer);
+ Alloc(buffer);
}
}
+ // If the whole module is sequential, we can save memory by running the
+ // heap-simulation for sub-computations inline. E.g. the buffers for the
+ // condition and body of a kWhile instruction are only live for the duration
+ // of the instruction itself.
+ //
+ // The order that the sub-computations are simulated does not affect
+ // correctness; since the whole module is sequential, we know that the
+ // sub-computations will never be run concurrently.
+ if (module_sequence != nullptr) {
+ if (instruction->opcode() == HloOpcode::kCall ||
+ instruction->opcode() == HloOpcode::kWhile) {
+ for (const HloComputation* called_computation :
+ instruction->called_computations()) {
+ const std::vector<const HloInstruction*>& called_sequence =
+ FindOrDie(*module_sequence, called_computation);
+ TF_RETURN_IF_ERROR(RunComputation(*called_computation,
+ called_sequence, points_to_analysis,
+ module_sequence));
+ }
+ }
+
+ // Other sub-computations (e.g. Map, Reduce, ...) are skipped; they are
+ // assigned "thread-local" allocations, meaning their buffers are not
+ // allocated up-front at the beginning of the computation.
+ }
+
// Free buffers that are no longer live. This is the earliest point that we
// can de-allocate; right after the last use of the buffer.
for (const LogicalBuffer* buffer : dead_buffers_to_free) {
- heap.Free(buffer);
+ Free(buffer);
}
for (const LogicalBuffer* buffer : operand_buffers_to_free) {
- heap.Free(buffer);
+ Free(buffer);
}
}
@@ -187,10 +244,10 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
const FlatSet<const HloInstruction*>& pending = buffer_pending.second;
CHECK_EQ(pending.size(), 1) << *buffer;
CHECK(*pending.begin() == nullptr) << *buffer;
- heap.Free(buffer);
+ Free(buffer);
}
- return heap.Finish();
+ return Status::OK();
}
HeapSimulator::HeapSimulator(
@@ -309,6 +366,11 @@ HeapSimulator::Result HeapSimulator::Finish() {
result.chunk_map.emplace(buffer, chunk);
}
}
+ // If we were told to assign specific buffers, make sure we've assigned
+ // exactly that many buffers.
+ if (buffers_to_assign_ != nullptr) {
+ CHECK_EQ(buffers_to_assign_->size(), result.chunk_map.size());
+ }
}
// Fragmentation is the difference between the actual and ideal sizes.