diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-15 16:04:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-15 16:07:40 -0700 |
commit | 6c3c766dcabff3b5fa41dbfd491c9e8062a77b07 (patch) | |
tree | 744fa2bc46e9b446e5645bd4b6962ce641898fd6 /tensorflow/compiler/xla/service/hlo_module_group_util.cc | |
parent | e5945c00148186808e337b4946cf0fa6460f6803 (diff) |
[XLA] Enable the semantic for cross-modeul AllReduce.
PiperOrigin-RevId: 204670087
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_module_group_util.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_module_group_util.cc | 24 |
1 files changed, 18 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index df1d562048..9fd0ade153 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -56,12 +56,17 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors( }; // If the given instruction is a companion instruction, we need to find the - // predecessors of all of its companion instructions. + // predecessors of all of its companion instructions. If the instruction is an + // all-reduce, we need to find the predecessors of all the peer all-reduce + // instructions. std::vector<HloInstruction*> instruction_group; if (metadata_.IsCompanionInstruction(instruction)) { for (HloInstruction* companion : metadata_.Companions(instruction)) { instruction_group.push_back(companion); } + } else if (instruction->IsCrossModuleAllReduce()) { + instruction_group = + metadata_.GetAllReduceGroup(*instruction->all_reduce_id()); } else { instruction_group.push_back(instruction); } @@ -112,12 +117,17 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors( }; // If the given instruction is a companion instruction, we need to find the - // successors of all of its companion instructions. + // successors of all of its companion instructions. If the instruction is an + // all-reduce, we need to find the successors of all its peer all-reduce + // instructions. std::vector<HloInstruction*> instruction_group; if (metadata_.IsCompanionInstruction(instruction)) { for (HloInstruction* companion : metadata_.Companions(instruction)) { instruction_group.push_back(companion); } + } else if (instruction->IsCrossModuleAllReduce()) { + instruction_group = + metadata_.GetAllReduceGroup(*instruction->all_reduce_id()); } else { instruction_group.push_back(instruction); } @@ -170,15 +180,17 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( HloInstruction* hlo = stack.top(); // Find the instruction group of the currently visited instruction. The - // instruction group represents all companion instructions of the - // current instruction, and are considered to be a single entity for the - // purpose of the traversal (i.e., they must always be in the same visit - // state). + // instruction group represents all companion instructions of the current + // instruction, or all the all-reduce instructions that belong to the same + // group, or are considered to be a single entity for the purpose of the + // traversal (i.e., they must always be in the same visit state). std::vector<HloInstruction*> instruction_group; if (metadata_.IsCompanionInstruction(hlo)) { for (HloInstruction* companion : metadata_.Companions(hlo)) { instruction_group.push_back(companion); } + } else if (hlo->IsCrossModuleAllReduce()) { + instruction_group = metadata_.GetAllReduceGroup(*hlo->all_reduce_id()); } else { instruction_group.push_back(hlo); } |