diff options
author | Yunxing Dai <yunxing@google.com> | 2018-10-08 21:18:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-08 21:23:03 -0700 |
commit | 375c109659d2d0e6265447dffdeb460693b3cccf (patch) | |
tree | a6f09b6472cff1ade7fc91c1ff0d5e3f473da774 | |
parent | d58712b7fc8de0e1f87fe2ea5221bc3c85230ed3 (diff) |
[XLA] Introduce input/output alias config.
- This CL intruduces input/output alias config in HLO module that allows any HLO pass to configure it. Once the alias_config is set, each backend needs to follow the contract during execution time to make sure the input and output are indeed aliased.
- Copy insertion / buffer assignment and alias analysis has been updated to correctly honor the config and avoid any possible liveness interference.
PiperOrigin-RevId: 216299501
16 files changed, 1037 insertions, 25 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 2b292ed053..26ebb88e96 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -294,6 +294,7 @@ cc_library( srcs = [ "dfs_hlo_visitor.cc", "hlo_computation.cc", + "hlo_input_output_alias_config.cc", "hlo_instruction.cc", "hlo_instructions.cc", "hlo_module.cc", @@ -308,6 +309,7 @@ cc_library( "hlo_clone_context.h", "hlo_computation.h", "hlo_domain_metadata.h", + "hlo_input_output_alias_config.h", "hlo_instruction.h", "hlo_instructions.h", "hlo_module.h", @@ -1268,6 +1270,25 @@ tf_cc_test( ], ) +tf_cc_test( + name = "hlo_input_output_alias_config_test", + srcs = ["hlo_input_output_alias_config_test.cc"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_memory_scheduler", + ":hlo_ordering", + ":hlo_parser", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + ], +) + cc_library( name = "hlo_memory_scheduler", srcs = ["hlo_memory_scheduler.cc"], diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 2c2d1626c2..d5d6a044a8 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -239,7 +239,7 @@ BufferAllocation::Slice BufferAllocation::GetSlice( void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size) { - VLOG(4) << "Trying to add " << buffer << " to " << this; + VLOG(4) << "Trying to add " << buffer << " to allocation #" << index(); CHECK(assigned_buffers_.count(&buffer) == 0) << "LogicalBuffer " << buffer << " already assigned to allocation " << index_; @@ -784,21 +784,6 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, } } - if (allow_input_output_aliasing_ && allocation->maybe_live_out()) { - const HloComputation* entry_computation = - assignment->module_->entry_computation(); - for (auto param : entry_computation->parameter_instructions()) { - for (auto& param_buffer : - assignment->points_to_analysis().GetBuffersDefinedByInstruction( - param)) { - if (assignment->liveness().MayInterfere(*param_buffer, buffer)) { - VLOG(4) << "Can't assign: Parameter interference with result"; - return false; - } - } - } - } - // If the buffer is live out of the computation then it should only be // assigned a buffer which exactly fits the result to avoid wasting memory // (result buffers can have arbitrary lifetimes). @@ -1434,13 +1419,28 @@ BufferAssigner::MergeColocatedBufferSets( // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated // in the same allocation (currently just supports kWhile, kCall, and -// kConditional). +// kConditional and input output aliasing). void BufferAssigner::BuildColocatedBufferSets( const HloModule* module, const BufferLiveness& buffer_liveness, const LogicalBuffer::SizeFunction& buffer_size, std::vector<ColocatedBufferSet>* colocated_buffer_sets) { const TuplePointsToAnalysis& points_to_analysis = buffer_liveness.points_to_analysis(); + + // Set up colocated buffer set for input and output. + module->input_output_alias_config().ForEachAlias( + [&](const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index) { + std::vector<const LogicalBuffer*> colocated_set; + AddBufferToColocatedSet(module->entry_computation()->root_instruction(), + output_index, points_to_analysis, + &colocated_set); + AddBufferToColocatedSet( + module->entry_computation()->parameter_instruction(param_number), + param_index, points_to_analysis, &colocated_set); + AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); + }); + for (const HloComputation* computation : module->MakeComputationPostOrder()) { if (computation->IsFusionComputation()) { continue; diff --git a/tensorflow/compiler/xla/service/buffer_value.h b/tensorflow/compiler/xla/service/buffer_value.h index 69b3646356..11d8abc5ba 100644 --- a/tensorflow/compiler/xla/service/buffer_value.h +++ b/tensorflow/compiler/xla/service/buffer_value.h @@ -141,6 +141,9 @@ class BufferValue { // operator< is required for std::set. bool operator<(const BufferValue& other) const { return id_ < other.id_; } + bool operator==(const BufferValue& other) const { return id_ == other.id_; } + bool operator!=(const BufferValue& other) const { return id_ != other.id_; } + virtual string ToString() const = 0; // TODO(lauj) rename LogicalBufferProto to BufferValueProto. diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index f35324aa35..cfe025fdd1 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -40,10 +40,12 @@ namespace { using absl::StrAppend; -bool IsEntryParameterValue(const HloValue& value) { +bool IsReadonlyEntryParameterValue(const HloValue& value) { const HloComputation* computation = value.defining_instruction()->parent(); return value.defining_instruction()->opcode() == HloOpcode::kParameter && - computation == computation->parent()->entry_computation(); + computation == computation->parent()->entry_computation() && + !computation->parent()->input_output_alias_config().ParameterHasAlias( + value.defining_instruction()->parameter_number()); } bool IsConstantValue(const HloValue& value) { @@ -51,7 +53,7 @@ bool IsConstantValue(const HloValue& value) { } bool ValueIsReadOnly(const HloValue& value) { - return IsConstantValue(value) || IsEntryParameterValue(value); + return IsConstantValue(value) || IsReadonlyEntryParameterValue(value); } // Data structure describing the action which should be taken on parts of a @@ -332,6 +334,81 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, return Status::OK(); } +// Conservatively adds copies before root instruction of entry computation and +// each aliased parameter to resolve interference of aliased input and output +// buffer. We later rely on the CopyRemover to drop the unnecessary ones. +Status AddCopiesForAliasedInputOutputs(HloModule* module) { + HloComputation* entry = module->entry_computation(); + HloInstruction* root = entry->root_instruction(); + + ShapeTree<bool> output_indices_to_copy(root->shape()); + std::vector<ShapeTree<HloInstruction*>> copied_parameters; + bool has_alias = false; + for (auto* param : entry->parameter_instructions()) { + bool param_has_alias = false; + ShapeTree<bool> param_indices_to_copy(param->shape()); + + module->input_output_alias_config().ForEachAlias( + [&](const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index) { + if (param_number == param->parameter_number()) { + param_has_alias = true; + *(param_indices_to_copy.mutable_element(param_index)) = true; + *(output_indices_to_copy.mutable_element(output_index)) = true; + } + }); + + if (!param_has_alias) { + continue; + } + + has_alias = true; + // Store a snapshot of users before DeepCopyInstruction, as + // DeepCopyInstruction introduces new users of the instruction. + std::vector<HloInstruction*> users = param->users(); + ShapeTree<HloInstruction*> param_copy_tree(param->shape(), + /*init_value=*/nullptr); + TF_ASSIGN_OR_RETURN(HloInstruction * copied, + entry->DeepCopyInstruction( + param, ¶m_indices_to_copy, ¶m_copy_tree)); + for (HloInstruction* user : users) { + TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied)); + } + + copied_parameters.push_back(param_copy_tree); + } + + if (!has_alias) { + return Status::OK(); + } + + // Add copies before root instruction. + ShapeTree<HloInstruction*> output_copy_tree(root->shape(), + /*init_value=*/nullptr); + + TF_ASSIGN_OR_RETURN(HloInstruction * root_copied, + root->parent()->DeepCopyInstruction( + root, &output_indices_to_copy, &output_copy_tree)); + + // Add control dependencies between the input/output copies. + TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus( + [&](const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& input_index) -> Status { + HloInstruction* from = + copied_parameters[param_number].element(input_index); + HloInstruction* to = output_copy_tree.element(output_index); + + TF_RET_CHECK(from != nullptr); + TF_RET_CHECK(to != nullptr); + TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to)); + return Status::OK(); + })); + + entry->set_root_instruction(root_copied); + + return Status::OK(); +} + // Removes any control dependencies to or from the given instruction. Status StripControlDependenciesFrom(HloInstruction* instruction) { while (!instruction->control_successors().empty()) { @@ -953,6 +1030,8 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { } } } + + TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 892d0d7b54..3096206c34 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -1351,6 +1351,189 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) { EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); } +TEST_F(CopyInsertionTest, CrossingParameters) { + // Test a case where two parameters' dataflow cross with each other while + // input and output are aliased with same index: + // + // (p0 , p1) + // | \ /| + // | \ / | + // alias X alias + // | / \ | + // | / \| + // (p1 , p0) + auto module = CreateNewModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 4); +} + +TEST_F(CopyInsertionTest, ParametersAliasing) { + // Test a case where two parameters' dataflow don't interfere with each other + // while aliased. + // + // (p0 , p1) + // | | + // | | + // alias alias + // | | + // | | + // (p0 , p1) + auto module = CreateNewModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + InsertCopies(module.get()); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(op::GetTupleElement(param, 0)), + op::Copy(op::GetTupleElement(param, 1)))); + + EXPECT_EQ(CountCopies(*module), 2); +} + +TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) { + // Test a case where one parameter is aliased with result while another one + // isn't. + // + // (p0 , p1) + // | | + // | | + // alias | + // | | + // | | + // (p0 , p1) + auto module = CreateNewModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + InsertCopies(module.get()); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(op::GetTupleElement(param, 0)), + op::Copy(op::GetTupleElement(param, 1)))); + + EXPECT_EQ(CountCopies(*module), 2); +} + +TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) { + // Test a case where one parameter is aliased with result while another one + // isn't. + // + // +-- (p0 , p1) + // | | | + // | | | + // alias Negate Negate + // | | | + // | | | + // +-- (p0 , p1) + auto module = CreateNewModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + + auto negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); + + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); + builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 0); +} + +TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) { + // Test a case where one parameter is aliased with result while another one + // isn't. + // + // +-- (p0 , p1) + // | | | + // | | | + // alias Negate Negate + // | | | + // | Add----+ + // | | | + // +-- (p0 , p1) + auto module = CreateNewModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + + auto negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); + + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, negate0, negate1)); + builder.AddInstruction(HloInstruction::CreateTuple({add, negate1})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 0); +} + TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { // Test a while instruction with a body which permutes its tuple parameter // elements and applies one operation to one of the elements. The addition of diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index a0eb9e6ddc..82c8fb1904 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -225,6 +225,32 @@ message HloScheduleProto { map<int64, InstructionSequence> sequences = 1; } +message HloInputOutputAliasProto { + // The following proto describes a pair of aliased an input + // (described by parameter number and a ShapeIndex of the parameter) + // and an output (described by a ShapeIndex of the root + // instruction). For example: + // + // entry = { + // output_shape_index={1}, + // parameter_number=0, + // parameter_shape_index={1, 2}, + // } + // + // This entry indicates that the first paremter's {1, 2} element is + // aliased with the {1} element of the root instruction. + message AliasEntryProto { + // ShapeIndex of the root hlo. + repeated int64 output_shape_index = 1; + // Number of the parameter in entry computation. + int64 parameter_number = 2; + // ShapeIndex of the parameter instruction. + repeated int64 parameter_shape_index = 3; + } + + repeated AliasEntryProto entries = 1; +} + // Serialization of HloModule. message HloModuleProto { string name = 1; @@ -243,6 +269,9 @@ message HloModuleProto { // The schedule for this module. HloScheduleProto schedule = 7; + + // Describes alias information between inputs and outputs. + HloInputOutputAliasProto input_output_alias = 8; } // Serialization of LogicalBuffer. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index c3da12e273..cf8e6594cb 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -59,8 +59,9 @@ class BufferValueMap { // construction process. using BufferNumber = int64; - explicit BufferValueMap(const HloDataflowAnalysis& dataflow) - : dataflow_(dataflow) { + explicit BufferValueMap(HloModule* module, + const HloDataflowAnalysis& dataflow) + : module_(module), dataflow_(dataflow) { buffers_.reserve(dataflow_.values().size()); value_to_buffer_number_.reserve(dataflow_.values().size()); for (const HloValue* value : dataflow_.values()) { @@ -171,6 +172,42 @@ class BufferValueMap { return value_to_buffer_number_.at(&value); } + void ComputeInputOutputAliasedBuffers( + const HloValue& value, std::vector<BufferNumber>* aliased_buffers) { + // Get parameter value from an aliased_input object. + const auto get_parameter_value = + [this](const std::pair<int64, ShapeIndex>& aliased_input) + -> const HloValue& { + int64 param_number = aliased_input.first; + const ShapeIndex& param_index = aliased_input.second; + return dataflow_.GetUniqueValueAt( + module_->entry_computation()->parameter_instruction(param_number), + param_index); + }; + + // If the value shows up in a root instruction, alias it with parameter + // intruction. + for (const HloPosition& pos : value.positions()) { + if (pos.instruction == module_->entry_computation()->root_instruction()) { + ShapeIndex output_index = pos.index; + + auto aliased_input = + module_->input_output_alias_config().GetAliasedParameter( + output_index); + if (aliased_input) { + aliased_buffers->push_back( + GetBufferForValue(get_parameter_value(*aliased_input))); + } + } + } + + // If the value is parameter instruction itself, alias it with itself. + if (value.instruction()->opcode() == HloOpcode::kParameter && + value.instruction()->parent() == module_->entry_computation()) { + aliased_buffers->push_back(GetBufferForValue(value)); + } + } + void ComputeWhileAliasedBuffers(const HloValue& value, std::vector<BufferNumber>* aliased_buffers) { VLOG(3) << "Compute kWhile aliases"; @@ -278,6 +315,7 @@ class BufferValueMap { VLOG(2) << "Use of value " << value.ToShortString() << ": " << use; } std::vector<BufferNumber> aliased_buffers; + ComputeInputOutputAliasedBuffers(value, &aliased_buffers); ComputeWhileAliasedBuffers(value, &aliased_buffers); ComputeConditionalAliasedBuffers(value, &aliased_buffers); // Uniquify aliased buffers. @@ -288,6 +326,8 @@ class BufferValueMap { return aliased_buffers; } + HloModule* module_; + // Dataflow analysis used to construct the buffer map. const HloDataflowAnalysis& dataflow_; @@ -461,7 +501,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run( /*bitcast_defines_value=*/false, fusion_can_share_buffer)); - BufferValueMap buffer_map(alias_analysis->dataflow_analysis()); + BufferValueMap buffer_map(module, alias_analysis->dataflow_analysis()); buffer_map.MergeAliasedBuffers(); // Create a vector of HloBuffers, one for each set of values in the diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 0cd0ab36fc..5c8d97b2d1 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -217,6 +217,181 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) { EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } +TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) { + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + + auto negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); + + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); + module_->AddEntryComputation(builder.Build()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + + // Cannot alias an output twice. + ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_EQ(analysis.GetUniqueBufferAt(gte0), + analysis.GetUniqueBufferAt(tuple, /*index=*/{0})); + + EXPECT_EQ(analysis.GetUniqueBufferAt(gte1), + analysis.GetUniqueBufferAt(tuple, /*index=*/{1})); +} + +TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) { + // parameter 0 aliased with output 1 and parameter 1 aliased with output 0. + // + // (p0 , p1) + // \ / + // \ / + // alias X + // / \ + // / \ + // (p0 , p1) + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + module_->AddEntryComputation(builder.Build()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1})); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); + + // Cannot alias an output twice. + ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + // Every Ops in this graph are aliased with each other. + EXPECT_EQ(analysis.GetUniqueBufferAt(gte0), + analysis.GetUniqueBufferAt(tuple, /*index=*/{0})); + EXPECT_EQ(analysis.GetUniqueBufferAt(gte0), + analysis.GetUniqueBufferAt(tuple, /*index=*/{1})); + + EXPECT_EQ(analysis.GetUniqueBufferAt(gte1), + analysis.GetUniqueBufferAt(tuple, /*index=*/{0})); + EXPECT_EQ(analysis.GetUniqueBufferAt(gte1), + analysis.GetUniqueBufferAt(tuple, /*index=*/{1})); +} + +TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) { + // Test a simple single while instruction can be aliased with input and output + // of the computation. + // + // body((F32[], F32[]) %tuple_param): + // %add = Add(%tuple_param{0}, %tuple_param{1}) + // return Tuple(%tuple_param{0}, %add) + // + // condition((F32[], F32[]) %tuple_param): + // return Constant(false) + // + // entry: + // %param1 = param1 + // %while = While(%param1, body, condition) + // %while_1 = GTE(%while, 0) + // %while_2 = GTE(%while, 1) + // %negate_1 = Negate(%while_1) + // %negate_2 = Negate(%while_2) + // return Tuple(negate_1, negate_2) + // + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Element 0 passes transparently through the body. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); + auto body_tuple = body_builder.AddInstruction( + HloInstruction::CreateTuple({body_element_0, add})); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); + + // Condition computation trivially returns a constant "false". + auto cond_builder = HloComputation::Builder("condition"); + auto cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); + HloComputation* condition = + module_->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, param)); + auto while_element_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 0)); + auto while_element_2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 1)); + auto negate_1 = builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, while_element_1)); + auto negate_2 = builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, while_element_2)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({negate_1, negate_2})); + module_->AddEntryComputation(builder.Build()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_THAT( + GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})), + UnorderedElementsAre(GetValueDefinedAt(param, {1}), + GetValueDefinedAt(xla_while, /*index=*/{1}), + GetValueDefinedAt(body_param, {1}), + GetValueDefinedAt(cond_param, {1}), + GetValueDefinedAt(add), + GetValueDefinedAt(negate_2))); + + EXPECT_THAT( + analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).ComputePositions(), + UnorderedElementsAre( + HloPosition{param, {1}}, HloPosition{xla_while, {1}}, + HloPosition{while_element_2, {}}, HloPosition{body_param, {1}}, + HloPosition{body_element_1, {}}, HloPosition{add, {}}, + HloPosition{body_tuple, {1}}, HloPosition{tuple, {1}}, + HloPosition{cond_param, {1}}, HloPosition{negate_2, {}})); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); +} + TEST_F(HloAliasAnalysisTest, SingleCall) { // Test a single call of a subcomputation. The subcomputation adds its two // array-shaped parameters. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index c22adcdd8d..f401eac016 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -126,7 +126,7 @@ bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, const HloValue& HloDataflowAnalysis::GetValueDefinedAt( const HloInstruction* instruction, const ShapeIndex& index) const { - CHECK(ValueIsDefinedAt(instruction, index)); + CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString(); return GetUniqueValueAt(instruction, index); } diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc new file mode 100644 index 0000000000..9ad98e5038 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc @@ -0,0 +1,172 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { +Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, + int64 param_number, + const ShapeIndex& param_index) { + // Output can't be aliased with multiple parameters. + TF_RET_CHECK(!alias_.element(output_index)); + (*alias_.mutable_element(output_index)) = + std::make_pair(param_number, param_index); + return Status::OK(); +} + +HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const { + HloInputOutputAliasProto result; + alias_.ForEachElement( + [&](const ShapeIndex& index, + const absl::optional<std::pair<int64, ShapeIndex>>& data) { + if (data) { + HloInputOutputAliasProto::AliasEntryProto entry; + for (int64 i : index) { + entry.add_output_shape_index(i); + } + entry.set_parameter_number(data->first); + for (int64 i : data->second) { + entry.add_parameter_shape_index(i); + } + result.add_entries()->Swap(&entry); + } + }); + return result; +} + +StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto( + const HloModule* module, const HloInputOutputAliasProto& proto) { + HloInputOutputAliasConfig result( + module->entry_computation()->root_instruction()->shape()); + for (const HloInputOutputAliasProto::AliasEntryProto& entry : + proto.entries()) { + ShapeIndex output_index(entry.output_shape_index().begin(), + entry.output_shape_index().end()); + + int64 param_number = entry.parameter_number(); + ShapeIndex param_index(entry.parameter_shape_index().begin(), + entry.parameter_shape_index().end()); + TF_RETURN_IF_ERROR( + result.SetUpAlias(output_index, param_number, param_index)); + } + + return result; +} + +string HloInputOutputAliasConfig::ToString() const { + std::vector<string> pieces; + pieces.push_back("HloInputOutputAliasConfig"); + + ForEachAlias([&](const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index) { + pieces.push_back(absl::StrFormat( + " OutputIndex %s is aliased with parameter %lld at %s:", + output_index.ToString(), param_number, param_index.ToString())); + }); + + return absl::StrJoin(pieces, "\n"); +} + +bool HloInputOutputAliasConfig::ParameterHasAlias(int64 param_number) const { + bool output = false; + alias_.ForEachElement( + [&](const xla::ShapeIndex&, + absl::optional<std::pair<int64, ShapeIndex>> alias) { + if (alias && alias->first == param_number) { + output = true; + } + }); + return output; +} + +absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput( + int64 param_number, const ShapeIndex& param_index) const { + absl::optional<ShapeIndex> output; + alias_.ForEachElement( + [&](const xla::ShapeIndex& output_index, + absl::optional<std::pair<int64, ShapeIndex>> alias) { + if (alias && alias->first == param_number && + alias->second == param_index) { + output = output_index; + } + }); + return output; +} + +absl::optional<std::pair<int64, ShapeIndex>> +HloInputOutputAliasConfig::GetAliasedParameter( + const ShapeIndex& output_index) const { + CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)); + return alias_.element(output_index); +} + +void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const { + alias_.ForEachElement( + [&](const ShapeIndex& output_index, + absl::optional<std::pair<int64, ShapeIndex>> aliased) { + if (aliased) { + fn(output_index, aliased->first, aliased->second); + } + }); +} + +Status HloInputOutputAliasConfig::ForEachAliasWithStatus( + AliasFnWithStatus fn) const { + return alias_.ForEachElementWithStatus( + [&](const ShapeIndex& output_index, + absl::optional<std::pair<int64, ShapeIndex>> aliased) { + if (aliased) { + TF_RETURN_IF_ERROR(fn(output_index, aliased->first, aliased->second)); + } + return Status::OK(); + }); +} + +Status HloInputOutputAliasConfig::Verify(const HloModule& module) const { + std::vector<ShapeTree<bool>> param_has_seen; + const HloComputation* entry = module.entry_computation(); + for (int64 i = 0; i < entry->num_parameters(); ++i) { + HloInstruction* param = entry->parameter_instruction(i); + param_has_seen.emplace_back(param->shape()); + } + return ForEachAliasWithStatus([&](const ShapeIndex& output_index, + int64 param_number, + const ShapeIndex& param_index) -> Status { + const HloInstruction* root = entry->root_instruction(); + + const Shape& param_shape = + entry->parameter_instruction(param_number)->shape(); + const Shape& output_shape = root->shape(); + TF_RET_CHECK(entry->num_parameters() > param_number); + TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, param_index)); + TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index)); + + // Check each param_number and param_index pair only show up once. No + // input can be aliased with output buffers. + TF_RET_CHECK(param_has_seen[param_number].element(param_index) == false); + + *(param_has_seen[param_number].mutable_element(param_index)) = true; + + return Status::OK(); + }); +} + +std::ostream& operator<<(std::ostream& out, + const HloInputOutputAliasConfig& config) { + out << config.ToString(); + return out; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h new file mode 100644 index 0000000000..02c46f65c8 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_ + +#include <utility> + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +class HloModule; + +// This class specifies the alias map from output index to parameter number and +// parameter index in the entry computation. +class HloInputOutputAliasConfig { + public: + HloInputOutputAliasConfig() = default; + + explicit HloInputOutputAliasConfig(Shape shape) : alias_(shape) {} + + virtual ~HloInputOutputAliasConfig() = default; + + // Sets up alias config from `output_index` to `param_index` at + // `param_number`. + Status SetUpAlias(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index); + + // Returns true if the given parameter is aliased with one of the output + // buffers. + bool ParameterHasAlias(int64 param_number) const; + + // (De)Serializes an HloInputOutoutAliasConfig to/from an + // HloInputOutoutAliasProto. + HloInputOutputAliasProto ToProto() const; + + static StatusOr<HloInputOutputAliasConfig> CreateFromProto( + const HloModule* module, const HloInputOutputAliasProto& proto); + + // Returns the output index that the given parameter and parameter index is + // aliased with. A nullopt is returned if there is no output that is aliased + // with the parameter number and index. + absl::optional<ShapeIndex> GetAliasedOutput( + int64 param_number, const ShapeIndex& param_index) const; + + // Returns the number of parameter and index of the parameter buffer that the + // given output buffer index is aliased with. A nullopt is returned if there + // is no parameter is aliased with the specific output. + absl::optional<std::pair<int64, ShapeIndex>> GetAliasedParameter( + const ShapeIndex& output_index) const; + + using AliasFn = + std::function<void(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index)>; + + // Iterates through each aliased output and input. + void ForEachAlias(AliasFn fn) const; + + using AliasFnWithStatus = + std::function<Status(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index)>; + + // Verifies that the given config is valid for the given module. + // Specifically, the config's input and output should be in-bound and size of + // the aliased buffers should match. + Status Verify(const HloModule& module) const; + + Status ForEachAliasWithStatus(AliasFnWithStatus fn) const; + + string ToString() const; + + private: + // A ShapeTree which indicates the list of buffers that's expected to be + // aliased. The key on this shape tree represents the output index. The value + // is a pair of parameter number and index into the buffer. If the value is + // nullopt, it means there is no parameter aliasing for this output. + ShapeTree<absl::optional<std::pair<int64, ShapeIndex>>> alias_; +}; + +std::ostream& operator<<(std::ostream& out, + const HloInputOutputAliasConfig& config); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_ diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc new file mode 100644 index 0000000000..3b61ff04e6 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc @@ -0,0 +1,184 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" + +#include <memory> +#include <string> + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { +class HloInputOutputAliasConfigTest : public HloTestBase { + protected: + void expect_aliased(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index, + const HloInputOutputAliasConfig& config) { + absl::optional<ShapeIndex> aliased_output = + config.GetAliasedOutput(param_number, param_index); + + EXPECT_TRUE(aliased_output); + EXPECT_EQ(aliased_output.value(), output_index); + + absl::optional<std::pair<int64, ShapeIndex>> aliased_param = + config.GetAliasedParameter(output_index); + + EXPECT_TRUE(aliased_param); + EXPECT_EQ(aliased_param.value(), std::make_pair(param_number, param_index)); + } + + void expect_not_aliased(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index, + const HloInputOutputAliasConfig& config) { + absl::optional<ShapeIndex> aliased_output = + config.GetAliasedOutput(param_number, param_index); + + EXPECT_FALSE(aliased_output && aliased_output == output_index); + + absl::optional<std::pair<int64, ShapeIndex>> aliased_param = + config.GetAliasedParameter(output_index); + + EXPECT_FALSE(aliased_param && aliased_param->first == param_number && + aliased_param->second == param_index); + } +}; + +TEST_F(HloInputOutputAliasConfigTest, SimpleAliasing) { + const string module_str = R"( +HloModule TEST + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT root = (f32[], f32[]) tuple(%a, %b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(module_str)); + + HloInputOutputAliasConfig config( + module->entry_computation()->root_instruction()->shape()); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{})); + + expect_aliased(/*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{}, config); + + expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1, + /*param_index=*/{}, config); + + expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{}, config); +} + +TEST_F(HloInputOutputAliasConfigTest, SimpleAliasingWithTupleInput) { + const string module_str = R"( +HloModule TEST + +ENTRY main { + param = (f32[], f32[]) parameter(0) + gte1 = f32[] get-tuple-element(%param), index=0 + gte2 = f32[] get-tuple-element(%param), index=1 + ROOT root = (f32[], f32[]) tuple(%gte1, %gte2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(module_str)); + + HloInputOutputAliasConfig config( + module->entry_computation()->root_instruction()->shape()); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{0})); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{1})); + + expect_aliased(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{0}, config); + + expect_aliased(/*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{1}, config); + + expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1, + /*param_index=*/{}, config); + + expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{}, config); +} + +TEST_F(HloInputOutputAliasConfigTest, InputDoNotAliasTwice) { + const string module_str = R"( +HloModule TEST + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT root = (f32[], f32[]) tuple(%a, %b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(module_str)); + + HloInputOutputAliasConfig config( + module->entry_computation()->root_instruction()->shape()); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{})); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{})); + + ASSERT_IS_NOT_OK(config.Verify(*module)); +} + +TEST_F(HloInputOutputAliasConfigTest, OutputDoNotAliasTwice) { + const string module_str = R"( +HloModule TEST + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT root = (f32[], f32[]) tuple(%a, %b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(module_str)); + + HloInputOutputAliasConfig config( + module->entry_computation()->root_instruction()->shape()); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{})); + + ASSERT_IS_NOT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{})); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 93e04eb3db..547f74a0ed 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -73,6 +73,8 @@ HloComputation* HloModule::AddComputationInternal( config_.SetDefaultComputationLayout( entry_computation_->ComputeProgramShape()); } + input_output_alias_config_ = HloInputOutputAliasConfig( + entry_computation_->root_instruction()->shape()); } if (uniquify_identifiers) { @@ -252,6 +254,9 @@ HloModuleProto HloModule::ToProto() const { if (has_schedule()) { *proto.mutable_schedule() = schedule().ToProto().ValueOrDie(); } + + *proto.mutable_input_output_alias() = input_output_alias_config().ToProto(); + return proto; } @@ -328,6 +333,10 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( } TF_RET_CHECK(module->entry_computation_ != nullptr); + TF_ASSIGN_OR_RETURN(module->input_output_alias_config_, + HloInputOutputAliasConfig::CreateFromProto( + module.get(), proto.input_output_alias())); + // Because we didn't uniquify the names or the ids, double-check that the // instruction and computation names and ids are unique from the proto. absl::flat_hash_set<string> computation_names; diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 735804e827..9b9dc3ba9f 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" @@ -212,6 +213,15 @@ class HloModule { return result; } + // input_output_alias_config indicates the list of aliased buffers that are + // expected from the module. + HloInputOutputAliasConfig& input_output_alias_config() { + return input_output_alias_config_; + } + const HloInputOutputAliasConfig& input_output_alias_config() const { + return input_output_alias_config_; + } + // Returns the number of unique intruction ids given out. All ids up to // this point are guaranteed to be in the range [0..NumUniqueInstructionIds()) int NumUniqueInstructionIds() const { return next_unique_id_; } @@ -284,6 +294,10 @@ class HloModule { // sequential order of instructions for each non-fusion computation in the // module. absl::optional<HloSchedule> schedule_; + + // alias_config indicates the alias information of input/output buffers that + // are expected from the module. + HloInputOutputAliasConfig input_output_alias_config_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index be3bee5975..2902a11a42 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -1220,6 +1220,8 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(module->schedule().Verify()); } + TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(*module)); + return false; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 73f541d505..51cedce7f0 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -72,7 +72,7 @@ class ShapeIndex { void push_back(int64 value) { indices_.push_back(value); } void pop_back() { indices_.pop_back(); } - // push_front is O(n^2), but shapes don't usually have a ton of dimensions. + // push_front is O(n), but shapes don't usually have a ton of dimensions. void push_front(int64 value) { indices_.insert(indices_.begin(), value); } using container_type = absl::InlinedVector<int64, 2>; |