aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/heap_simulator.cc
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-02-26 14:19:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-26 14:23:37 -0800
commitd98e7fc5720c1597b6f2034ba2ad62438ac5ef39 (patch)
tree3e9063ca7a9ce572b73475508b5a4060f6f887d3 /tensorflow/compiler/xla/service/heap_simulator.cc
parenta05488be720fc803ac56738c8bc0222fb8a36d7f (diff)
[XLA] GTE of a certain element of the tuple does not need not keep other elements alive.
This achieves two things: 1. Heap simulation runtime is no longer quadratic in the number of tuple elements (as we don't add each GetTupleElement to the liveset of each buffer defined by the tuple). 2. A reduction in the heap memory footprint. PiperOrigin-RevId: 187079787
Diffstat (limited to 'tensorflow/compiler/xla/service/heap_simulator.cc')
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc135
1 files changed, 77 insertions, 58 deletions
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index a2d13c013c..3dd4c4a079 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -27,38 +27,6 @@ namespace xla {
using tensorflow::gtl::FlatMap;
using tensorflow::gtl::FlatSet;
-namespace {
-
-// Returns the set of buffers that may be sources of all operands of the given
-// instruction. The returned buffers are guaranteed to have no duplicates, and
-// to be sorted in a deterministic order.
-std::vector<const LogicalBuffer*> UniqueOperandSourceBuffers(
- const HloInstruction* instruction,
- const TuplePointsToAnalysis& points_to_analysis) {
- std::vector<const LogicalBuffer*> buffers;
- for (const HloInstruction* operand : instruction->operands()) {
- points_to_analysis.GetPointsToSet(operand).ForEachElement(
- [&](const ShapeIndex& /*index*/,
- const PointsToSet::BufferList& points_to) {
- buffers.insert(buffers.end(), points_to.begin(), points_to.end());
- });
- }
-
- // Sort and then remove duplicates from buffers.
- std::sort(buffers.begin(), buffers.end(),
- [](const LogicalBuffer* a, const LogicalBuffer* b) {
- return a->id() < b->id();
- });
- buffers.erase(std::unique(buffers.begin(), buffers.end(),
- [](const LogicalBuffer* a, const LogicalBuffer* b) {
- return a->id() == b->id();
- }),
- buffers.end());
- return buffers;
-}
-
-} // namespace
-
/*static*/
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
@@ -93,6 +61,7 @@ Status HeapSimulator::RunComputation(
const HloComputation& computation,
const std::vector<const HloInstruction*>& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis) {
+ VLOG(3) << "Computation:\n" << computation.ToString();
// The goal here is to minimize memory usage, assuming the given sequential
// ordering of instructions. The strategy is to walk through the instruction
// sequence, calling Alloc and Free on the underlying heap algorithm. The
@@ -101,7 +70,51 @@ Status HeapSimulator::RunComputation(
// 'live_buffers' tracks the liveness of each buffer that we assign, by
// associating it with a set of HloInstructions that need to be visited. When
// the set becomes empty, the buffer is no longer used, and can be freed.
+ // 'used_buffers' is the reverse map - it tracks which buffers were used by an
+ // instruction, so that we can remove the instructions from a buffer's live
+ // set after they are visited.
FlatMap<const LogicalBuffer*, FlatSet<const HloInstruction*>> live_buffers;
+ FlatMap<const HloInstruction*, FlatSet<const LogicalBuffer*>> used_buffers;
+ auto add_user_to_buffer = [this, &live_buffers, &used_buffers](
+ const HloInstruction* user,
+ const LogicalBuffer* buffer) {
+ if (!IgnoreBuffer(buffer)) {
+ VLOG(4) << " Adding user " << user->name() << " to buffer "
+ << buffer->ToString();
+ live_buffers[buffer].insert(user);
+ used_buffers[user].insert(buffer);
+ }
+ };
+
+ // Initialize live_buffers for each buffer that we're going to assign. The
+ // set of instructions that need to be visited contains all users of all
+ // aliases, that is, all users of all instructions that have the buffer
+ // contained in their points-to set.
+ for (const HloInstruction* instruction : instruction_sequence) {
+ const PointsToSet& points_to =
+ points_to_analysis.GetPointsToSet(instruction);
+ const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet();
+ for (const HloInstruction* user : instruction->users()) {
+ if (user->opcode() != HloOpcode::kGetTupleElement) {
+ for (const LogicalBuffer* buffer : buffer_set) {
+ add_user_to_buffer(user, buffer);
+ }
+ } else {
+ // A GetTupleElement doesn't need to keep all of its operand's buffers
+ // alive. It only needs the buffers that relate to the element its
+ // extracting, and the tuple it's extracting from, but not the buffers
+ // for the other elements.
+ for (const LogicalBuffer* buffer : points_to.element({})) {
+ add_user_to_buffer(user, buffer);
+ }
+ const PointsToSet& gte_points_to =
+ points_to_analysis.GetPointsToSet(user);
+ for (const LogicalBuffer* buffer : gte_points_to.CreateFlattenedSet()) {
+ add_user_to_buffer(user, buffer);
+ }
+ }
+ }
+ }
const HloInstruction* root = computation.root_instruction();
auto output_source_buffers =
@@ -114,34 +127,17 @@ Status HeapSimulator::RunComputation(
buffers_defined_by_instruction =
points_to_analysis.GetBuffersDefinedByInstruction(instruction);
- // Initialize live_buffers for each buffer that we're going to assign. The
- // set of instructions that need to be visited contains all users of all
- // aliases. The alias itself is not necessary; if it has users, the users
- // are necessarily scheduled after the alias. And if it has no users, it is
- // either a dead value or an output, both of which are handled below.
- //
- // We ignore control dependencies here. The reasoning is that the control
- // dependencies have already been accounted for in the ordering of the given
- // 'instruction_sequence', and should not otherwise artificially extend the
- // lifetime of buffers that aren't already connected by a data dependency.
+ VLOG(3) << "Instruction: " << instruction->ToString();
+ for (const LogicalBuffer* buffer : buffers_defined_by_instruction) {
+ VLOG(4) << " Defines: " << buffer->ToString()
+ << (IgnoreBuffer(buffer) ? " (Ignored)" : "");
+ }
+
dead_buffers_to_free.clear();
for (const LogicalBuffer* buffer : buffers_defined_by_instruction) {
if (IgnoreBuffer(buffer)) {
continue;
}
- FlatSet<const HloInstruction*>* live_set = nullptr;
- for (const BufferAlias& alias :
- points_to_analysis.GetBufferAliases(*buffer)) {
- const std::vector<HloInstruction*>& users =
- alias.instruction()->users();
- if (!users.empty()) {
- if (live_set == nullptr) {
- live_set = &live_buffers[buffer];
- }
- live_set->insert(users.begin(), users.end());
- }
- }
-
// Add a nullptr sentry to ensure entry parameters and output source
// buffers are not freed until the very end.
const bool entry_parameter =
@@ -165,11 +161,12 @@ Status HeapSimulator::RunComputation(
// have no instructions left to visit are moved from live_buffers to
// operand_buffers_to_free.
operand_buffers_to_free.clear();
- for (const LogicalBuffer* operand_buffer :
- UniqueOperandSourceBuffers(instruction, points_to_analysis)) {
+ for (const LogicalBuffer* operand_buffer : used_buffers[instruction]) {
if (IgnoreBuffer(operand_buffer)) {
continue;
}
+ VLOG(4) << " Removing user " << instruction->name() << " from buffer "
+ << operand_buffer->ToString();
auto it = live_buffers.find(operand_buffer);
FlatSet<const HloInstruction*>* live_set = &it->second;
live_set->erase(instruction);
@@ -178,6 +175,11 @@ Status HeapSimulator::RunComputation(
operand_buffers_to_free.push_back(operand_buffer);
}
}
+ // Sort to get a deterministic iteration order.
+ std::sort(operand_buffers_to_free.begin(), operand_buffers_to_free.end(),
+ [](const LogicalBuffer* x, const LogicalBuffer* y) {
+ return x->id() < y->id();
+ });
// Allocate buffers defined by this instruction. This is the latest point
// that we can allocate; right before the buffer is first used. This must
@@ -203,6 +205,8 @@ Status HeapSimulator::RunComputation(
CanShareOperandBufferWithUser(
operand_buffer->instruction(), operand_buffer->index(),
buffer->instruction(), buffer->index(), points_to_analysis)) {
+ VLOG(3) << " Sharing: " << buffer->ToString() << " with "
+ << operand_buffer->ToString();
ShareBuffer(buffer, operand_buffer, instruction);
shared = true;
break;
@@ -211,6 +215,7 @@ Status HeapSimulator::RunComputation(
}
if (!shared) {
+ VLOG(3) << " Allocating: " << buffer->ToString();
Alloc(buffer, instruction);
}
}
@@ -244,20 +249,34 @@ Status HeapSimulator::RunComputation(
// Free buffers that are no longer live. This is the earliest point that we
// can de-allocate; right after the last use of the buffer.
for (const LogicalBuffer* buffer : dead_buffers_to_free) {
+ VLOG(3) << " Freeing dead: " << buffer->ToString();
Free(buffer, instruction);
}
for (const LogicalBuffer* buffer : operand_buffers_to_free) {
+ VLOG(3) << " Freeing operand: " << buffer->ToString();
Free(buffer, instruction);
}
}
// Any remaining live buffers must be entry parameters or output source
- // buffers, which had a nullptr sentry added. Free them now.
+ // buffers, which had a nullptr sentry added. Free them now, in a
+ // deterministic order.
+ std::vector<const LogicalBuffer*> to_free;
+ to_free.reserve(live_buffers.size());
for (const auto& buffer_pending : live_buffers) {
const LogicalBuffer* buffer = buffer_pending.first;
const FlatSet<const HloInstruction*>& pending = buffer_pending.second;
CHECK_EQ(pending.size(), 1) << *buffer;
CHECK(*pending.begin() == nullptr) << *buffer;
+ to_free.push_back(buffer);
+ }
+
+ std::sort(to_free.begin(), to_free.end(),
+ [](const LogicalBuffer* x, const LogicalBuffer* y) {
+ return x->id() < y->id();
+ });
+ for (const LogicalBuffer* buffer : to_free) {
+ VLOG(3) << "Freeing pending: " << buffer->ToString();
Free(buffer, root);
}