From aeab291563b0b4cc75c0f5fc73610a6595780570 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 21 Aug 2018 07:46:55 -0700 Subject: Handle communicating instructions in HloComputation::ComputeReachability Send&recv instructions and cross-replica-sum instructions are imposing extra dependencies via the channel id or all reduce id. This CL teaches the reachability calculation logic in hlo computation to correctly account for these "invisible" dependencies. The main purpose is to stop multi output fusion from generating dependency cyclies via communicating instructions. PiperOrigin-RevId: 209593997 --- tensorflow/compiler/xla/service/hlo_computation.cc | 80 +++++++++++++++++++++- tensorflow/compiler/xla/service/hlo_computation.h | 7 ++ .../compiler/xla/service/hlo_computation_test.cc | 23 ++++++- 3 files changed, 108 insertions(+), 2 deletions(-) (limited to 'tensorflow/compiler/xla') diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index bae78c94bd..70b18ff356 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -321,6 +321,7 @@ void ComputeComputationPostOrder( enum State { kVisiting, kVisited }; void ComputeInstructionPostOrder( + std::map> channel_dependency_map, std::vector* post_order, HloInstruction* root, tensorflow::gtl::FlatMap* visited) { std::vector dfs_stack; @@ -355,12 +356,67 @@ void ComputeInstructionPostOrder( for (HloInstruction* op : current->control_predecessors()) { dfs_stack.emplace_back(op); } + + // Add inputs for send->recv_done dependencies and cross-replica-sum + // dependencies. + switch (current->opcode()) { + case HloOpcode::kRecvDone: { + const auto& dependencies = + channel_dependency_map[current->channel_id()]; + for (HloInstruction* op : dependencies) { + dfs_stack.emplace_back(op); + } + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = current->all_reduce_id(); + if (all_reduce_id) { + const auto& dependencies = + channel_dependency_map[all_reduce_id.value()]; + for (HloInstruction* op : dependencies) { + dfs_stack.emplace_back(op); + } + } + break; + } + default: + break; + } } } } // namespace +std::map> +HloComputation::ComputeChannelDependencies() const { + std::map> channel_dependency_map; + for (const auto& instruction : instructions_) { + switch (instruction->opcode()) { + case HloOpcode::kSend: { + channel_dependency_map[instruction->channel_id()].push_back( + instruction.get()); + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = instruction->all_reduce_id(); + if (all_reduce_id) { + auto& dependencies = channel_dependency_map[all_reduce_id.value()]; + absl::c_copy(instruction->operands(), + std::back_inserter(dependencies)); + absl::c_copy(instruction->control_predecessors(), + std::back_inserter(dependencies)); + } + break; + } + default: + break; + } + } + return channel_dependency_map; +} + std::vector HloComputation::MakeInstructionPostOrder() const { + auto channel_dependency_map = ComputeChannelDependencies(); std::vector post_order; post_order.reserve(instruction_count()); std::vector trace_instructions; @@ -372,7 +428,8 @@ std::vector HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - ComputeInstructionPostOrder(&post_order, instruction.get(), &visited); + ComputeInstructionPostOrder(channel_dependency_map, &post_order, + instruction.get(), &visited); } } post_order.insert(post_order.end(), trace_instructions.begin(), @@ -676,12 +733,33 @@ std::unique_ptr HloComputation::ComputeReachability() const { const auto& all = MakeInstructionPostOrder(); auto result = absl::make_unique(all); + auto channel_dependency_map = ComputeChannelDependencies(); std::vector inputs; for (const HloInstruction* hlo : all) { inputs.assign(hlo->operands().begin(), hlo->operands().end()); inputs.insert(inputs.end(), hlo->control_predecessors().begin(), hlo->control_predecessors().end()); + + switch (hlo->opcode()) { + case HloOpcode::kRecvDone: { + const auto& dependencies = channel_dependency_map[hlo->channel_id()]; + absl::c_copy(dependencies, std::back_inserter(inputs)); + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = hlo->all_reduce_id(); + if (all_reduce_id) { + const auto& dependencies = + channel_dependency_map[all_reduce_id.value()]; + absl::c_copy(dependencies, std::back_inserter(inputs)); + } + break; + } + default: + break; + } + result->FastSetReachabilityToUnion(inputs, hlo); } return result; diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 49ed65910f..faa33f0f90 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -399,6 +399,13 @@ class HloComputation { // Internal helper to collect unreachable roots. std::vector CollectUnreachableRoots() const; + // Returns a map from channel-id to directed dependencies of the channel + // instructions. For send&recv pairs it means the send instruction and for + // cross-replica-sum the union of the dependencies for all participating + // instructions. + std::map> ComputeChannelDependencies() + const; + string name_; int64 unique_id_; HloInstruction* root_instruction_; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index e4c5470331..f7ed1b0316 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -691,6 +691,27 @@ TEST_F(HloComputationTest, StringificationCanonical) { EXPECT_EQ(computation->ToString(options), expected_computation2); } -} // namespace +TEST_F(HloComputationTest, ChannelReachability) { + const Shape shape = ShapeUtil::MakeShape(F32, {5, 7}); + HloComputation::Builder builder("ChannelReachability"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); + auto send = + builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); + auto recv = + builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build(recv_done)); + auto reachability = computation->ComputeReachability(); + EXPECT_TRUE(reachability->IsReachable(param, recv_done)); + EXPECT_FALSE(reachability->IsReachable(send, recv)); + EXPECT_FALSE(reachability->IsReachable(send_done, recv)); +} + +} // namespace } // namespace xla -- cgit v1.2.3