diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/buffer_assignment.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_assignment.cc | 34 |
1 files changed, 17 insertions, 17 deletions
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 2c2d1626c2..d5d6a044a8 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -239,7 +239,7 @@ BufferAllocation::Slice BufferAllocation::GetSlice( void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size) { - VLOG(4) << "Trying to add " << buffer << " to " << this; + VLOG(4) << "Trying to add " << buffer << " to allocation #" << index(); CHECK(assigned_buffers_.count(&buffer) == 0) << "LogicalBuffer " << buffer << " already assigned to allocation " << index_; @@ -784,21 +784,6 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, } } - if (allow_input_output_aliasing_ && allocation->maybe_live_out()) { - const HloComputation* entry_computation = - assignment->module_->entry_computation(); - for (auto param : entry_computation->parameter_instructions()) { - for (auto& param_buffer : - assignment->points_to_analysis().GetBuffersDefinedByInstruction( - param)) { - if (assignment->liveness().MayInterfere(*param_buffer, buffer)) { - VLOG(4) << "Can't assign: Parameter interference with result"; - return false; - } - } - } - } - // If the buffer is live out of the computation then it should only be // assigned a buffer which exactly fits the result to avoid wasting memory // (result buffers can have arbitrary lifetimes). @@ -1434,13 +1419,28 @@ BufferAssigner::MergeColocatedBufferSets( // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated // in the same allocation (currently just supports kWhile, kCall, and -// kConditional). +// kConditional and input output aliasing). void BufferAssigner::BuildColocatedBufferSets( const HloModule* module, const BufferLiveness& buffer_liveness, const LogicalBuffer::SizeFunction& buffer_size, std::vector<ColocatedBufferSet>* colocated_buffer_sets) { const TuplePointsToAnalysis& points_to_analysis = buffer_liveness.points_to_analysis(); + + // Set up colocated buffer set for input and output. + module->input_output_alias_config().ForEachAlias( + [&](const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index) { + std::vector<const LogicalBuffer*> colocated_set; + AddBufferToColocatedSet(module->entry_computation()->root_instruction(), + output_index, points_to_analysis, + &colocated_set); + AddBufferToColocatedSet( + module->entry_computation()->parameter_instruction(param_number), + param_index, points_to_analysis, &colocated_set); + AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); + }); + for (const HloComputation* computation : module->MakeComputationPostOrder()) { if (computation->IsFusionComputation()) { continue; |