aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yunxing Dai <yunxing@google.com>2018-10-08 21:18:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 21:23:03 -0700
commit375c109659d2d0e6265447dffdeb460693b3cccf (patch)
treea6f09b6472cff1ade7fc91c1ff0d5e3f473da774
parentd58712b7fc8de0e1f87fe2ea5221bc3c85230ed3 (diff)
[XLA] Introduce input/output alias config.
- This CL intruduces input/output alias config in HLO module that allows any HLO pass to configure it. Once the alias_config is set, each backend needs to follow the contract during execution time to make sure the input and output are indeed aliased. - Copy insertion / buffer assignment and alias analysis has been updated to correctly honor the config and avoid any possible liveness interference. PiperOrigin-RevId: 216299501
-rw-r--r--tensorflow/compiler/xla/service/BUILD21
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc34
-rw-r--r--tensorflow/compiler/xla/service/buffer_value.h3
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc85
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc183
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto29
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.cc46
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc175
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc172
-rw-r--r--tensorflow/compiler/xla/service/hlo_input_output_alias_config.h101
-rw-r--r--tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc184
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h14
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc2
-rw-r--r--tensorflow/compiler/xla/shape_util.h2
16 files changed, 1037 insertions, 25 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 2b292ed053..26ebb88e96 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -294,6 +294,7 @@ cc_library(
srcs = [
"dfs_hlo_visitor.cc",
"hlo_computation.cc",
+ "hlo_input_output_alias_config.cc",
"hlo_instruction.cc",
"hlo_instructions.cc",
"hlo_module.cc",
@@ -308,6 +309,7 @@ cc_library(
"hlo_clone_context.h",
"hlo_computation.h",
"hlo_domain_metadata.h",
+ "hlo_input_output_alias_config.h",
"hlo_instruction.h",
"hlo_instructions.h",
"hlo_module.h",
@@ -1268,6 +1270,25 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "hlo_input_output_alias_config_test",
+ srcs = ["hlo_input_output_alias_config_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_dce",
+ ":hlo_memory_scheduler",
+ ":hlo_ordering",
+ ":hlo_parser",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
+ "@com_google_absl//absl/algorithm:container",
+ ],
+)
+
cc_library(
name = "hlo_memory_scheduler",
srcs = ["hlo_memory_scheduler.cc"],
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 2c2d1626c2..d5d6a044a8 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -239,7 +239,7 @@ BufferAllocation::Slice BufferAllocation::GetSlice(
void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset,
int64 size) {
- VLOG(4) << "Trying to add " << buffer << " to " << this;
+ VLOG(4) << "Trying to add " << buffer << " to allocation #" << index();
CHECK(assigned_buffers_.count(&buffer) == 0)
<< "LogicalBuffer " << buffer << " already assigned to allocation "
<< index_;
@@ -784,21 +784,6 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
}
}
- if (allow_input_output_aliasing_ && allocation->maybe_live_out()) {
- const HloComputation* entry_computation =
- assignment->module_->entry_computation();
- for (auto param : entry_computation->parameter_instructions()) {
- for (auto& param_buffer :
- assignment->points_to_analysis().GetBuffersDefinedByInstruction(
- param)) {
- if (assignment->liveness().MayInterfere(*param_buffer, buffer)) {
- VLOG(4) << "Can't assign: Parameter interference with result";
- return false;
- }
- }
- }
- }
-
// If the buffer is live out of the computation then it should only be
// assigned a buffer which exactly fits the result to avoid wasting memory
// (result buffers can have arbitrary lifetimes).
@@ -1434,13 +1419,28 @@ BufferAssigner::MergeColocatedBufferSets(
// Builds sets of buffers in 'colocated_buffer_sets' which should be colocated
// in the same allocation (currently just supports kWhile, kCall, and
-// kConditional).
+// kConditional and input output aliasing).
void BufferAssigner::BuildColocatedBufferSets(
const HloModule* module, const BufferLiveness& buffer_liveness,
const LogicalBuffer::SizeFunction& buffer_size,
std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
const TuplePointsToAnalysis& points_to_analysis =
buffer_liveness.points_to_analysis();
+
+ // Set up colocated buffer set for input and output.
+ module->input_output_alias_config().ForEachAlias(
+ [&](const ShapeIndex& output_index, int64 param_number,
+ const ShapeIndex& param_index) {
+ std::vector<const LogicalBuffer*> colocated_set;
+ AddBufferToColocatedSet(module->entry_computation()->root_instruction(),
+ output_index, points_to_analysis,
+ &colocated_set);
+ AddBufferToColocatedSet(
+ module->entry_computation()->parameter_instruction(param_number),
+ param_index, points_to_analysis, &colocated_set);
+ AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
+ });
+
for (const HloComputation* computation : module->MakeComputationPostOrder()) {
if (computation->IsFusionComputation()) {
continue;
diff --git a/tensorflow/compiler/xla/service/buffer_value.h b/tensorflow/compiler/xla/service/buffer_value.h
index 69b3646356..11d8abc5ba 100644
--- a/tensorflow/compiler/xla/service/buffer_value.h
+++ b/tensorflow/compiler/xla/service/buffer_value.h
@@ -141,6 +141,9 @@ class BufferValue {
// operator< is required for std::set.
bool operator<(const BufferValue& other) const { return id_ < other.id_; }
+ bool operator==(const BufferValue& other) const { return id_ == other.id_; }
+ bool operator!=(const BufferValue& other) const { return id_ != other.id_; }
+
virtual string ToString() const = 0;
// TODO(lauj) rename LogicalBufferProto to BufferValueProto.
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index f35324aa35..cfe025fdd1 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -40,10 +40,12 @@ namespace {
using absl::StrAppend;
-bool IsEntryParameterValue(const HloValue& value) {
+bool IsReadonlyEntryParameterValue(const HloValue& value) {
const HloComputation* computation = value.defining_instruction()->parent();
return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
- computation == computation->parent()->entry_computation();
+ computation == computation->parent()->entry_computation() &&
+ !computation->parent()->input_output_alias_config().ParameterHasAlias(
+ value.defining_instruction()->parameter_number());
}
bool IsConstantValue(const HloValue& value) {
@@ -51,7 +53,7 @@ bool IsConstantValue(const HloValue& value) {
}
bool ValueIsReadOnly(const HloValue& value) {
- return IsConstantValue(value) || IsEntryParameterValue(value);
+ return IsConstantValue(value) || IsReadonlyEntryParameterValue(value);
}
// Data structure describing the action which should be taken on parts of a
@@ -332,6 +334,81 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
return Status::OK();
}
+// Conservatively adds copies before root instruction of entry computation and
+// each aliased parameter to resolve interference of aliased input and output
+// buffer. We later rely on the CopyRemover to drop the unnecessary ones.
+Status AddCopiesForAliasedInputOutputs(HloModule* module) {
+ HloComputation* entry = module->entry_computation();
+ HloInstruction* root = entry->root_instruction();
+
+ ShapeTree<bool> output_indices_to_copy(root->shape());
+ std::vector<ShapeTree<HloInstruction*>> copied_parameters;
+ bool has_alias = false;
+ for (auto* param : entry->parameter_instructions()) {
+ bool param_has_alias = false;
+ ShapeTree<bool> param_indices_to_copy(param->shape());
+
+ module->input_output_alias_config().ForEachAlias(
+ [&](const ShapeIndex& output_index, int64 param_number,
+ const ShapeIndex& param_index) {
+ if (param_number == param->parameter_number()) {
+ param_has_alias = true;
+ *(param_indices_to_copy.mutable_element(param_index)) = true;
+ *(output_indices_to_copy.mutable_element(output_index)) = true;
+ }
+ });
+
+ if (!param_has_alias) {
+ continue;
+ }
+
+ has_alias = true;
+ // Store a snapshot of users before DeepCopyInstruction, as
+ // DeepCopyInstruction introduces new users of the instruction.
+ std::vector<HloInstruction*> users = param->users();
+ ShapeTree<HloInstruction*> param_copy_tree(param->shape(),
+ /*init_value=*/nullptr);
+ TF_ASSIGN_OR_RETURN(HloInstruction * copied,
+ entry->DeepCopyInstruction(
+ param, &param_indices_to_copy, &param_copy_tree));
+ for (HloInstruction* user : users) {
+ TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied));
+ }
+
+ copied_parameters.push_back(param_copy_tree);
+ }
+
+ if (!has_alias) {
+ return Status::OK();
+ }
+
+ // Add copies before root instruction.
+ ShapeTree<HloInstruction*> output_copy_tree(root->shape(),
+ /*init_value=*/nullptr);
+
+ TF_ASSIGN_OR_RETURN(HloInstruction * root_copied,
+ root->parent()->DeepCopyInstruction(
+ root, &output_indices_to_copy, &output_copy_tree));
+
+ // Add control dependencies between the input/output copies.
+ TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus(
+ [&](const ShapeIndex& output_index, int64 param_number,
+ const ShapeIndex& input_index) -> Status {
+ HloInstruction* from =
+ copied_parameters[param_number].element(input_index);
+ HloInstruction* to = output_copy_tree.element(output_index);
+
+ TF_RET_CHECK(from != nullptr);
+ TF_RET_CHECK(to != nullptr);
+ TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to));
+ return Status::OK();
+ }));
+
+ entry->set_root_instruction(root_copied);
+
+ return Status::OK();
+}
+
// Removes any control dependencies to or from the given instruction.
Status StripControlDependenciesFrom(HloInstruction* instruction) {
while (!instruction->control_successors().empty()) {
@@ -953,6 +1030,8 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
}
}
}
+
+ TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index 892d0d7b54..3096206c34 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -1351,6 +1351,189 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) {
EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
}
+TEST_F(CopyInsertionTest, CrossingParameters) {
+ // Test a case where two parameters' dataflow cross with each other while
+ // input and output are aliased with same index:
+ //
+ // (p0 , p1)
+ // | \ /|
+ // | \ / |
+ // alias X alias
+ // | / \ |
+ // | / \|
+ // (p1 , p0)
+ auto module = CreateNewModule();
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "0"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
+ builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0}));
+ module->AddEntryComputation(builder.Build());
+ ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
+ ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
+ InsertCopies(module.get());
+
+ EXPECT_EQ(CountCopies(*module), 4);
+}
+
+TEST_F(CopyInsertionTest, ParametersAliasing) {
+ // Test a case where two parameters' dataflow don't interfere with each other
+ // while aliased.
+ //
+ // (p0 , p1)
+ // | |
+ // | |
+ // alias alias
+ // | |
+ // | |
+ // (p0 , p1)
+ auto module = CreateNewModule();
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "p0"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
+ builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
+ module->AddEntryComputation(builder.Build());
+ ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
+ ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
+ InsertCopies(module.get());
+
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::Tuple(op::Copy(op::GetTupleElement(param, 0)),
+ op::Copy(op::GetTupleElement(param, 1))));
+
+ EXPECT_EQ(CountCopies(*module), 2);
+}
+
+TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) {
+ // Test a case where one parameter is aliased with result while another one
+ // isn't.
+ //
+ // (p0 , p1)
+ // | |
+ // | |
+ // alias |
+ // | |
+ // | |
+ // (p0 , p1)
+ auto module = CreateNewModule();
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "p0"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
+ builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
+ module->AddEntryComputation(builder.Build());
+ ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
+ InsertCopies(module.get());
+
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::Tuple(op::Copy(op::GetTupleElement(param, 0)),
+ op::Copy(op::GetTupleElement(param, 1))));
+
+ EXPECT_EQ(CountCopies(*module), 2);
+}
+
+TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) {
+ // Test a case where one parameter is aliased with result while another one
+ // isn't.
+ //
+ // +-- (p0 , p1)
+ // | | |
+ // | | |
+ // alias Negate Negate
+ // | | |
+ // | | |
+ // +-- (p0 , p1)
+ auto module = CreateNewModule();
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "p0"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
+
+ auto negate0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
+
+ auto negate1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
+ builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
+ module->AddEntryComputation(builder.Build());
+ ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
+ InsertCopies(module.get());
+
+ EXPECT_EQ(CountCopies(*module), 0);
+}
+
+TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) {
+ // Test a case where one parameter is aliased with result while another one
+ // isn't.
+ //
+ // +-- (p0 , p1)
+ // | | |
+ // | | |
+ // alias Negate Negate
+ // | | |
+ // | Add----+
+ // | | |
+ // +-- (p0 , p1)
+ auto module = CreateNewModule();
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "p0"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
+
+ auto negate0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
+
+ auto negate1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
+
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape_, HloOpcode::kAdd, negate0, negate1));
+ builder.AddInstruction(HloInstruction::CreateTuple({add, negate1}));
+ module->AddEntryComputation(builder.Build());
+ ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
+ InsertCopies(module.get());
+
+ EXPECT_EQ(CountCopies(*module), 0);
+}
+
TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
// Test a while instruction with a body which permutes its tuple parameter
// elements and applies one operation to one of the elements. The addition of
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index a0eb9e6ddc..82c8fb1904 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -225,6 +225,32 @@ message HloScheduleProto {
map<int64, InstructionSequence> sequences = 1;
}
+message HloInputOutputAliasProto {
+ // The following proto describes a pair of aliased an input
+ // (described by parameter number and a ShapeIndex of the parameter)
+ // and an output (described by a ShapeIndex of the root
+ // instruction). For example:
+ //
+ // entry = {
+ // output_shape_index={1},
+ // parameter_number=0,
+ // parameter_shape_index={1, 2},
+ // }
+ //
+ // This entry indicates that the first paremter's {1, 2} element is
+ // aliased with the {1} element of the root instruction.
+ message AliasEntryProto {
+ // ShapeIndex of the root hlo.
+ repeated int64 output_shape_index = 1;
+ // Number of the parameter in entry computation.
+ int64 parameter_number = 2;
+ // ShapeIndex of the parameter instruction.
+ repeated int64 parameter_shape_index = 3;
+ }
+
+ repeated AliasEntryProto entries = 1;
+}
+
// Serialization of HloModule.
message HloModuleProto {
string name = 1;
@@ -243,6 +269,9 @@ message HloModuleProto {
// The schedule for this module.
HloScheduleProto schedule = 7;
+
+ // Describes alias information between inputs and outputs.
+ HloInputOutputAliasProto input_output_alias = 8;
}
// Serialization of LogicalBuffer.
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
index c3da12e273..cf8e6594cb 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
@@ -59,8 +59,9 @@ class BufferValueMap {
// construction process.
using BufferNumber = int64;
- explicit BufferValueMap(const HloDataflowAnalysis& dataflow)
- : dataflow_(dataflow) {
+ explicit BufferValueMap(HloModule* module,
+ const HloDataflowAnalysis& dataflow)
+ : module_(module), dataflow_(dataflow) {
buffers_.reserve(dataflow_.values().size());
value_to_buffer_number_.reserve(dataflow_.values().size());
for (const HloValue* value : dataflow_.values()) {
@@ -171,6 +172,42 @@ class BufferValueMap {
return value_to_buffer_number_.at(&value);
}
+ void ComputeInputOutputAliasedBuffers(
+ const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
+ // Get parameter value from an aliased_input object.
+ const auto get_parameter_value =
+ [this](const std::pair<int64, ShapeIndex>& aliased_input)
+ -> const HloValue& {
+ int64 param_number = aliased_input.first;
+ const ShapeIndex& param_index = aliased_input.second;
+ return dataflow_.GetUniqueValueAt(
+ module_->entry_computation()->parameter_instruction(param_number),
+ param_index);
+ };
+
+ // If the value shows up in a root instruction, alias it with parameter
+ // intruction.
+ for (const HloPosition& pos : value.positions()) {
+ if (pos.instruction == module_->entry_computation()->root_instruction()) {
+ ShapeIndex output_index = pos.index;
+
+ auto aliased_input =
+ module_->input_output_alias_config().GetAliasedParameter(
+ output_index);
+ if (aliased_input) {
+ aliased_buffers->push_back(
+ GetBufferForValue(get_parameter_value(*aliased_input)));
+ }
+ }
+ }
+
+ // If the value is parameter instruction itself, alias it with itself.
+ if (value.instruction()->opcode() == HloOpcode::kParameter &&
+ value.instruction()->parent() == module_->entry_computation()) {
+ aliased_buffers->push_back(GetBufferForValue(value));
+ }
+ }
+
void ComputeWhileAliasedBuffers(const HloValue& value,
std::vector<BufferNumber>* aliased_buffers) {
VLOG(3) << "Compute kWhile aliases";
@@ -278,6 +315,7 @@ class BufferValueMap {
VLOG(2) << "Use of value " << value.ToShortString() << ": " << use;
}
std::vector<BufferNumber> aliased_buffers;
+ ComputeInputOutputAliasedBuffers(value, &aliased_buffers);
ComputeWhileAliasedBuffers(value, &aliased_buffers);
ComputeConditionalAliasedBuffers(value, &aliased_buffers);
// Uniquify aliased buffers.
@@ -288,6 +326,8 @@ class BufferValueMap {
return aliased_buffers;
}
+ HloModule* module_;
+
// Dataflow analysis used to construct the buffer map.
const HloDataflowAnalysis& dataflow_;
@@ -461,7 +501,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
/*bitcast_defines_value=*/false,
fusion_can_share_buffer));
- BufferValueMap buffer_map(alias_analysis->dataflow_analysis());
+ BufferValueMap buffer_map(module, alias_analysis->dataflow_analysis());
buffer_map.MergeAliasedBuffers();
// Create a vector of HloBuffers, one for each set of values in the
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
index 0cd0ab36fc..5c8d97b2d1 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
@@ -217,6 +217,181 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) {
EXPECT_FALSE(AnyValuesInSameBufferInterfere());
}
+TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) {
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "p0"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
+
+ auto negate0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
+ auto negate1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
+
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
+ module_->AddEntryComputation(builder.Build());
+ TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
+ TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
+
+ // Cannot alias an output twice.
+ ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}));
+
+ const HloAliasAnalysis& analysis = RunAnalysis();
+
+ EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
+ analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
+
+ EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
+ analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
+}
+
+TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) {
+ // parameter 0 aliased with output 1 and parameter 1 aliased with output 0.
+ //
+ // (p0 , p1)
+ // \ /
+ // \ /
+ // alias X
+ // / \
+ // / \
+ // (p0 , p1)
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "p0"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
+ module_->AddEntryComputation(builder.Build());
+ TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1}));
+ TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}));
+
+ // Cannot alias an output twice.
+ ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
+
+ const HloAliasAnalysis& analysis = RunAnalysis();
+
+ // Every Ops in this graph are aliased with each other.
+ EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
+ analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
+ EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
+ analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
+
+ EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
+ analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
+ EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
+ analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
+}
+
+TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) {
+ // Test a simple single while instruction can be aliased with input and output
+ // of the computation.
+ //
+ // body((F32[], F32[]) %tuple_param):
+ // %add = Add(%tuple_param{0}, %tuple_param{1})
+ // return Tuple(%tuple_param{0}, %add)
+ //
+ // condition((F32[], F32[]) %tuple_param):
+ // return Constant(false)
+ //
+ // entry:
+ // %param1 = param1
+ // %while = While(%param1, body, condition)
+ // %while_1 = GTE(%while, 0)
+ // %while_2 = GTE(%while, 1)
+ // %negate_1 = Negate(%while_1)
+ // %negate_2 = Negate(%while_2)
+ // return Tuple(negate_1, negate_2)
+ //
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
+
+ // Element 0 passes transparently through the body.
+ auto body_builder = HloComputation::Builder("body");
+ auto body_param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "param"));
+ auto body_element_0 = body_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
+ auto body_element_1 = body_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
+ auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
+ auto body_tuple = body_builder.AddInstruction(
+ HloInstruction::CreateTuple({body_element_0, add}));
+ HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
+
+ // Condition computation trivially returns a constant "false".
+ auto cond_builder = HloComputation::Builder("condition");
+ auto cond_param = cond_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "param"));
+ cond_builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
+ HloComputation* condition =
+ module_->AddEmbeddedComputation(cond_builder.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "p0"));
+
+ auto xla_while = builder.AddInstruction(
+ HloInstruction::CreateWhile(tuple_shape, condition, body, param));
+ auto while_element_1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 0));
+ auto while_element_2 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 1));
+ auto negate_1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ scalar_shape_, HloOpcode::kNegate, while_element_1));
+ auto negate_2 = builder.AddInstruction(HloInstruction::CreateUnary(
+ scalar_shape_, HloOpcode::kNegate, while_element_2));
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({negate_1, negate_2}));
+ module_->AddEntryComputation(builder.Build());
+ TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
+ TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
+
+ const HloAliasAnalysis& analysis = RunAnalysis();
+
+ EXPECT_THAT(
+ GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})),
+ UnorderedElementsAre(GetValueDefinedAt(param, {1}),
+ GetValueDefinedAt(xla_while, /*index=*/{1}),
+ GetValueDefinedAt(body_param, {1}),
+ GetValueDefinedAt(cond_param, {1}),
+ GetValueDefinedAt(add),
+ GetValueDefinedAt(negate_2)));
+
+ EXPECT_THAT(
+ analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).ComputePositions(),
+ UnorderedElementsAre(
+ HloPosition{param, {1}}, HloPosition{xla_while, {1}},
+ HloPosition{while_element_2, {}}, HloPosition{body_param, {1}},
+ HloPosition{body_element_1, {}}, HloPosition{add, {}},
+ HloPosition{body_tuple, {1}}, HloPosition{tuple, {1}},
+ HloPosition{cond_param, {1}}, HloPosition{negate_2, {}}));
+
+ EXPECT_FALSE(AnyValuesInSameBufferInterfere());
+}
+
TEST_F(HloAliasAnalysisTest, SingleCall) {
// Test a single call of a subcomputation. The subcomputation adds its two
// array-shaped parameters.
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index c22adcdd8d..f401eac016 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -126,7 +126,7 @@ bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
const HloInstruction* instruction, const ShapeIndex& index) const {
- CHECK(ValueIsDefinedAt(instruction, index));
+ CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString();
return GetUniqueValueAt(instruction, index);
}
diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc
new file mode 100644
index 0000000000..9ad98e5038
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc
@@ -0,0 +1,172 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+
+namespace xla {
+Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index,
+ int64 param_number,
+ const ShapeIndex& param_index) {
+ // Output can't be aliased with multiple parameters.
+ TF_RET_CHECK(!alias_.element(output_index));
+ (*alias_.mutable_element(output_index)) =
+ std::make_pair(param_number, param_index);
+ return Status::OK();
+}
+
+HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const {
+ HloInputOutputAliasProto result;
+ alias_.ForEachElement(
+ [&](const ShapeIndex& index,
+ const absl::optional<std::pair<int64, ShapeIndex>>& data) {
+ if (data) {
+ HloInputOutputAliasProto::AliasEntryProto entry;
+ for (int64 i : index) {
+ entry.add_output_shape_index(i);
+ }
+ entry.set_parameter_number(data->first);
+ for (int64 i : data->second) {
+ entry.add_parameter_shape_index(i);
+ }
+ result.add_entries()->Swap(&entry);
+ }
+ });
+ return result;
+}
+
+StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
+ const HloModule* module, const HloInputOutputAliasProto& proto) {
+ HloInputOutputAliasConfig result(
+ module->entry_computation()->root_instruction()->shape());
+ for (const HloInputOutputAliasProto::AliasEntryProto& entry :
+ proto.entries()) {
+ ShapeIndex output_index(entry.output_shape_index().begin(),
+ entry.output_shape_index().end());
+
+ int64 param_number = entry.parameter_number();
+ ShapeIndex param_index(entry.parameter_shape_index().begin(),
+ entry.parameter_shape_index().end());
+ TF_RETURN_IF_ERROR(
+ result.SetUpAlias(output_index, param_number, param_index));
+ }
+
+ return result;
+}
+
+string HloInputOutputAliasConfig::ToString() const {
+ std::vector<string> pieces;
+ pieces.push_back("HloInputOutputAliasConfig");
+
+ ForEachAlias([&](const ShapeIndex& output_index, int64 param_number,
+ const ShapeIndex& param_index) {
+ pieces.push_back(absl::StrFormat(
+ " OutputIndex %s is aliased with parameter %lld at %s:",
+ output_index.ToString(), param_number, param_index.ToString()));
+ });
+
+ return absl::StrJoin(pieces, "\n");
+}
+
+bool HloInputOutputAliasConfig::ParameterHasAlias(int64 param_number) const {
+ bool output = false;
+ alias_.ForEachElement(
+ [&](const xla::ShapeIndex&,
+ absl::optional<std::pair<int64, ShapeIndex>> alias) {
+ if (alias && alias->first == param_number) {
+ output = true;
+ }
+ });
+ return output;
+}
+
+absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput(
+ int64 param_number, const ShapeIndex& param_index) const {
+ absl::optional<ShapeIndex> output;
+ alias_.ForEachElement(
+ [&](const xla::ShapeIndex& output_index,
+ absl::optional<std::pair<int64, ShapeIndex>> alias) {
+ if (alias && alias->first == param_number &&
+ alias->second == param_index) {
+ output = output_index;
+ }
+ });
+ return output;
+}
+
+absl::optional<std::pair<int64, ShapeIndex>>
+HloInputOutputAliasConfig::GetAliasedParameter(
+ const ShapeIndex& output_index) const {
+ CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index));
+ return alias_.element(output_index);
+}
+
+void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const {
+ alias_.ForEachElement(
+ [&](const ShapeIndex& output_index,
+ absl::optional<std::pair<int64, ShapeIndex>> aliased) {
+ if (aliased) {
+ fn(output_index, aliased->first, aliased->second);
+ }
+ });
+}
+
+Status HloInputOutputAliasConfig::ForEachAliasWithStatus(
+ AliasFnWithStatus fn) const {
+ return alias_.ForEachElementWithStatus(
+ [&](const ShapeIndex& output_index,
+ absl::optional<std::pair<int64, ShapeIndex>> aliased) {
+ if (aliased) {
+ TF_RETURN_IF_ERROR(fn(output_index, aliased->first, aliased->second));
+ }
+ return Status::OK();
+ });
+}
+
+Status HloInputOutputAliasConfig::Verify(const HloModule& module) const {
+ std::vector<ShapeTree<bool>> param_has_seen;
+ const HloComputation* entry = module.entry_computation();
+ for (int64 i = 0; i < entry->num_parameters(); ++i) {
+ HloInstruction* param = entry->parameter_instruction(i);
+ param_has_seen.emplace_back(param->shape());
+ }
+ return ForEachAliasWithStatus([&](const ShapeIndex& output_index,
+ int64 param_number,
+ const ShapeIndex& param_index) -> Status {
+ const HloInstruction* root = entry->root_instruction();
+
+ const Shape& param_shape =
+ entry->parameter_instruction(param_number)->shape();
+ const Shape& output_shape = root->shape();
+ TF_RET_CHECK(entry->num_parameters() > param_number);
+ TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, param_index));
+ TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index));
+
+ // Check each param_number and param_index pair only show up once. No
+ // input can be aliased with output buffers.
+ TF_RET_CHECK(param_has_seen[param_number].element(param_index) == false);
+
+ *(param_has_seen[param_number].mutable_element(param_index)) = true;
+
+ return Status::OK();
+ });
+}
+
+std::ostream& operator<<(std::ostream& out,
+ const HloInputOutputAliasConfig& config) {
+ out << config.ToString();
+ return out;
+}
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h
new file mode 100644
index 0000000000..02c46f65c8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h
@@ -0,0 +1,101 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_
+
+#include <utility>
+
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+
+namespace xla {
+
+class HloModule;
+
+// This class specifies the alias map from output index to parameter number and
+// parameter index in the entry computation.
+class HloInputOutputAliasConfig {
+ public:
+ HloInputOutputAliasConfig() = default;
+
+ explicit HloInputOutputAliasConfig(Shape shape) : alias_(shape) {}
+
+ virtual ~HloInputOutputAliasConfig() = default;
+
+ // Sets up alias config from `output_index` to `param_index` at
+ // `param_number`.
+ Status SetUpAlias(const ShapeIndex& output_index, int64 param_number,
+ const ShapeIndex& param_index);
+
+ // Returns true if the given parameter is aliased with one of the output
+ // buffers.
+ bool ParameterHasAlias(int64 param_number) const;
+
+ // (De)Serializes an HloInputOutoutAliasConfig to/from an
+ // HloInputOutoutAliasProto.
+ HloInputOutputAliasProto ToProto() const;
+
+ static StatusOr<HloInputOutputAliasConfig> CreateFromProto(
+ const HloModule* module, const HloInputOutputAliasProto& proto);
+
+ // Returns the output index that the given parameter and parameter index is
+ // aliased with. A nullopt is returned if there is no output that is aliased
+ // with the parameter number and index.
+ absl::optional<ShapeIndex> GetAliasedOutput(
+ int64 param_number, const ShapeIndex& param_index) const;
+
+ // Returns the number of parameter and index of the parameter buffer that the
+ // given output buffer index is aliased with. A nullopt is returned if there
+ // is no parameter is aliased with the specific output.
+ absl::optional<std::pair<int64, ShapeIndex>> GetAliasedParameter(
+ const ShapeIndex& output_index) const;
+
+ using AliasFn =
+ std::function<void(const ShapeIndex& output_index, int64 param_number,
+ const ShapeIndex& param_index)>;
+
+ // Iterates through each aliased output and input.
+ void ForEachAlias(AliasFn fn) const;
+
+ using AliasFnWithStatus =
+ std::function<Status(const ShapeIndex& output_index, int64 param_number,
+ const ShapeIndex& param_index)>;
+
+ // Verifies that the given config is valid for the given module.
+ // Specifically, the config's input and output should be in-bound and size of
+ // the aliased buffers should match.
+ Status Verify(const HloModule& module) const;
+
+ Status ForEachAliasWithStatus(AliasFnWithStatus fn) const;
+
+ string ToString() const;
+
+ private:
+ // A ShapeTree which indicates the list of buffers that's expected to be
+ // aliased. The key on this shape tree represents the output index. The value
+ // is a pair of parameter number and index into the buffer. If the value is
+ // nullopt, it means there is no parameter aliasing for this output.
+ ShapeTree<absl::optional<std::pair<int64, ShapeIndex>>> alias_;
+};
+
+std::ostream& operator<<(std::ostream& out,
+ const HloInputOutputAliasConfig& config);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_
diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc
new file mode 100644
index 0000000000..3b61ff04e6
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc
@@ -0,0 +1,184 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
+
+#include <memory>
+#include <string>
+
+#include "absl/algorithm/container.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+class HloInputOutputAliasConfigTest : public HloTestBase {
+ protected:
+ void expect_aliased(const ShapeIndex& output_index, int64 param_number,
+ const ShapeIndex& param_index,
+ const HloInputOutputAliasConfig& config) {
+ absl::optional<ShapeIndex> aliased_output =
+ config.GetAliasedOutput(param_number, param_index);
+
+ EXPECT_TRUE(aliased_output);
+ EXPECT_EQ(aliased_output.value(), output_index);
+
+ absl::optional<std::pair<int64, ShapeIndex>> aliased_param =
+ config.GetAliasedParameter(output_index);
+
+ EXPECT_TRUE(aliased_param);
+ EXPECT_EQ(aliased_param.value(), std::make_pair(param_number, param_index));
+ }
+
+ void expect_not_aliased(const ShapeIndex& output_index, int64 param_number,
+ const ShapeIndex& param_index,
+ const HloInputOutputAliasConfig& config) {
+ absl::optional<ShapeIndex> aliased_output =
+ config.GetAliasedOutput(param_number, param_index);
+
+ EXPECT_FALSE(aliased_output && aliased_output == output_index);
+
+ absl::optional<std::pair<int64, ShapeIndex>> aliased_param =
+ config.GetAliasedParameter(output_index);
+
+ EXPECT_FALSE(aliased_param && aliased_param->first == param_number &&
+ aliased_param->second == param_index);
+ }
+};
+
+TEST_F(HloInputOutputAliasConfigTest, SimpleAliasing) {
+ const string module_str = R"(
+HloModule TEST
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT root = (f32[], f32[]) tuple(%a, %b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+
+ HloInputOutputAliasConfig config(
+ module->entry_computation()->root_instruction()->shape());
+
+ TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1,
+ /*param_index=*/{}));
+
+ expect_aliased(/*output_index=*/{0}, /*param_number=*/1,
+ /*param_index=*/{}, config);
+
+ expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1,
+ /*param_index=*/{}, config);
+
+ expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0,
+ /*param_index=*/{}, config);
+}
+
+TEST_F(HloInputOutputAliasConfigTest, SimpleAliasingWithTupleInput) {
+ const string module_str = R"(
+HloModule TEST
+
+ENTRY main {
+ param = (f32[], f32[]) parameter(0)
+ gte1 = f32[] get-tuple-element(%param), index=0
+ gte2 = f32[] get-tuple-element(%param), index=1
+ ROOT root = (f32[], f32[]) tuple(%gte1, %gte2)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+
+ HloInputOutputAliasConfig config(
+ module->entry_computation()->root_instruction()->shape());
+
+ TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0,
+ /*param_index=*/{0}));
+
+ TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0,
+ /*param_index=*/{1}));
+
+ expect_aliased(/*output_index=*/{0}, /*param_number=*/0,
+ /*param_index=*/{0}, config);
+
+ expect_aliased(/*output_index=*/{1}, /*param_number=*/0,
+ /*param_index=*/{1}, config);
+
+ expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1,
+ /*param_index=*/{}, config);
+
+ expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0,
+ /*param_index=*/{}, config);
+}
+
+TEST_F(HloInputOutputAliasConfigTest, InputDoNotAliasTwice) {
+ const string module_str = R"(
+HloModule TEST
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT root = (f32[], f32[]) tuple(%a, %b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+
+ HloInputOutputAliasConfig config(
+ module->entry_computation()->root_instruction()->shape());
+
+ TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0,
+ /*param_index=*/{}));
+
+ TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0,
+ /*param_index=*/{}));
+
+ ASSERT_IS_NOT_OK(config.Verify(*module));
+}
+
+TEST_F(HloInputOutputAliasConfigTest, OutputDoNotAliasTwice) {
+ const string module_str = R"(
+HloModule TEST
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT root = (f32[], f32[]) tuple(%a, %b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+
+ HloInputOutputAliasConfig config(
+ module->entry_computation()->root_instruction()->shape());
+
+ TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0,
+ /*param_index=*/{}));
+
+ ASSERT_IS_NOT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1,
+ /*param_index=*/{}));
+}
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 93e04eb3db..547f74a0ed 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -73,6 +73,8 @@ HloComputation* HloModule::AddComputationInternal(
config_.SetDefaultComputationLayout(
entry_computation_->ComputeProgramShape());
}
+ input_output_alias_config_ = HloInputOutputAliasConfig(
+ entry_computation_->root_instruction()->shape());
}
if (uniquify_identifiers) {
@@ -252,6 +254,9 @@ HloModuleProto HloModule::ToProto() const {
if (has_schedule()) {
*proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
}
+
+ *proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
+
return proto;
}
@@ -328,6 +333,10 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
}
TF_RET_CHECK(module->entry_computation_ != nullptr);
+ TF_ASSIGN_OR_RETURN(module->input_output_alias_config_,
+ HloInputOutputAliasConfig::CreateFromProto(
+ module.get(), proto.input_output_alias()));
+
// Because we didn't uniquify the names or the ids, double-check that the
// instruction and computation names and ids are unique from the proto.
absl::flat_hash_set<string> computation_names;
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 735804e827..9b9dc3ba9f 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
@@ -212,6 +213,15 @@ class HloModule {
return result;
}
+ // input_output_alias_config indicates the list of aliased buffers that are
+ // expected from the module.
+ HloInputOutputAliasConfig& input_output_alias_config() {
+ return input_output_alias_config_;
+ }
+ const HloInputOutputAliasConfig& input_output_alias_config() const {
+ return input_output_alias_config_;
+ }
+
// Returns the number of unique intruction ids given out. All ids up to
// this point are guaranteed to be in the range [0..NumUniqueInstructionIds())
int NumUniqueInstructionIds() const { return next_unique_id_; }
@@ -284,6 +294,10 @@ class HloModule {
// sequential order of instructions for each non-fusion computation in the
// module.
absl::optional<HloSchedule> schedule_;
+
+ // alias_config indicates the alias information of input/output buffers that
+ // are expected from the module.
+ HloInputOutputAliasConfig input_output_alias_config_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index be3bee5975..2902a11a42 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -1220,6 +1220,8 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(module->schedule().Verify());
}
+ TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(*module));
+
return false;
}
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 73f541d505..51cedce7f0 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -72,7 +72,7 @@ class ShapeIndex {
void push_back(int64 value) { indices_.push_back(value); }
void pop_back() { indices_.pop_back(); }
- // push_front is O(n^2), but shapes don't usually have a ton of dimensions.
+ // push_front is O(n), but shapes don't usually have a ton of dimensions.
void push_front(int64 value) { indices_.insert(indices_.begin(), value); }
using container_type = absl::InlinedVector<int64, 2>;