aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_module_group_metadata.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc57
1 files changed, 51 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index 6bcd7b042d..10bf9ffd6c 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
@@ -75,10 +77,23 @@ Status HloModuleGroupMetadata::Build() {
if (tracked == nullptr) {
return Status::OK();
}
- // Add the parent computation of this channel instruction and its peer
- // computation (both must be while computations) as companions.
+
+ std::vector<HloComputation*> peers;
if (IsChannelInstruction(hlo)) {
- HloComputation* peer_computation = PeerComputation(hlo);
+ peers.push_back(PeerComputation(hlo));
+ } else if (hlo->IsCrossModuleAllReduce()) {
+ for (HloInstruction* instr : GetAllReduceGroup(*hlo->all_reduce_id())) {
+ if (instr == hlo) {
+ continue;
+ }
+ peers.push_back(instr->parent());
+ }
+ }
+
+ // Add the parent computation of this channel (or all-reduce) instruction
+ // and its peer computation(s) (both must be while computations) as
+ // companions.
+ for (HloComputation* peer_computation : peers) {
const TrackedInstruction* peer_tracked =
GetTrackedInstruction(peer_computation);
TF_RET_CHECK(peer_tracked != nullptr)
@@ -162,8 +177,12 @@ bool HloModuleGroupMetadata::IsChannelInstruction(
case HloOpcode::kSend:
case HloOpcode::kRecv:
case HloOpcode::kSendDone:
- case HloOpcode::kRecvDone:
- return true;
+ case HloOpcode::kRecvDone: {
+ const HloSendRecvInstruction* send_recv_instr =
+ DynCast<HloSendRecvInstruction>(instruction);
+ CHECK(send_recv_instr != nullptr);
+ return !send_recv_instr->is_host_transfer();
+ }
default:
return false;
}
@@ -175,7 +194,8 @@ bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const {
bool HloModuleGroupMetadata::InstructionCommunicates(
HloInstruction* hlo) const {
- return IsChannelInstruction(hlo) || IsCompanionInstruction(hlo);
+ return IsChannelInstruction(hlo) || IsCompanionInstruction(hlo) ||
+ hlo->IsCrossModuleAllReduce();
}
const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel(
@@ -200,6 +220,13 @@ HloComputation* HloModuleGroupMetadata::PeerComputation(
}
}
+const std::vector<HloInstruction*>& HloModuleGroupMetadata::GetAllReduceGroup(
+ int64 all_reduce_id) const {
+ auto it = all_reduce_map_.find(all_reduce_id);
+ CHECK(it != all_reduce_map_.end());
+ return it->second;
+}
+
std::vector<HloModuleGroupMetadata::TrackedInstruction>
HloModuleGroupMetadata::GetCompanionsPath(const HloInstruction* hlo) const {
std::vector<TrackedInstruction> path;
@@ -278,10 +305,27 @@ Status HloModuleGroupMetadata::RecordInstructions() {
tracked_instructions_[hlo->to_apply()] =
TrackedInstruction(hlo, ComputationKind::kCallFunction);
}
+
+ // Group cross module all-reduce instructions by the all_reduce id.
+ if (hlo->IsCrossModuleAllReduce()) {
+ TF_RET_CHECK(channel_id_map_.find(*hlo->all_reduce_id()) ==
+ channel_id_map_.end())
+ << "all_reduce_id " << *hlo->all_reduce_id()
+ << " is already used by a send/recv instruction";
+ all_reduce_map_[*hlo->all_reduce_id()].push_back(hlo);
+ max_channel_id_ = std::max(max_channel_id_, *hlo->all_reduce_id());
+ return Status::OK();
+ }
+
if (!IsChannelInstruction(hlo)) {
return Status::OK();
}
+ TF_RET_CHECK(all_reduce_map_.find(hlo->channel_id()) ==
+ all_reduce_map_.end())
+ << "channel id " << hlo->channel_id()
+ << " is already used by an all-reduce instruction";
+
// Add a new channel if needed.
if (channel_id_map_.find(hlo->channel_id()) == channel_id_map_.end()) {
channels_.emplace_back();
@@ -324,6 +368,7 @@ Status HloModuleGroupMetadata::RecordInstructions() {
}
}
VLOG(2) << "Created " << channels_.size() << " channels";
+ VLOG(2) << "Created " << all_reduce_map_.size() << " all-reduce groups";
return Status::OK();
}