aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_alias_analysis.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.cc46
1 files changed, 43 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
index c3da12e273..cf8e6594cb 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
@@ -59,8 +59,9 @@ class BufferValueMap {
// construction process.
using BufferNumber = int64;
- explicit BufferValueMap(const HloDataflowAnalysis& dataflow)
- : dataflow_(dataflow) {
+ explicit BufferValueMap(HloModule* module,
+ const HloDataflowAnalysis& dataflow)
+ : module_(module), dataflow_(dataflow) {
buffers_.reserve(dataflow_.values().size());
value_to_buffer_number_.reserve(dataflow_.values().size());
for (const HloValue* value : dataflow_.values()) {
@@ -171,6 +172,42 @@ class BufferValueMap {
return value_to_buffer_number_.at(&value);
}
+ void ComputeInputOutputAliasedBuffers(
+ const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
+ // Get parameter value from an aliased_input object.
+ const auto get_parameter_value =
+ [this](const std::pair<int64, ShapeIndex>& aliased_input)
+ -> const HloValue& {
+ int64 param_number = aliased_input.first;
+ const ShapeIndex& param_index = aliased_input.second;
+ return dataflow_.GetUniqueValueAt(
+ module_->entry_computation()->parameter_instruction(param_number),
+ param_index);
+ };
+
+ // If the value shows up in a root instruction, alias it with parameter
+ // intruction.
+ for (const HloPosition& pos : value.positions()) {
+ if (pos.instruction == module_->entry_computation()->root_instruction()) {
+ ShapeIndex output_index = pos.index;
+
+ auto aliased_input =
+ module_->input_output_alias_config().GetAliasedParameter(
+ output_index);
+ if (aliased_input) {
+ aliased_buffers->push_back(
+ GetBufferForValue(get_parameter_value(*aliased_input)));
+ }
+ }
+ }
+
+ // If the value is parameter instruction itself, alias it with itself.
+ if (value.instruction()->opcode() == HloOpcode::kParameter &&
+ value.instruction()->parent() == module_->entry_computation()) {
+ aliased_buffers->push_back(GetBufferForValue(value));
+ }
+ }
+
void ComputeWhileAliasedBuffers(const HloValue& value,
std::vector<BufferNumber>* aliased_buffers) {
VLOG(3) << "Compute kWhile aliases";
@@ -278,6 +315,7 @@ class BufferValueMap {
VLOG(2) << "Use of value " << value.ToShortString() << ": " << use;
}
std::vector<BufferNumber> aliased_buffers;
+ ComputeInputOutputAliasedBuffers(value, &aliased_buffers);
ComputeWhileAliasedBuffers(value, &aliased_buffers);
ComputeConditionalAliasedBuffers(value, &aliased_buffers);
// Uniquify aliased buffers.
@@ -288,6 +326,8 @@ class BufferValueMap {
return aliased_buffers;
}
+ HloModule* module_;
+
// Dataflow analysis used to construct the buffer map.
const HloDataflowAnalysis& dataflow_;
@@ -461,7 +501,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
/*bitcast_defines_value=*/false,
fusion_can_share_buffer));
- BufferValueMap buffer_map(alias_analysis->dataflow_analysis());
+ BufferValueMap buffer_map(module, alias_analysis->dataflow_analysis());
buffer_map.MergeAliasedBuffers();
// Create a vector of HloBuffers, one for each set of values in the