diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-21 07:46:55 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-21 07:50:58 -0700 |
commit | aeab291563b0b4cc75c0f5fc73610a6595780570 (patch) | |
tree | 8b5e9cc861551cec0f4ba0cfcfa1b0e788109ae5 /tensorflow/compiler/xla/service/hlo_computation.cc | |
parent | ad4018ebb6b7f8ee4a38a7a0059bdbad28557732 (diff) |
Handle communicating instructions in HloComputation::ComputeReachability
Send&recv instructions and cross-replica-sum instructions are imposing
extra dependencies via the channel id or all reduce id. This CL teaches
the reachability calculation logic in hlo computation to correctly
account for these "invisible" dependencies.
The main purpose is to stop multi output fusion from generating
dependency cyclies via communicating instructions.
PiperOrigin-RevId: 209593997
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 80 |
1 files changed, 79 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index bae78c94bd..70b18ff356 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -321,6 +321,7 @@ void ComputeComputationPostOrder( enum State { kVisiting, kVisited }; void ComputeInstructionPostOrder( + std::map<int64, std::vector<HloInstruction*>> channel_dependency_map, std::vector<HloInstruction*>* post_order, HloInstruction* root, tensorflow::gtl::FlatMap<HloInstruction*, State>* visited) { std::vector<HloInstruction*> dfs_stack; @@ -355,12 +356,67 @@ void ComputeInstructionPostOrder( for (HloInstruction* op : current->control_predecessors()) { dfs_stack.emplace_back(op); } + + // Add inputs for send->recv_done dependencies and cross-replica-sum + // dependencies. + switch (current->opcode()) { + case HloOpcode::kRecvDone: { + const auto& dependencies = + channel_dependency_map[current->channel_id()]; + for (HloInstruction* op : dependencies) { + dfs_stack.emplace_back(op); + } + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = current->all_reduce_id(); + if (all_reduce_id) { + const auto& dependencies = + channel_dependency_map[all_reduce_id.value()]; + for (HloInstruction* op : dependencies) { + dfs_stack.emplace_back(op); + } + } + break; + } + default: + break; + } } } } // namespace +std::map<int64, std::vector<HloInstruction*>> +HloComputation::ComputeChannelDependencies() const { + std::map<int64, std::vector<HloInstruction*>> channel_dependency_map; + for (const auto& instruction : instructions_) { + switch (instruction->opcode()) { + case HloOpcode::kSend: { + channel_dependency_map[instruction->channel_id()].push_back( + instruction.get()); + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = instruction->all_reduce_id(); + if (all_reduce_id) { + auto& dependencies = channel_dependency_map[all_reduce_id.value()]; + absl::c_copy(instruction->operands(), + std::back_inserter(dependencies)); + absl::c_copy(instruction->control_predecessors(), + std::back_inserter(dependencies)); + } + break; + } + default: + break; + } + } + return channel_dependency_map; +} + std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { + auto channel_dependency_map = ComputeChannelDependencies(); std::vector<HloInstruction*> post_order; post_order.reserve(instruction_count()); std::vector<HloInstruction*> trace_instructions; @@ -372,7 +428,8 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - ComputeInstructionPostOrder(&post_order, instruction.get(), &visited); + ComputeInstructionPostOrder(channel_dependency_map, &post_order, + instruction.get(), &visited); } } post_order.insert(post_order.end(), trace_instructions.begin(), @@ -676,12 +733,33 @@ std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability() const { const auto& all = MakeInstructionPostOrder(); auto result = absl::make_unique<HloReachabilityMap>(all); + auto channel_dependency_map = ComputeChannelDependencies(); std::vector<HloInstruction*> inputs; for (const HloInstruction* hlo : all) { inputs.assign(hlo->operands().begin(), hlo->operands().end()); inputs.insert(inputs.end(), hlo->control_predecessors().begin(), hlo->control_predecessors().end()); + + switch (hlo->opcode()) { + case HloOpcode::kRecvDone: { + const auto& dependencies = channel_dependency_map[hlo->channel_id()]; + absl::c_copy(dependencies, std::back_inserter(inputs)); + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = hlo->all_reduce_id(); + if (all_reduce_id) { + const auto& dependencies = + channel_dependency_map[all_reduce_id.value()]; + absl::c_copy(dependencies, std::back_inserter(inputs)); + } + break; + } + default: + break; + } + result->FastSetReachabilityToUnion(inputs, hlo); } return result; |