aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
diff options
context:
space:
mode:
authorGravatar Yunxing Dai <yunxing@google.com>2018-10-08 21:18:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 21:23:03 -0700
commit375c109659d2d0e6265447dffdeb460693b3cccf (patch)
treea6f09b6472cff1ade7fc91c1ff0d5e3f473da774 /tensorflow/compiler/xla/service/hlo_alias_analysis.cc
parentd58712b7fc8de0e1f87fe2ea5221bc3c85230ed3 (diff)
[XLA] Introduce input/output alias config.
- This CL intruduces input/output alias config in HLO module that allows any HLO pass to configure it. Once the alias_config is set, each backend needs to follow the contract during execution time to make sure the input and output are indeed aliased. - Copy insertion / buffer assignment and alias analysis has been updated to correctly honor the config and avoid any possible liveness interference. PiperOrigin-RevId: 216299501
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_alias_analysis.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.cc46
1 files changed, 43 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
index c3da12e273..cf8e6594cb 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
@@ -59,8 +59,9 @@ class BufferValueMap {
// construction process.
using BufferNumber = int64;
- explicit BufferValueMap(const HloDataflowAnalysis& dataflow)
- : dataflow_(dataflow) {
+ explicit BufferValueMap(HloModule* module,
+ const HloDataflowAnalysis& dataflow)
+ : module_(module), dataflow_(dataflow) {
buffers_.reserve(dataflow_.values().size());
value_to_buffer_number_.reserve(dataflow_.values().size());
for (const HloValue* value : dataflow_.values()) {
@@ -171,6 +172,42 @@ class BufferValueMap {
return value_to_buffer_number_.at(&value);
}
+ void ComputeInputOutputAliasedBuffers(
+ const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
+ // Get parameter value from an aliased_input object.
+ const auto get_parameter_value =
+ [this](const std::pair<int64, ShapeIndex>& aliased_input)
+ -> const HloValue& {
+ int64 param_number = aliased_input.first;
+ const ShapeIndex& param_index = aliased_input.second;
+ return dataflow_.GetUniqueValueAt(
+ module_->entry_computation()->parameter_instruction(param_number),
+ param_index);
+ };
+
+ // If the value shows up in a root instruction, alias it with parameter
+ // intruction.
+ for (const HloPosition& pos : value.positions()) {
+ if (pos.instruction == module_->entry_computation()->root_instruction()) {
+ ShapeIndex output_index = pos.index;
+
+ auto aliased_input =
+ module_->input_output_alias_config().GetAliasedParameter(
+ output_index);
+ if (aliased_input) {
+ aliased_buffers->push_back(
+ GetBufferForValue(get_parameter_value(*aliased_input)));
+ }
+ }
+ }
+
+ // If the value is parameter instruction itself, alias it with itself.
+ if (value.instruction()->opcode() == HloOpcode::kParameter &&
+ value.instruction()->parent() == module_->entry_computation()) {
+ aliased_buffers->push_back(GetBufferForValue(value));
+ }
+ }
+
void ComputeWhileAliasedBuffers(const HloValue& value,
std::vector<BufferNumber>* aliased_buffers) {
VLOG(3) << "Compute kWhile aliases";
@@ -278,6 +315,7 @@ class BufferValueMap {
VLOG(2) << "Use of value " << value.ToShortString() << ": " << use;
}
std::vector<BufferNumber> aliased_buffers;
+ ComputeInputOutputAliasedBuffers(value, &aliased_buffers);
ComputeWhileAliasedBuffers(value, &aliased_buffers);
ComputeConditionalAliasedBuffers(value, &aliased_buffers);
// Uniquify aliased buffers.
@@ -288,6 +326,8 @@ class BufferValueMap {
return aliased_buffers;
}
+ HloModule* module_;
+
// Dataflow analysis used to construct the buffer map.
const HloDataflowAnalysis& dataflow_;
@@ -461,7 +501,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
/*bitcast_defines_value=*/false,
fusion_can_share_buffer));
- BufferValueMap buffer_map(alias_analysis->dataflow_analysis());
+ BufferValueMap buffer_map(module, alias_analysis->dataflow_analysis());
buffer_map.MergeAliasedBuffers();
// Create a vector of HloBuffers, one for each set of values in the