diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_alias_analysis.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_alias_analysis.cc | 46 |
1 files changed, 43 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index c3da12e273..cf8e6594cb 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -59,8 +59,9 @@ class BufferValueMap { // construction process. using BufferNumber = int64; - explicit BufferValueMap(const HloDataflowAnalysis& dataflow) - : dataflow_(dataflow) { + explicit BufferValueMap(HloModule* module, + const HloDataflowAnalysis& dataflow) + : module_(module), dataflow_(dataflow) { buffers_.reserve(dataflow_.values().size()); value_to_buffer_number_.reserve(dataflow_.values().size()); for (const HloValue* value : dataflow_.values()) { @@ -171,6 +172,42 @@ class BufferValueMap { return value_to_buffer_number_.at(&value); } + void ComputeInputOutputAliasedBuffers( + const HloValue& value, std::vector<BufferNumber>* aliased_buffers) { + // Get parameter value from an aliased_input object. + const auto get_parameter_value = + [this](const std::pair<int64, ShapeIndex>& aliased_input) + -> const HloValue& { + int64 param_number = aliased_input.first; + const ShapeIndex& param_index = aliased_input.second; + return dataflow_.GetUniqueValueAt( + module_->entry_computation()->parameter_instruction(param_number), + param_index); + }; + + // If the value shows up in a root instruction, alias it with parameter + // intruction. + for (const HloPosition& pos : value.positions()) { + if (pos.instruction == module_->entry_computation()->root_instruction()) { + ShapeIndex output_index = pos.index; + + auto aliased_input = + module_->input_output_alias_config().GetAliasedParameter( + output_index); + if (aliased_input) { + aliased_buffers->push_back( + GetBufferForValue(get_parameter_value(*aliased_input))); + } + } + } + + // If the value is parameter instruction itself, alias it with itself. + if (value.instruction()->opcode() == HloOpcode::kParameter && + value.instruction()->parent() == module_->entry_computation()) { + aliased_buffers->push_back(GetBufferForValue(value)); + } + } + void ComputeWhileAliasedBuffers(const HloValue& value, std::vector<BufferNumber>* aliased_buffers) { VLOG(3) << "Compute kWhile aliases"; @@ -278,6 +315,7 @@ class BufferValueMap { VLOG(2) << "Use of value " << value.ToShortString() << ": " << use; } std::vector<BufferNumber> aliased_buffers; + ComputeInputOutputAliasedBuffers(value, &aliased_buffers); ComputeWhileAliasedBuffers(value, &aliased_buffers); ComputeConditionalAliasedBuffers(value, &aliased_buffers); // Uniquify aliased buffers. @@ -288,6 +326,8 @@ class BufferValueMap { return aliased_buffers; } + HloModule* module_; + // Dataflow analysis used to construct the buffer map. const HloDataflowAnalysis& dataflow_; @@ -461,7 +501,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run( /*bitcast_defines_value=*/false, fusion_can_share_buffer)); - BufferValueMap buffer_map(alias_analysis->dataflow_analysis()); + BufferValueMap buffer_map(module, alias_analysis->dataflow_analysis()); buffer_map.MergeAliasedBuffers(); // Create a vector of HloBuffers, one for each set of values in the |