diff options
17 files changed, 377 insertions, 282 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index ae2df587dc..156cb85f66 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -125,6 +125,7 @@ cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index d623aa2caf..f1fc608caa 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -425,7 +425,8 @@ Status GatherComputationsByAllocationType( } for (auto& instruction : computation->instructions()) { - for (auto* subcomputation : instruction->MakeCalledComputationsSet()) { + for (HloComputation* subcomputation : + instruction->called_computations()) { switch (instruction->opcode()) { case HloOpcode::kCall: case HloOpcode::kWhile: diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 9d3e024088..b5a2936b67 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -244,7 +244,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, // *) Is a loop fusion instruction (with DynamicUpdateSlice fused root) where // the singleton use of 'a' at 'a.index' is the fused root at operand 0. for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { - if (alias.instruction()->users().count(b.instruction()) > 0 && + if (b.instruction()->IsUserOf(alias.instruction()) && !CanShareOperandBufferWithUser(alias.instruction(), alias.index(), b.instruction(), b.index(), points_to_analysis())) { diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 633016994a..4c26b2de12 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -90,8 +90,6 @@ class CopyInsertionTest : public HloTestBase { }; }; -#define EXPECT_INST(A, E...) EXPECT_EQ(A, (std::set<HloInstruction*>{E})) - TEST_F(CopyInsertionTest, SingleParameter) { auto builder = HloComputation::Builder(TestName()); HloInstruction* x = builder.AddInstruction( @@ -99,7 +97,7 @@ TEST_F(CopyInsertionTest, SingleParameter) { HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({x})); - EXPECT_INST(x->users(), tuple); + ExpectEqUnordered(x->users(), {tuple}); HloModule module(TestName()); module.AddEntryComputation(builder.Build()); @@ -127,7 +125,7 @@ TEST_F(CopyInsertionTest, SingleConstant) { HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant})); - EXPECT_INST(constant->users(), tuple); + ExpectEqUnordered(constant->users(), {tuple}); HloModule module(TestName()); module.AddEntryComputation(builder.Build()); @@ -221,9 +219,9 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); - EXPECT_INST(constant1->users(), tuple1); - EXPECT_INST(constant2->users(), tuple1, tuple2); - EXPECT_INST(constant3->users(), tuple2); + ExpectEqUnordered(constant1->users(), {tuple1}); + ExpectEqUnordered(constant2->users(), {tuple1, tuple2}); + ExpectEqUnordered(constant3->users(), {tuple2}); HloModule module(TestName()); module.AddEntryComputation(builder.Build()); @@ -261,7 +259,7 @@ TEST_F(CopyInsertionTest, BitcastParameter) { HloModule module(TestName()); module.AddEntryComputation(builder.Build()); - EXPECT_INST(x->users(), bitcast); + ExpectEqUnordered(x->users(), {bitcast}); HloInstruction* old_root = module.entry_computation()->root_instruction(); InsertCopies(&module); @@ -289,7 +287,7 @@ TEST_F(CopyInsertionTest, BitcastConstant) { HloModule module(TestName()); module.AddEntryComputation(builder.Build()); - EXPECT_INST(constant->users(), bitcast); + ExpectEqUnordered(constant->users(), {bitcast}); HloInstruction* old_root = module.entry_computation()->root_instruction(); InsertCopies(&module); @@ -316,8 +314,7 @@ TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { HloModule module(TestName()); module.AddEntryComputation(builder.Build()); - EXPECT_EQ(1, x->user_count()); - EXPECT_EQ(*x->users().begin(), bitcast); + ExpectEqUnordered(x->users(), {bitcast}); HloInstruction* old_root = module.entry_computation()->root_instruction(); InsertCopies(&module); diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 3931e909a0..0ee117afcd 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include <algorithm> +#include <vector> #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/instruction_fusion.h" @@ -249,7 +250,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { return Status::OK(); } // Merge fused instructions from 'fusion' into each user. - std::set<HloInstruction*> users = fusion->users(); + std::vector<HloInstruction*> users = fusion->users(); for (HloInstruction* user : users) { user->MergeFusionInstruction(fusion); changed_ = true; diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 4d722970c5..76702f52e0 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -85,7 +85,8 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run( } for (const BufferAlias& alias : points_to_analysis.GetBufferAliases(*buffer)) { - const std::set<HloInstruction*>& users = alias.instruction()->users(); + const std::vector<HloInstruction*>& users = + alias.instruction()->users(); if (!users.empty()) { live_buffers[buffer].insert(users.begin(), users.end()); } diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 366ecb5a52..c47b49a0c6 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -142,6 +142,12 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) { TF_RET_CHECK(instruction->user_count() == 0) << "instruction " << instruction->name() << " has users and cannot be removed"; + TF_RET_CHECK(instruction->control_predecessors().empty()) + << "instruction " << instruction->name() + << " has control predecessors and cannot be removed"; + TF_RET_CHECK(instruction->control_successors().empty()) + << "instruction " << instruction->name() + << " has control successors and cannot be removed"; TF_RET_CHECK(instruction_iterators_.count(instruction) != 0); auto inst_it = instruction_iterators_.at(instruction); @@ -227,7 +233,8 @@ void ComputeComputationPostOrder( } for (auto& instruction : computation->instructions()) { - for (auto& called_computation : instruction->MakeCalledComputationsSet()) { + for (HloComputation* called_computation : + instruction->called_computations()) { ComputeComputationPostOrder(called_computation, visited, post_order); } } @@ -383,15 +390,6 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction( } } -Status HloComputation::AddControlDependency(HloInstruction* predecessor, - HloInstruction* successor) { - TF_RET_CHECK(instruction_iterators_.count(predecessor) > 0); - TF_RET_CHECK(instruction_iterators_.count(successor) > 0); - successor->AddControlPredecessor(predecessor); - predecessor->AddControlSuccessor(successor); - return Status::OK(); -} - ProgramShape HloComputation::ComputeProgramShape() const { ProgramShape program_shape; diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 1833f70551..fc02b2c4ef 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -128,17 +128,6 @@ class HloComputation { return instructions_; } - // Add a control dependency between the two instructions in this computation - // so that the 'predecessor' is visited before the 'successor' during the DFS - // traversal of the computation. Returns an error status if either of the - // given instructions does not belong to the current computation. - // - // This is used to enforce an additional ordering requirement that is not - // captured by normal data dependencies, such as ordering among Send or Recv - // operations to avoid deadlock. - Status AddControlDependency(HloInstruction* predecessor, - HloInstruction* successor); - // Compute and return a post-order of the instructions in the computation. In // this order, definitions of values always appear before their uses. std::list<HloInstruction*> MakeInstructionPostOrder() const; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 9b978ff490..12a5683396 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -297,7 +297,7 @@ TEST_F(HloComputationTest, CycleDetection) { auto computation = builder.Build(); // Add a control dependency to create a cycle. - ASSERT_IS_OK(computation->AddControlDependency(add, negate)); + ASSERT_IS_OK(add->AddControlDependencyTo(negate)); const auto visitor = [](HloInstruction* instruction) { return Status::OK(); }; auto visit_status = computation->Accept(visitor); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 4373180535..3f8a2f9859 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -193,8 +193,6 @@ string InstructionSequenceGraph( instruction->metadata().source_line()); } - std::vector<HloComputation*> called_computations; - // Pick different colors or shapes for instructions which are particularly // expensive (eg, dot) and those which are unusual in some way or unique // (eg, parameter). @@ -401,7 +399,8 @@ string InstructionSequenceGraph( } else { // Add a dotted edge between the instruction and any computations that the // instruction calls. - for (auto* computation : instruction->MakeCalledComputationsSet()) { + for (const HloComputation* computation : + instruction->called_computations()) { string cluster_name = StrCat("cluster_", ComputationId(computation)); string call_edge = Printf( "%s -> %s [ style=dashed; ltail=%s ];\n", diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 968b953193..883f9751d1 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -42,6 +41,10 @@ limitations under the License. namespace xla { +using ::tensorflow::strings::StrAppend; +using ::tensorflow::str_util::Join; +using ::tensorflow::strings::Printf; + /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter( int64 parameter_number, const Shape& shape, const string& name) { auto instruction = @@ -195,7 +198,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, for (auto operand : operands) { instruction->AppendOperand(operand); } - instruction->to_apply_ = map_computation; + instruction->called_computations_.push_back(map_computation); return instruction; } @@ -276,8 +279,9 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape, HloInstruction* init) { auto instruction = WrapUnique(new HloInstruction(HloOpcode::kWhile, shape)); instruction->AppendOperand(init); - instruction->condition_ = condition; - instruction->body_ = body; + // Body comes before condition computation in the vector. + instruction->called_computations_.push_back(body); + instruction->called_computations_.push_back(condition); return instruction; } @@ -345,7 +349,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, instruction->AppendOperand(init_value); instruction->dimensions_.assign(dimensions_to_reduce.begin(), dimensions_to_reduce.end()); - instruction->to_apply_ = reduce_computation; + instruction->called_computations_.push_back(reduce_computation); return instruction; } @@ -356,7 +360,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, WrapUnique(new HloInstruction(HloOpcode::kReduceWindow, shape)); instruction->AppendOperand(operand); instruction->AppendOperand(init_value); - instruction->to_apply_ = reduce_computation; + instruction->called_computations_.push_back(reduce_computation); instruction->window_ = MakeUnique<Window>(window); return instruction; } @@ -371,8 +375,9 @@ HloInstruction::CreateSelectAndScatter( instruction->AppendOperand(operand); instruction->AppendOperand(source); instruction->AppendOperand(init_value); - instruction->select_ = select; - instruction->scatter_ = scatter; + // Select comes before scatter in the vector. + instruction->called_computations_.push_back(select); + instruction->called_computations_.push_back(scatter); instruction->window_ = MakeUnique<Window>(window); return instruction; } @@ -480,7 +485,7 @@ HloInstruction* HloInstruction::FuseInstruction( CHECK_EQ(opcode_, HloOpcode::kFusion); // This fusion instruction must be a user of instruction_to_fuse. - CHECK_NE(0, instruction_to_fuse->users().count(this)); + CHECK(IsUserOf(instruction_to_fuse)); HloInstruction* fused_instruction = CloneAndFuseInternal(instruction_to_fuse); CheckFusionInstruction(); return fused_instruction; @@ -573,6 +578,14 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); } + for (HloComputation* computation : + instruction_to_fuse->called_computations()) { + if (std::find(called_computations_.begin(), called_computations_.end(), + computation) == called_computations_.end()) { + called_computations_.push_back(computation); + } + } + return clone; } @@ -581,45 +594,6 @@ RandomDistribution HloInstruction::random_distribution() const { return distribution_; } -namespace { - -// Adds any HloComputations this instruction calls directly to the given set. -void CalledComputationsInternal( - const HloInstruction& instruction, - std::set<HloComputation*>* called_computations) { - switch (instruction.opcode()) { - case HloOpcode::kCall: - case HloOpcode::kMap: - case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: - called_computations->insert(instruction.to_apply()); - break; - case HloOpcode::kSelectAndScatter: - called_computations->insert(instruction.select()); - called_computations->insert(instruction.scatter()); - break; - case HloOpcode::kWhile: - called_computations->insert(instruction.while_condition()); - called_computations->insert(instruction.while_body()); - break; - case HloOpcode::kFusion: - for (const auto& fused_instruction : instruction.fused_instructions()) { - CalledComputationsInternal(*fused_instruction, called_computations); - } - break; - default: - break; - } -} - -} // namespace - -std::set<HloComputation*> HloInstruction::MakeCalledComputationsSet() const { - std::set<HloComputation*> called_computations; - CalledComputationsInternal(*this, &called_computations); - return called_computations; -} - void HloInstruction::CheckFusionInstruction() const { CHECK_EQ(opcode_, HloOpcode::kFusion); @@ -698,7 +672,7 @@ void HloInstruction::CheckFusionInstruction() const { for (auto operand : operands) { instruction->AppendOperand(operand); } - instruction->to_apply_ = computation; + instruction->called_computations_.push_back(computation); return instruction; } @@ -777,7 +751,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( CHECK_EQ(operands.size(), 1); return CreateBroadcast(shape, operands[0], dimensions_); case HloOpcode::kCall: - return CreateCall(shape, operands, to_apply_); + return CreateCall(shape, operands, to_apply()); case HloOpcode::kCustomCall: return CreateCustomCall(shape, operands, custom_call_target_); case HloOpcode::kConcatenate: @@ -796,22 +770,22 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( CHECK_EQ(operands.size(), 1); return CreateGetTupleElement(shape, operands[0], tuple_index()); case HloOpcode::kMap: - return CreateMap(shape, operands, to_apply_); + return CreateMap(shape, operands, to_apply()); case HloOpcode::kPad: CHECK_EQ(operands.size(), 2); return CreatePad(shape, operands[0], operands[1], *padding_config_); case HloOpcode::kReduce: CHECK_EQ(operands.size(), 2); return CreateReduce(shape, operands[0], operands[1], dimensions_, - to_apply_); + to_apply()); case HloOpcode::kReduceWindow: CHECK_EQ(operands.size(), 2); return CreateReduceWindow(shape, operands[0], operands[1], *window_, - to_apply_); + to_apply()); case HloOpcode::kSelectAndScatter: CHECK_EQ(operands.size(), 3); - return CreateSelectAndScatter(shape, operands[0], select_, *window_, - operands[1], operands[2], scatter_); + return CreateSelectAndScatter(shape, operands[0], select(), *window_, + operands[1], operands[2], scatter()); case HloOpcode::kRecv: CHECK_EQ(operands.size(), 0); return CreateRecv(shape, channel_id_); @@ -843,7 +817,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( return CreateTuple(operands_); case HloOpcode::kWhile: CHECK_EQ(operands.size(), 1); - return CreateWhile(shape, condition_, body_, operands[0]); + return CreateWhile(shape, while_condition(), while_body(), operands[0]); case HloOpcode::kConstant: return CreateConstant(LiteralUtil::CloneToUnique(*literal_)); case HloOpcode::kFusion: @@ -973,12 +947,43 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const { LOG(FATAL) << "target was not an operand"; } +Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { + TF_RET_CHECK(instruction->parent() == parent()); + if (std::find(control_successors_.begin(), control_successors_.end(), + instruction) == control_successors_.end()) { + control_successors_.push_back(instruction); + TF_RET_CHECK(std::find(instruction->control_predecessors_.begin(), + instruction->control_predecessors_.end(), + this) == instruction->control_predecessors_.end()); + instruction->control_predecessors_.push_back(this); + } + return Status::OK(); +} + +Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) { + auto succ_it = std::find(control_successors_.begin(), + control_successors_.end(), instruction); + TF_RET_CHECK(succ_it != control_successors_.end()); + control_successors_.erase(succ_it); + auto pred_it = std::find(instruction->control_predecessors_.begin(), + instruction->control_predecessors_.end(), this); + TF_RET_CHECK(pred_it != instruction->control_predecessors_.end()); + instruction->control_predecessors_.erase(succ_it); + + return Status::OK(); +} + void HloInstruction::AppendOperand(HloInstruction* operand) { operands_.push_back(operand); operand->AddUser(this); } -void HloInstruction::AddUser(HloInstruction* user) { users_.insert(user); } +void HloInstruction::AddUser(HloInstruction* user) { + if (!ContainsKey(user_set_, user)) { + user_set_.insert(user); + users_.push_back(user); + } +} bool HloInstruction::IsConstant() const { return opcode_ == HloOpcode::kConstant; @@ -993,14 +998,6 @@ bool HloInstruction::HasConstantOperand() const { return false; } -void HloInstruction::AddControlPredecessor(HloInstruction* instruction) { - control_predecessors_.insert(instruction); -} - -void HloInstruction::AddControlSuccessor(HloInstruction* instruction) { - control_successors_.insert(instruction); -} - bool HloInstruction::Identical( const HloInstruction& other, std::function<bool(const HloInstruction*, const HloInstruction*)> @@ -1161,9 +1158,14 @@ bool HloInstruction::IsRank2Transpose() const { } void HloInstruction::RemoveUser(HloInstruction* user) { - auto user_it = users_.find(user); - CHECK(user_it != users_.end()); - users_.erase(user_it); + auto set_it = user_set_.find(user); + CHECK(set_it != user_set_.end()); + user_set_.erase(set_it); + // This is linear in the number of the users, but a vector provides a stable + // iteration order and much faster traversal. + auto vec_it = std::find(users_.begin(), users_.end(), user); + CHECK(vec_it != users_.end()); + users_.erase(vec_it); } Status HloInstruction::ReplaceUseWith(HloInstruction* user, @@ -1172,15 +1174,12 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, << "this shape: " << ShapeUtil::HumanString(shape()) << ", replacement shape: " << ShapeUtil::HumanString(new_producer->shape()); - auto user_it = std::find(users_.begin(), users_.end(), user); - TF_RET_CHECK(user_it != users_.end()) - << "Instruction " << user->name() << " not a use of instruction " - << name(); - users_.erase(user_it); VLOG(3) << "Replacing uses of " << name() << " in " << user->name() << " with " << new_producer->name(); + RemoveUser(user); + TF_RET_CHECK( std::count(user->operands_.begin(), user->operands_.end(), this) >= 0); std::replace(user->operands_.begin(), user->operands_.end(), this, @@ -1212,19 +1211,26 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, } Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { - // We can't use range-based loop because the iterator is invalidated by call - // to ReplaceUseWith. - for (auto user = users_.begin(); user != users_.end();) { - auto this_user = user; - user++; - // It's possible that new_producer is a user of this instruction as might - // be the case when replacing an instruction with a kCopy of itself. In - // this case, don't do the replacement to avoid creating a cycle in the - // graph. - if (*this_user != new_producer) { - TF_RETURN_IF_ERROR(ReplaceUseWith(*this_user, new_producer)); + bool new_producer_is_user = false; + for (HloInstruction* user : users()) { + if (user == new_producer) { + // It's possible that new_producer is a user of this instruction as might + // be the case when replacing an instruction with a kCopy of itself. In + // this case, don't do the replacement to avoid creating a cycle in the + // graph. new_producer remains the only user of this instruction. + new_producer_is_user = true; + } else { + std::replace(user->operands_.begin(), user->operands_.end(), this, + new_producer); + new_producer->AddUser(user); } } + users_.clear(); + user_set_.clear(); + if (new_producer_is_user) { + AddUser(new_producer); + } + return Status::OK(); } @@ -1235,7 +1241,7 @@ void HloInstruction::DetachFromOperands() { std::set<HloInstruction*> detached_operands; for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { HloInstruction* operand = operands_[operand_num]; - if (detached_operands.count(operand) == 0) { + if (!ContainsKey(detached_operands, operand)) { operand->RemoveUser(this); detached_operands.insert(operand); } @@ -1249,22 +1255,29 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: - return to_apply_; + CHECK_EQ(called_computations_.size(), 1); + return called_computations_[0]; default: - LOG(FATAL) << "Invalid instruction for to_apply(): " << ToString(); + LOG(FATAL) << "Invalid opcode for to_apply(): " + << HloOpcodeString(opcode()); } } void HloInstruction::set_to_apply(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); switch (opcode_) { case HloOpcode::kCall: case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: - to_apply_ = computation; + CHECK_EQ(called_computations_.size(), 1); + called_computations_[0] = computation; break; default: - LOG(FATAL) << "Invalid instruction for to_apply(): " << ToString(); + LOG(FATAL) << "Invalid opcode for to_apply(): " + << HloOpcodeString(opcode()); } } @@ -1280,49 +1293,60 @@ const string& HloInstruction::outfeed_config() const { HloComputation* HloInstruction::while_condition() const { CHECK_EQ(HloOpcode::kWhile, opcode_); - return condition_; + return called_computations_[kConditionComputationIndex]; } HloComputation* HloInstruction::while_body() const { CHECK_EQ(HloOpcode::kWhile, opcode_); - return body_; + return called_computations_[kBodyComputationIndex]; } void HloInstruction::set_while_condition(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); CHECK_EQ(HloOpcode::kWhile, opcode_); - condition_ = computation; + called_computations_[kConditionComputationIndex] = computation; } void HloInstruction::set_while_body(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); CHECK_EQ(HloOpcode::kWhile, opcode_); - body_ = computation; + called_computations_[kBodyComputationIndex] = computation; } HloComputation* HloInstruction::select() const { CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - return select_; + return called_computations_[kSelectComputationIndex]; } HloComputation* HloInstruction::scatter() const { CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - return scatter_; + return called_computations_[kScatterComputationIndex]; } void HloInstruction::set_select(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - select_ = computation; + called_computations_[kSelectComputationIndex] = computation; } void HloInstruction::set_scatter(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - scatter_ = computation; + called_computations_[kScatterComputationIndex] = computation; } string HloInstruction::SignatureString() const { - string operands = tensorflow::str_util::Join( - operands_, ", ", [](string* out, HloInstruction* operand) { - tensorflow::strings::StrAppend( - out, ShapeUtil::HumanString(operand->shape())); + string operands = + Join(operands_, ", ", [](string* out, HloInstruction* operand) { + StrAppend(out, ShapeUtil::HumanString(operand->shape())); }); return tensorflow::strings::StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape())); @@ -1343,7 +1367,7 @@ string HloInstruction::ToString(bool compact_operands) const { // empty entries. for (const auto& s : v) { if (s.empty()) continue; - tensorflow::strings::StrAppend(&operands, (first ? "" : " "), s); + StrAppend(&operands, (first ? "" : " "), s); first = false; } } else { @@ -1356,31 +1380,26 @@ string HloInstruction::ToString(bool compact_operands) const { if (compact_operands && slice.size() > kMaxOperandsToShowIfCompact) { slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); } - operands = tensorflow::str_util::Join( - slice, ", ", [&](string* out, HloInstruction* operand) { - *out += ShapeUtil::HumanStringWithLayout(operand->shape()); - if (!compact_operands) { - tensorflow::strings::StrAppend(out, " ", operand->name()); - } - }); + operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { + *out += ShapeUtil::HumanStringWithLayout(operand->shape()); + if (!compact_operands) { + StrAppend(out, " ", operand->name()); + } + }); const int64 remaining = operands_.size() - slice.size(); if (slice.size() != operands_.size()) { - tensorflow::strings::StrAppend(&operands, ", ...(+", remaining, ")"); + StrAppend(&operands, ", ...(+", remaining, ")"); } } string extra; if (CanHaveDimensionsField()) { - tensorflow::strings::StrAppend( - &extra, ", dimensions={", tensorflow::str_util::Join(dimensions(), ","), - "}"); + StrAppend(&extra, ", dimensions={", Join(dimensions(), ","), "}"); } if (window_ != nullptr) { - tensorflow::strings::StrAppend(&extra, ", ", - window_util::ToString(*window_)); + StrAppend(&extra, ", ", window_util::ToString(*window_)); } if (padding_config_ != nullptr) { - tensorflow::strings::StrAppend( - &extra, ", padding=", padding_config_->ShortDebugString()); + StrAppend(&extra, ", padding=", padding_config_->ShortDebugString()); } if (!slice_starts_.empty() && !slice_limits_.empty()) { std::vector<string> bounds; @@ -1388,8 +1407,7 @@ string HloInstruction::ToString(bool compact_operands) const { bounds.push_back(tensorflow::strings::StrCat("[", slice_starts_[i], ":", slice_limits_[i], "]")); } - tensorflow::strings::StrAppend( - &extra, ", slice={", tensorflow::str_util::Join(bounds, ", "), "}"); + StrAppend(&extra, ", slice={", Join(bounds, ", "), "}"); } if (convolution_dimension_numbers_ != nullptr) { const auto& dnums = *convolution_dimension_numbers_; @@ -1432,37 +1450,42 @@ string HloInstruction::ToString(bool compact_operands) const { extra += "->"; append_dims(lhs_dims, shape()); } - if (to_apply_ != nullptr) { - tensorflow::strings::StrAppend(&extra, ", computation=", to_apply_->name()); - } + if (opcode() == HloOpcode::kWhile) { - tensorflow::strings::StrAppend(&extra, - ", condition=", while_condition()->name()); - tensorflow::strings::StrAppend(&extra, ", body=", while_body()->name()); + StrAppend(&extra, ", condition=", while_condition()->name()); + StrAppend(&extra, ", body=", while_body()->name()); + } else if (opcode() == HloOpcode::kSelectAndScatter) { + StrAppend(&extra, ", select=", select()->name()); + StrAppend(&extra, ", scatter=", scatter()->name()); + } else if (!called_computations().empty()) { + StrAppend(&extra, ", calls=", + Join(called_computations(), ", ", + [](string* out, const HloComputation* computation) { + StrAppend(out, computation->name()); + })); } + if (opcode() == HloOpcode::kGetTupleElement) { - tensorflow::strings::StrAppend(&extra, ", index=", tuple_index()); + StrAppend(&extra, ", index=", tuple_index()); } if (!metadata_.op_type().empty() || !metadata_.op_name().empty() || !metadata_.source_file().empty()) { - tensorflow::strings::StrAppend( - &extra, " # metadata=", metadata_.ShortDebugString()); + StrAppend(&extra, " # metadata=", metadata_.ShortDebugString()); } - return tensorflow::strings::Printf( - "%s = %s %s(%s)%s", name().c_str(), - ShapeUtil::HumanStringWithLayout(shape()).c_str(), - HloOpcodeString(opcode()).c_str(), operands.c_str(), extra.c_str()); + return Printf("%s = %s %s(%s)%s", name().c_str(), + ShapeUtil::HumanStringWithLayout(shape()).c_str(), + HloOpcodeString(opcode()).c_str(), operands.c_str(), + extra.c_str()); } string HloInstruction::ToShortString() const { - return tensorflow::strings::Printf( - "%s = %s(%s)", name().c_str(), HloOpcodeString(opcode()).c_str(), - tensorflow::str_util::Join(operands_, ", ", - [](string* out, HloInstruction* operand) { - tensorflow::strings::StrAppend( - out, operand->name()); - }) - .c_str()); + return Printf("%s = %s(%s)", name().c_str(), + HloOpcodeString(opcode()).c_str(), + Join(operands_, ", ", + [](string* out, HloInstruction* operand) { + StrAppend(out, operand->name()); + }) + .c_str()); } string HloInstruction::ToCategory() const { @@ -1659,16 +1682,16 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kTuple: return visitor->HandleTuple(this, operands_); case HloOpcode::kMap: - return visitor->HandleMap(this, operands_, to_apply_, {}); + return visitor->HandleMap(this, operands_, to_apply(), {}); case HloOpcode::kClamp: return visitor->HandleClamp(this, operands_[0], operands_[1], operands_[2]); case HloOpcode::kReduce: return visitor->HandleReduce(this, operands_[0], operands_[1], - dimensions_, to_apply_); + dimensions_, to_apply()); case HloOpcode::kReduceWindow: return visitor->HandleReduceWindow(this, operands_[0], window(), - to_apply_); + to_apply()); case HloOpcode::kSelectAndScatter: return visitor->HandleSelectAndScatter(this); case HloOpcode::kNegate: @@ -1715,11 +1738,12 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kRng: return visitor->HandleRng(this, distribution_); case HloOpcode::kWhile: - return visitor->HandleWhile(this, operands_[0], condition_, body_); + return visitor->HandleWhile(this, operands_[0], while_condition(), + while_body()); case HloOpcode::kFusion: return visitor->HandleFusion(this); case HloOpcode::kCall: - return visitor->HandleCall(this, operands_, to_apply_); + return visitor->HandleCall(this, operands_, to_apply()); case HloOpcode::kCustomCall: return visitor->HandleCustomCall(this, operands_, custom_call_target_); case HloOpcode::kSend: @@ -1828,7 +1852,7 @@ bool OrderIsTopologicalSort(const std::vector<const HloInstruction*>& order) { // ops). for (auto* instruction : order) { for (auto* operand : instruction->operands()) { - if (order_position.count(operand) == 0 || + if (!ContainsKey(order_position, operand) || order_position.at(operand) >= order_position.at(instruction)) { return false; } @@ -1859,7 +1883,7 @@ Status HloInstruction::AcceptOrdered( })); for (auto* const_instruction : order) { - if (predecessors.count(const_instruction) == 0) { + if (!ContainsKey(predecessors, const_instruction)) { // Instruction is not a predecessors of 'this'. continue; } @@ -2008,7 +2032,7 @@ bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const { HloInstruction* operand = worklist.front(); worklist.pop_front(); for (HloInstruction* user : operand->users()) { - if (visited.count(user)) { + if (ContainsKey(visited, user)) { continue; } if (user->IsElementwise() || @@ -2063,7 +2087,7 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { hlo.parameter_number_ == i) { return UseKind::kUse; } - if (cache.count(&hlo) == 0) { + if (!ContainsKey(cache, &hlo)) { for (int64 j = 0; j < hlo.operands_.size(); ++j) { UseKind old = cache[&hlo]; UseKind updated = plus( diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index af960bd364..926a984d22 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -24,11 +24,12 @@ limitations under the License. #include <functional> #include <list> #include <memory> -#include <set> #include <string> #include <tuple> +#include <unordered_set> #include <vector> +#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -303,28 +304,38 @@ class HloInstruction { int64 user_count() const { return users_.size(); } // Returns the users of this instruction. - const std::set<HloInstruction*>& users() const { return users_; } + const std::vector<HloInstruction*>& users() const { return users_; } - // Returns the set of control predecessors of this instruction. Control - // predecessors are the instructions that must be scheduled before the current - // instruction. - const std::set<HloInstruction*>& control_predecessors() const { - return control_predecessors_; + // Returns true if this instruction is a user of 'instruction'. + bool IsUserOf(const HloInstruction* instruction) const { + return ContainsKey(instruction->user_set_, this); } - // Adds the given instruction to the set of control predecessors. - void AddControlPredecessor(HloInstruction* instruction); - - // Returns the set of control successors of this instruction. Control - // successors are the instructions that must be scheduled after the current - // instruction. - const std::set<HloInstruction*>& control_successors() const { + // Adds a control dependency from this instruction to the given + // instruction. This instruction becomes a control predecessor of + // 'instruction', and 'instruction' becomes a control successor of this + // instruction. Returns an error status if either of the given instructions + // does not belong to the same computation. + // + // This is used to enforce an additional ordering requirement that is not + // captured by normal data dependencies, such as ordering among Send or Recv + // operations to avoid deadlock. + Status AddControlDependencyTo(HloInstruction* instruction); + + // Removes a previously added control dependency from this instruction to + // 'instruction'. + Status RemoveControlDependencyTo(HloInstruction* instruction); + + // Returns the set of control predecessors (successors) of this + // instruction. Control predecessors (sucessors) must execute before (after) + // the current instruction. + const std::vector<HloInstruction*>& control_predecessors() const { + return control_predecessors_; + } + const std::vector<HloInstruction*>& control_successors() const { return control_successors_; } - // Adds the given instruction to the set of control successors. - void AddControlSuccessor(HloInstruction* instruction); - // Returns true if "other" performs the same computation as this instruction. // Layout of the instructions' output array is not considered. bool Identical( @@ -636,10 +647,11 @@ class HloInstruction { const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands); - // Computes and returns the computations this instruction calls (if any). This - // includes computations called by fused instructions inside of a fusion - // instruction. - std::set<HloComputation*> MakeCalledComputationsSet() const; + // Returns the computations this instruction calls (if any). This includes + // computations called by fused instructions inside of a fusion instruction. + const std::vector<HloComputation*>& called_computations() const { + return called_computations_; + } // Returns true if this instruction performs an elementwise operation on // `operand_idx`-th operand. An instruction is elementwise on an operand iff, @@ -806,21 +818,23 @@ class HloInstruction { int64 parameter_number_ = 0; string parameter_name_; - // Computation to apply, only present for kCall, kMap, kReduce and - // kReduceWindow. - HloComputation* to_apply_ = nullptr; - // Name of a global symbol to call, only present for kCustomCall. string custom_call_target_; - // Computation for condition and body of kWhile, only present for kWhile. - HloComputation* condition_ = nullptr; - HloComputation* body_ = nullptr; + // Computations called by this instruction. + std::vector<HloComputation*> called_computations_; + + // Indices of computations in called_computations_ for instructions which call + // multiple computations. + enum { + // kWhile computations. + kBodyComputationIndex = 0, + kConditionComputationIndex = 1, - // Computation for select and scatter, only present for - // kSelectAndScatter. - HloComputation* select_ = nullptr; - HloComputation* scatter_ = nullptr; + // kSelectAndScatter computations. + kSelectComputationIndex = 0, + kScatterComputationIndex = 1, + }; // Outfeed configuration information, only present for kOutfeed. string outfeed_config_; @@ -829,14 +843,17 @@ class HloInstruction { std::vector<HloInstruction*> operands_; // The users of this instruction. Users are HLOs where this instruction is an - // operand. - std::set<HloInstruction*> users_; + // operand. The vector users_ and the set user_set_ contain identical + // members. The set enables fast membership testing and the vector enables + // fast, stable iteration. + std::vector<HloInstruction*> users_; + std::unordered_set<const HloInstruction*> user_set_; // The set of control predecessors of this instruction. - std::set<HloInstruction*> control_predecessors_; + std::vector<HloInstruction*> control_predecessors_; // The set of control successors of this instruction. - std::set<HloInstruction*> control_successors_; + std::vector<HloInstruction*> control_successors_; // A trace instruction that consumes this instruction. // diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 48711b605f..8eabaa1c47 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -25,15 +25,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace { -#define EXPECT_ISET(A, E...) EXPECT_EQ(A, (std::set<HloInstruction*>{E})) -#define EXPECT_IVEC(A, E...) EXPECT_EQ(A, (std::vector<HloInstruction*>{E})) - -class HloInstructionTest : public ::testing::Test { +class HloInstructionTest : public HloTestBase { protected: HloInstructionTest() {} @@ -149,10 +147,10 @@ TEST_F(HloInstructionTest, UserWithTwoOperands) { auto bar = HloInstruction::CreateParameter(1, r0f32_, "bar"); auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo.get(), bar.get()); - EXPECT_MATCH(add->operands(), testing::UnorderedMatcher<HloInstruction*>( - foo.get(), bar.get())); - EXPECT_ISET(foo->users(), add.get()); - EXPECT_ISET(bar->users(), add.get()); + + ExpectEqOrdered(add->operands(), {foo.get(), bar.get()}); + ExpectEqUnordered(foo->users(), {add.get()}); + ExpectEqUnordered(bar->users(), {add.get()}); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(add->Accept(&visitor)); @@ -385,12 +383,12 @@ TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) { EXPECT_EQ(1, foo->user_count()); EXPECT_EQ(2, bar->user_count()); - EXPECT_ISET(foo->users(), add_foobar.get()); - EXPECT_IVEC(add_foobar->operands(), foo.get(), bar.get()); + ExpectEqUnordered(foo->users(), {add_foobar.get()}); + ExpectEqOrdered(add_foobar->operands(), {foo.get(), bar.get()}); - EXPECT_ISET(bar->users(), add_foobar.get(), add_foofoo.get()); - EXPECT_IVEC(add_foobar->operands(), foo.get(), bar.get()); - EXPECT_IVEC(add_foofoo->operands(), bar.get(), bar.get()); + ExpectEqUnordered(bar->users(), {add_foobar.get(), add_foofoo.get()}); + ExpectEqOrdered(add_foobar->operands(), {foo.get(), bar.get()}); + ExpectEqOrdered(add_foofoo->operands(), {bar.get(), bar.get()}); } TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) { @@ -406,15 +404,16 @@ TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) { foo.get(), bar.get()); EXPECT_EQ(2, foo->user_count()); - EXPECT_ISET(foo->users(), tuple.get(), add_foobar.get()); + ExpectEqUnordered(foo->users(), {tuple.get(), add_foobar.get()}); // Replace the use of foo in tuple with bar. ASSERT_IS_OK(foo->ReplaceUseWith(tuple.get(), bar.get())); - EXPECT_ISET(foo->users(), add_foobar.get()); + ExpectEqUnordered(foo->users(), {add_foobar.get()}); // Both uses of foo in tuple should have been replaced with bar. - EXPECT_IVEC(tuple->operands(), bar.get(), bar.get(), baz.get(), bar.get()); + ExpectEqOrdered(tuple->operands(), + {bar.get(), bar.get(), baz.get(), bar.get()}); } TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { @@ -427,7 +426,7 @@ TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { auto log = HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo.get()); EXPECT_EQ(2, foo->user_count()); - EXPECT_ISET(foo->users(), exp.get(), log.get()); + ExpectEqUnordered(foo->users(), {exp.get(), log.get()}); EXPECT_EQ(0, bar->user_count()); // Replace the use of foo in exp with bar. @@ -435,8 +434,8 @@ TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { // The use of foo in log should not have been affected. EXPECT_EQ(1, foo->user_count()); - EXPECT_ISET(foo->users(), log.get()); - EXPECT_IVEC(log->operands(), foo.get()); + ExpectEqUnordered(foo->users(), {log.get()}); + ExpectEqOrdered(log->operands(), {foo.get()}); // Bar should now be used in exp. EXPECT_EQ(1, bar->user_count()); @@ -467,7 +466,7 @@ TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) { EXPECT_EQ(0, foo->user_count()); EXPECT_EQ(2, bar->user_count()); - EXPECT_ISET(bar->users(), add_foobar.get(), add_foofoo.get()); + ExpectEqUnordered(bar->users(), {add_foobar.get(), add_foofoo.get()}); } TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) { @@ -491,7 +490,7 @@ TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) { EXPECT_EQ(0, foo->user_count()); EXPECT_EQ(3, bar->user_count()); - EXPECT_ISET(bar->users(), add_foobar.get(), exp.get(), tuple.get()); + ExpectEqUnordered(bar->users(), {add_foobar.get(), exp.get(), tuple.get()}); } // Simple visitor that collects and post-processes each node in the graph. @@ -559,8 +558,8 @@ TEST_F(HloInstructionTest, SingletonFusionOp) { auto fusion = HloInstruction::CreateFusion( r0f32_, HloInstruction::FusionKind::kLoop, exp.get()); - EXPECT_IVEC(fusion->operands(), constant.get()); - EXPECT_ISET(constant->users(), fusion.get(), exp.get()); + ExpectEqOrdered(fusion->operands(), {constant.get()}); + ExpectEqUnordered(constant->users(), {fusion.get(), exp.get()}); } TEST_F(HloInstructionTest, BinaryFusionOp) { @@ -575,9 +574,9 @@ TEST_F(HloInstructionTest, BinaryFusionOp) { auto fusion = HloInstruction::CreateFusion( r0f32_, HloInstruction::FusionKind::kLoop, add.get()); - EXPECT_IVEC(fusion->operands(), constant1.get(), constant2.get()); - EXPECT_ISET(constant1->users(), fusion.get(), add.get()); - EXPECT_ISET(constant2->users(), fusion.get(), add.get()); + ExpectEqOrdered(fusion->operands(), {constant1.get(), constant2.get()}); + ExpectEqUnordered(constant1->users(), {fusion.get(), add.get()}); + ExpectEqUnordered(constant2->users(), {fusion.get(), add.get()}); } TEST_F(HloInstructionTest, ChainFusionOp) { @@ -594,8 +593,48 @@ TEST_F(HloInstructionTest, ChainFusionOp) { fusion->FuseInstruction(exp2.get()); fusion->FuseInstruction(exp1.get()); - EXPECT_IVEC(fusion->operands(), constant.get()); - EXPECT_ISET(constant->users(), fusion.get(), exp1.get()); + ExpectEqOrdered(fusion->operands(), {constant.get()}); + ExpectEqUnordered(constant->users(), {fusion.get(), exp1.get()}); +} + +TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { + // Create a fusion instruction containing a single unary operation. + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + + auto make_map_computation = [&]() { + auto builder = HloComputation::Builder("FusionMap"); + builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + return builder.Build(); + }; + + std::unique_ptr<HloComputation> computation_x = make_map_computation(); + std::unique_ptr<HloComputation> computation_y = make_map_computation(); + + auto constant = + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)); + auto map_1_x = + HloInstruction::CreateMap(scalar_shape, {constant.get()}, + computation_x.get(), /*static_operands=*/{}); + auto map_2_x = + HloInstruction::CreateMap(scalar_shape, {map_1_x.get()}, + computation_x.get(), /*static_operands=*/{}); + auto map_3_y = + HloInstruction::CreateMap(scalar_shape, {map_2_x.get()}, + computation_y.get(), /*static_operands=*/{}); + + auto fusion = HloInstruction::CreateFusion( + scalar_shape, HloInstruction::FusionKind::kLoop, map_3_y.get()); + + ASSERT_EQ(fusion->called_computations().size(), 1); + EXPECT_EQ(fusion->called_computations()[0], computation_y.get()); + + fusion->FuseInstruction(map_2_x.get()); + ASSERT_EQ(fusion->called_computations().size(), 2); + EXPECT_EQ(fusion->called_computations()[1], computation_x.get()); + + fusion->FuseInstruction(map_1_x.get()); + ASSERT_EQ(fusion->called_computations().size(), 2); } TEST_F(HloInstructionTest, ComplexFusionOp) { @@ -636,8 +675,8 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { // Operands in the fusion instruction's operands() vector should be in the // order in which their users were added fused. - EXPECT_IVEC(fusion->operands(), c1.get(), c3.get(), c2.get()); - EXPECT_ISET(c1->users(), add.get(), tuple.get(), fusion.get()); + ExpectEqOrdered(fusion->operands(), {c1.get(), c3.get(), c2.get()}); + ExpectEqUnordered(c1->users(), {add.get(), tuple.get(), fusion.get()}); } // Convenience function for comparing two HloInstructions inside of diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 5d68b456cd..36064e93fe 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -232,7 +232,8 @@ std::list<HloComputation*> HloModule::MakeComputationPostOrder() const { std::set<HloComputation*> nonroot_computations; for (auto& computation : computations_) { for (auto& instruction : computation->instructions()) { - for (auto called_computation : instruction->MakeCalledComputationsSet()) { + for (HloComputation* called_computation : + instruction->called_computations()) { nonroot_computations.insert(called_computation); } } diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 8775d9f888..b3168ed40e 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -292,11 +292,24 @@ class ListScheduler { std::vector<const HloInstruction*> CreateSchedule() { std::vector<const HloInstruction*> schedule; - // Populate the ready list with instructions which have no operands. + // Populate the ready list with instructions which have no operands or + // control predecessors. + std::unordered_map<const HloInstruction*, int64> unscheduled_pred_count; std::list<const HloInstruction*> ready_list; for (auto& instruction : computation_.instructions()) { - if (instruction->operand_count() == 0 && - instruction->control_predecessors().empty()) { + // TODO(b/34466113): Replace this and above with successors() or + // predecessors() when these methods are added to HloInstruction. + for (const HloInstruction* user : instruction->users()) { + unscheduled_pred_count[user]++; + } + for (const HloInstruction* succ : instruction->control_successors()) { + unscheduled_pred_count[succ]++; + } + } + for (auto& instruction : computation_.instructions()) { + // Instruction with no operands or control predecessors will + // not be in the map. + if (unscheduled_pred_count.count(instruction.get()) == 0) { ready_list.push_back(instruction.get()); } } @@ -328,28 +341,21 @@ class ListScheduler { } // Add new instructions to ready list. - // TODO(b/34466113): Replace this with successors()/predecessors() when - // predecessor/successor methods are added to HloInstruction. This also - // will resolve the nondeterminism of using a set here assuming - // predecessors/successors is a vector. - std::set<HloInstruction*> successors = best->users(); - successors.insert(best->control_successors().begin(), - best->control_successors().end()); - for (auto* successor : successors) { - std::set<HloInstruction*> predecessors(successor->operands().begin(), - successor->operands().end()); - predecessors.insert(successor->control_predecessors().begin(), - successor->control_predecessors().end()); - bool is_ready = true; - for (auto* predecessor : predecessors) { - if (scheduled_instructions_.count(predecessor) == 0) { - is_ready = false; - break; - } - } - if (is_ready) { - ready_list.push_back(successor); + auto update_pred_count = [&unscheduled_pred_count, + &ready_list](HloInstruction* inst) { + int64 pred_count = --unscheduled_pred_count.at(inst); + CHECK_GE(pred_count, 0); + if (pred_count == 0) { + ready_list.push_back(inst); } + }; + // TODO(b/34466113): Replace this and above with successors() or + // predecessors() when these methods are added to HloInstruction. + for (HloInstruction* user : best->users()) { + update_pred_count(user); + } + for (HloInstruction* succ : best->control_successors()) { + update_pred_count(succ); } } CHECK_EQ(schedule.size(), computation_.instructions().size()); diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index c60e13fc9c..7160129c12 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -106,8 +106,7 @@ bool IsExpensive(const HloInstruction& instruction) { } bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer) { - return !(producer->users().size() == 1 && - producer->users().count(consumer) == 1); + return !(producer->users().size() == 1 && consumer->IsUserOf(producer)); } StatusOr<bool> InstructionFusion::Run(HloModule* module) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 91fc9b87cd..6119473d81 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -84,6 +84,28 @@ class HloTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase> arguments); + // Helpers for comparing ordered and unordered equality of HloInstruction + // containers. + void ExpectEqOrdered( + tensorflow::gtl::ArraySlice<const HloInstruction*> actual, + tensorflow::gtl::ArraySlice<const HloInstruction*> expected) { + std::vector<const HloInstruction*> expected_vec(expected.begin(), + expected.end()); + std::vector<const HloInstruction*> actual_vec(actual.begin(), actual.end()); + EXPECT_TRUE(testing::VectorMatcher<const HloInstruction*>(expected_vec)( + actual_vec)); + } + + void ExpectEqUnordered( + tensorflow::gtl::ArraySlice<const HloInstruction*> actual, + tensorflow::gtl::ArraySlice<const HloInstruction*> expected) { + std::vector<const HloInstruction*> expected_vec(expected.begin(), + expected.end()); + std::vector<const HloInstruction*> actual_vec(actual.begin(), actual.end()); + EXPECT_TRUE(testing::UnorderedElementsAre<const HloInstruction*>( + expected_vec)(actual_vec)); + } + string TestName() const; std::unique_ptr<Backend> backend_; |