diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/buffer_assignment.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_assignment.cc | 21 |
1 files changed, 16 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 3c5b360c8e..b422b22df9 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -497,19 +497,19 @@ Status GatherComputationsByAllocationType( std::vector<const HloComputation*>* global_computations) { // Create a worklist of computations paired with whether the allocation must // be thread-local. - std::deque<std::pair<const HloComputation*, bool>> worklist; + std::deque<std::pair<HloComputation*, bool>> worklist; worklist.push_back(std::make_pair(module->entry_computation(), /*is_thread_local*/ false)); // Sets for quickly checking membership. Computations are returned in vectors // for stable iteration. - FlatSet<const HloComputation*> thread_local_set; - FlatSet<const HloComputation*> global_set; + FlatSet<HloComputation*> thread_local_set; + FlatSet<HloComputation*> global_set; while (!worklist.empty()) { auto worklist_front = worklist.front(); worklist.pop_front(); - const HloComputation* computation = worklist_front.first; + HloComputation* computation = worklist_front.first; bool is_thread_local = worklist_front.second; bool in_thread_local_set = thread_local_set.count(computation) > 0; bool in_global_set = global_set.count(computation) > 0; @@ -653,7 +653,7 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, } if (allow_input_output_aliasing_ && allocation->maybe_live_out()) { - const HloComputation* entry_computation = + HloComputation* entry_computation = assignment->module_->entry_computation(); for (auto param : entry_computation->parameter_instructions()) { for (auto& param_buffer : @@ -819,6 +819,17 @@ Status BufferAssigner::AssignBuffersForComputation( continue; } + if (instruction->opcode() == HloOpcode::kRecv) { + // Make sure that recv operations get a new unique allocation so that + // don't share their buffer with any other operations. + BufferAllocation* allocation = assignment->NewAllocation( + *buffer, buffer_size, is_thread_local, /*is_reusable=*/false); + allocation_indices.push_back(allocation->index()); + VLOG(3) << "New allocation #" << allocation->index() + << " for recv: " << *buffer; + continue; + } + if (ShapeUtil::IsTuple(buffer->shape())) { // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend // assumes longer buffer liveness than indicated by the analysis. |