diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 80 |
1 files changed, 79 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index bae78c94bd..70b18ff356 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -321,6 +321,7 @@ void ComputeComputationPostOrder( enum State { kVisiting, kVisited }; void ComputeInstructionPostOrder( + std::map<int64, std::vector<HloInstruction*>> channel_dependency_map, std::vector<HloInstruction*>* post_order, HloInstruction* root, tensorflow::gtl::FlatMap<HloInstruction*, State>* visited) { std::vector<HloInstruction*> dfs_stack; @@ -355,12 +356,67 @@ void ComputeInstructionPostOrder( for (HloInstruction* op : current->control_predecessors()) { dfs_stack.emplace_back(op); } + + // Add inputs for send->recv_done dependencies and cross-replica-sum + // dependencies. + switch (current->opcode()) { + case HloOpcode::kRecvDone: { + const auto& dependencies = + channel_dependency_map[current->channel_id()]; + for (HloInstruction* op : dependencies) { + dfs_stack.emplace_back(op); + } + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = current->all_reduce_id(); + if (all_reduce_id) { + const auto& dependencies = + channel_dependency_map[all_reduce_id.value()]; + for (HloInstruction* op : dependencies) { + dfs_stack.emplace_back(op); + } + } + break; + } + default: + break; + } } } } // namespace +std::map<int64, std::vector<HloInstruction*>> +HloComputation::ComputeChannelDependencies() const { + std::map<int64, std::vector<HloInstruction*>> channel_dependency_map; + for (const auto& instruction : instructions_) { + switch (instruction->opcode()) { + case HloOpcode::kSend: { + channel_dependency_map[instruction->channel_id()].push_back( + instruction.get()); + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = instruction->all_reduce_id(); + if (all_reduce_id) { + auto& dependencies = channel_dependency_map[all_reduce_id.value()]; + absl::c_copy(instruction->operands(), + std::back_inserter(dependencies)); + absl::c_copy(instruction->control_predecessors(), + std::back_inserter(dependencies)); + } + break; + } + default: + break; + } + } + return channel_dependency_map; +} + std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { + auto channel_dependency_map = ComputeChannelDependencies(); std::vector<HloInstruction*> post_order; post_order.reserve(instruction_count()); std::vector<HloInstruction*> trace_instructions; @@ -372,7 +428,8 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - ComputeInstructionPostOrder(&post_order, instruction.get(), &visited); + ComputeInstructionPostOrder(channel_dependency_map, &post_order, + instruction.get(), &visited); } } post_order.insert(post_order.end(), trace_instructions.begin(), @@ -676,12 +733,33 @@ std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability() const { const auto& all = MakeInstructionPostOrder(); auto result = absl::make_unique<HloReachabilityMap>(all); + auto channel_dependency_map = ComputeChannelDependencies(); std::vector<HloInstruction*> inputs; for (const HloInstruction* hlo : all) { inputs.assign(hlo->operands().begin(), hlo->operands().end()); inputs.insert(inputs.end(), hlo->control_predecessors().begin(), hlo->control_predecessors().end()); + + switch (hlo->opcode()) { + case HloOpcode::kRecvDone: { + const auto& dependencies = channel_dependency_map[hlo->channel_id()]; + absl::c_copy(dependencies, std::back_inserter(inputs)); + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = hlo->all_reduce_id(); + if (all_reduce_id) { + const auto& dependencies = + channel_dependency_map[all_reduce_id.value()]; + absl::c_copy(dependencies, std::back_inserter(inputs)); + } + break; + } + default: + break; + } + result->FastSetReachabilityToUnion(inputs, hlo); } return result; |