diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_module_group_util.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_module_group_util.cc | 72 |
1 files changed, 51 insertions, 21 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 9fd0ade153..0dc5676148 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -37,24 +38,38 @@ namespace xla { std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors( HloInstruction* instruction) { - std::vector<HloInstruction*> predecessors; - - // Adds to the unique predecessors list and also add companion instructions - // if the given predecessor has those. + std::vector<HloInstruction*> + predecessors; // Use a vector to avoid non-determinism. + tensorflow::gtl::FlatSet<HloInstruction*> unique; + + // Adds to the unique predecessors list; if the predecessors is a companion + // instruction, also add companion instructions; if the predecessors is a + // cross-module all-reduce, also add the all-reduce instructions in the same + // group. auto add_unique_predecessor = [&](HloInstruction* predecessor) { - if (std::find(predecessors.begin(), predecessors.end(), predecessor) != - predecessors.end()) { + if (unique.find(predecessor) != unique.end()) { return; } - if (!metadata_.IsCompanionInstruction(predecessor)) { - predecessors.push_back(predecessor); + if (metadata_.IsCompanionInstruction(predecessor)) { + for (HloInstruction* instr : metadata_.Companions(predecessor)) { + if (unique.insert(instr).second) { + predecessors.push_back(instr); + } + } return; } - for (HloInstruction* companion : metadata_.Companions(predecessor)) { - predecessors.push_back(companion); + if (predecessor->IsCrossModuleAllReduce()) { + for (HloInstruction* instr : + metadata_.GetAllReduceGroup(*predecessor->all_reduce_id())) { + if (unique.insert(instr).second) { + predecessors.push_back(instr); + } + } + return; } + unique.insert(predecessor); + predecessors.push_back(predecessor); }; - // If the given instruction is a companion instruction, we need to find the // predecessors of all of its companion instructions. If the instruction is an // all-reduce, we need to find the predecessors of all the peer all-reduce @@ -98,22 +113,37 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors( std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors( HloInstruction* instruction) { - std::vector<HloInstruction*> successors; - - // Adds to the unique successors list and also add companion instructions - // if the given successor has those. + std::vector<HloInstruction*> + successors; // Use a vector to avoid non-determinism. + tensorflow::gtl::FlatSet<HloInstruction*> unique; + + // Adds to the unique successors list; if the successor is a companion + // instruction, also add companion instructions; if the successor is a + // cross-module all-reduce, also add the all-reduce instructions in the same + // group. auto add_unique_successor = [&](HloInstruction* successor) { - if (std::find(successors.begin(), successors.end(), successor) != - successors.end()) { + if (unique.find(successor) != unique.end()) { return; } - if (!metadata_.IsCompanionInstruction(successor)) { - successors.push_back(successor); + if (metadata_.IsCompanionInstruction(successor)) { + for (HloInstruction* instr : metadata_.Companions(successor)) { + if (unique.insert(instr).second) { + successors.push_back(instr); + } + } return; } - for (HloInstruction* companion : metadata_.Companions(successor)) { - successors.push_back(companion); + if (successor->IsCrossModuleAllReduce()) { + for (HloInstruction* instr : + metadata_.GetAllReduceGroup(*successor->all_reduce_id())) { + if (unique.insert(instr).second) { + successors.push_back(instr); + } + } + return; } + unique.insert(successor); + successors.push_back(successor); }; // If the given instruction is a companion instruction, we need to find the |