aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_module_group_util.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-15 16:04:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-15 16:07:40 -0700
commit6c3c766dcabff3b5fa41dbfd491c9e8062a77b07 (patch)
tree744fa2bc46e9b446e5645bd4b6962ce641898fd6 /tensorflow/compiler/xla/service/hlo_module_group_util.cc
parente5945c00148186808e337b4946cf0fa6460f6803 (diff)
[XLA] Enable the semantic for cross-modeul AllReduce.
PiperOrigin-RevId: 204670087
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_module_group_util.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.cc24
1 files changed, 18 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
index df1d562048..9fd0ade153 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
@@ -56,12 +56,17 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
};
// If the given instruction is a companion instruction, we need to find the
- // predecessors of all of its companion instructions.
+ // predecessors of all of its companion instructions. If the instruction is an
+ // all-reduce, we need to find the predecessors of all the peer all-reduce
+ // instructions.
std::vector<HloInstruction*> instruction_group;
if (metadata_.IsCompanionInstruction(instruction)) {
for (HloInstruction* companion : metadata_.Companions(instruction)) {
instruction_group.push_back(companion);
}
+ } else if (instruction->IsCrossModuleAllReduce()) {
+ instruction_group =
+ metadata_.GetAllReduceGroup(*instruction->all_reduce_id());
} else {
instruction_group.push_back(instruction);
}
@@ -112,12 +117,17 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
};
// If the given instruction is a companion instruction, we need to find the
- // successors of all of its companion instructions.
+ // successors of all of its companion instructions. If the instruction is an
+ // all-reduce, we need to find the successors of all its peer all-reduce
+ // instructions.
std::vector<HloInstruction*> instruction_group;
if (metadata_.IsCompanionInstruction(instruction)) {
for (HloInstruction* companion : metadata_.Companions(instruction)) {
instruction_group.push_back(companion);
}
+ } else if (instruction->IsCrossModuleAllReduce()) {
+ instruction_group =
+ metadata_.GetAllReduceGroup(*instruction->all_reduce_id());
} else {
instruction_group.push_back(instruction);
}
@@ -170,15 +180,17 @@ Status HloModuleGroupUtil::VisitTopologicalOrder(
HloInstruction* hlo = stack.top();
// Find the instruction group of the currently visited instruction. The
- // instruction group represents all companion instructions of the
- // current instruction, and are considered to be a single entity for the
- // purpose of the traversal (i.e., they must always be in the same visit
- // state).
+ // instruction group represents all companion instructions of the current
+ // instruction, or all the all-reduce instructions that belong to the same
+ // group, or are considered to be a single entity for the purpose of the
+ // traversal (i.e., they must always be in the same visit state).
std::vector<HloInstruction*> instruction_group;
if (metadata_.IsCompanionInstruction(hlo)) {
for (HloInstruction* companion : metadata_.Companions(hlo)) {
instruction_group.push_back(companion);
}
+ } else if (hlo->IsCrossModuleAllReduce()) {
+ instruction_group = metadata_.GetAllReduceGroup(*hlo->all_reduce_id());
} else {
instruction_group.push_back(hlo);
}