diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_module_group_metadata.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_module_group_metadata.cc | 57 |
1 files changed, 51 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 6bcd7b042d..10bf9ffd6c 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -20,6 +20,8 @@ limitations under the License. #include <utility> #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -75,10 +77,23 @@ Status HloModuleGroupMetadata::Build() { if (tracked == nullptr) { return Status::OK(); } - // Add the parent computation of this channel instruction and its peer - // computation (both must be while computations) as companions. + + std::vector<HloComputation*> peers; if (IsChannelInstruction(hlo)) { - HloComputation* peer_computation = PeerComputation(hlo); + peers.push_back(PeerComputation(hlo)); + } else if (hlo->IsCrossModuleAllReduce()) { + for (HloInstruction* instr : GetAllReduceGroup(*hlo->all_reduce_id())) { + if (instr == hlo) { + continue; + } + peers.push_back(instr->parent()); + } + } + + // Add the parent computation of this channel (or all-reduce) instruction + // and its peer computation(s) (both must be while computations) as + // companions. + for (HloComputation* peer_computation : peers) { const TrackedInstruction* peer_tracked = GetTrackedInstruction(peer_computation); TF_RET_CHECK(peer_tracked != nullptr) @@ -162,8 +177,12 @@ bool HloModuleGroupMetadata::IsChannelInstruction( case HloOpcode::kSend: case HloOpcode::kRecv: case HloOpcode::kSendDone: - case HloOpcode::kRecvDone: - return true; + case HloOpcode::kRecvDone: { + const HloSendRecvInstruction* send_recv_instr = + DynCast<HloSendRecvInstruction>(instruction); + CHECK(send_recv_instr != nullptr); + return !send_recv_instr->is_host_transfer(); + } default: return false; } @@ -175,7 +194,8 @@ bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const { bool HloModuleGroupMetadata::InstructionCommunicates( HloInstruction* hlo) const { - return IsChannelInstruction(hlo) || IsCompanionInstruction(hlo); + return IsChannelInstruction(hlo) || IsCompanionInstruction(hlo) || + hlo->IsCrossModuleAllReduce(); } const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel( @@ -200,6 +220,13 @@ HloComputation* HloModuleGroupMetadata::PeerComputation( } } +const std::vector<HloInstruction*>& HloModuleGroupMetadata::GetAllReduceGroup( + int64 all_reduce_id) const { + auto it = all_reduce_map_.find(all_reduce_id); + CHECK(it != all_reduce_map_.end()); + return it->second; +} + std::vector<HloModuleGroupMetadata::TrackedInstruction> HloModuleGroupMetadata::GetCompanionsPath(const HloInstruction* hlo) const { std::vector<TrackedInstruction> path; @@ -278,10 +305,27 @@ Status HloModuleGroupMetadata::RecordInstructions() { tracked_instructions_[hlo->to_apply()] = TrackedInstruction(hlo, ComputationKind::kCallFunction); } + + // Group cross module all-reduce instructions by the all_reduce id. + if (hlo->IsCrossModuleAllReduce()) { + TF_RET_CHECK(channel_id_map_.find(*hlo->all_reduce_id()) == + channel_id_map_.end()) + << "all_reduce_id " << *hlo->all_reduce_id() + << " is already used by a send/recv instruction"; + all_reduce_map_[*hlo->all_reduce_id()].push_back(hlo); + max_channel_id_ = std::max(max_channel_id_, *hlo->all_reduce_id()); + return Status::OK(); + } + if (!IsChannelInstruction(hlo)) { return Status::OK(); } + TF_RET_CHECK(all_reduce_map_.find(hlo->channel_id()) == + all_reduce_map_.end()) + << "channel id " << hlo->channel_id() + << " is already used by an all-reduce instruction"; + // Add a new channel if needed. if (channel_id_map_.find(hlo->channel_id()) == channel_id_map_.end()) { channels_.emplace_back(); @@ -324,6 +368,7 @@ Status HloModuleGroupMetadata::RecordInstructions() { } } VLOG(2) << "Created " << channels_.size() << " channels"; + VLOG(2) << "Created " << all_reduce_map_.size() << " all-reduce groups"; return Status::OK(); } |