aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/buffer_assignment.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/buffer_assignment.cc')
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc21
1 files changed, 16 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 3c5b360c8e..b422b22df9 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -497,19 +497,19 @@ Status GatherComputationsByAllocationType(
std::vector<const HloComputation*>* global_computations) {
// Create a worklist of computations paired with whether the allocation must
// be thread-local.
- std::deque<std::pair<const HloComputation*, bool>> worklist;
+ std::deque<std::pair<HloComputation*, bool>> worklist;
worklist.push_back(std::make_pair(module->entry_computation(),
/*is_thread_local*/ false));
// Sets for quickly checking membership. Computations are returned in vectors
// for stable iteration.
- FlatSet<const HloComputation*> thread_local_set;
- FlatSet<const HloComputation*> global_set;
+ FlatSet<HloComputation*> thread_local_set;
+ FlatSet<HloComputation*> global_set;
while (!worklist.empty()) {
auto worklist_front = worklist.front();
worklist.pop_front();
- const HloComputation* computation = worklist_front.first;
+ HloComputation* computation = worklist_front.first;
bool is_thread_local = worklist_front.second;
bool in_thread_local_set = thread_local_set.count(computation) > 0;
bool in_global_set = global_set.count(computation) > 0;
@@ -653,7 +653,7 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
}
if (allow_input_output_aliasing_ && allocation->maybe_live_out()) {
- const HloComputation* entry_computation =
+ HloComputation* entry_computation =
assignment->module_->entry_computation();
for (auto param : entry_computation->parameter_instructions()) {
for (auto& param_buffer :
@@ -819,6 +819,17 @@ Status BufferAssigner::AssignBuffersForComputation(
continue;
}
+ if (instruction->opcode() == HloOpcode::kRecv) {
+ // Make sure that recv operations get a new unique allocation so that
+ // don't share their buffer with any other operations.
+ BufferAllocation* allocation = assignment->NewAllocation(
+ *buffer, buffer_size, is_thread_local, /*is_reusable=*/false);
+ allocation_indices.push_back(allocation->index());
+ VLOG(3) << "New allocation #" << allocation->index()
+ << " for recv: " << *buffer;
+ continue;
+ }
+
if (ShapeUtil::IsTuple(buffer->shape())) {
// TODO(b/34669761): Don't reuse tuple buffers because the GPU backend
// assumes longer buffer liveness than indicated by the analysis.