aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_module_group_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_module_group_util.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.cc72
1 files changed, 51 insertions, 21 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
index 9fd0ade153..0dc5676148 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -37,24 +38,38 @@ namespace xla {
std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
HloInstruction* instruction) {
- std::vector<HloInstruction*> predecessors;
-
- // Adds to the unique predecessors list and also add companion instructions
- // if the given predecessor has those.
+ std::vector<HloInstruction*>
+ predecessors; // Use a vector to avoid non-determinism.
+ tensorflow::gtl::FlatSet<HloInstruction*> unique;
+
+ // Adds to the unique predecessors list; if the predecessors is a companion
+ // instruction, also add companion instructions; if the predecessors is a
+ // cross-module all-reduce, also add the all-reduce instructions in the same
+ // group.
auto add_unique_predecessor = [&](HloInstruction* predecessor) {
- if (std::find(predecessors.begin(), predecessors.end(), predecessor) !=
- predecessors.end()) {
+ if (unique.find(predecessor) != unique.end()) {
return;
}
- if (!metadata_.IsCompanionInstruction(predecessor)) {
- predecessors.push_back(predecessor);
+ if (metadata_.IsCompanionInstruction(predecessor)) {
+ for (HloInstruction* instr : metadata_.Companions(predecessor)) {
+ if (unique.insert(instr).second) {
+ predecessors.push_back(instr);
+ }
+ }
return;
}
- for (HloInstruction* companion : metadata_.Companions(predecessor)) {
- predecessors.push_back(companion);
+ if (predecessor->IsCrossModuleAllReduce()) {
+ for (HloInstruction* instr :
+ metadata_.GetAllReduceGroup(*predecessor->all_reduce_id())) {
+ if (unique.insert(instr).second) {
+ predecessors.push_back(instr);
+ }
+ }
+ return;
}
+ unique.insert(predecessor);
+ predecessors.push_back(predecessor);
};
-
// If the given instruction is a companion instruction, we need to find the
// predecessors of all of its companion instructions. If the instruction is an
// all-reduce, we need to find the predecessors of all the peer all-reduce
@@ -98,22 +113,37 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
HloInstruction* instruction) {
- std::vector<HloInstruction*> successors;
-
- // Adds to the unique successors list and also add companion instructions
- // if the given successor has those.
+ std::vector<HloInstruction*>
+ successors; // Use a vector to avoid non-determinism.
+ tensorflow::gtl::FlatSet<HloInstruction*> unique;
+
+ // Adds to the unique successors list; if the successor is a companion
+ // instruction, also add companion instructions; if the successor is a
+ // cross-module all-reduce, also add the all-reduce instructions in the same
+ // group.
auto add_unique_successor = [&](HloInstruction* successor) {
- if (std::find(successors.begin(), successors.end(), successor) !=
- successors.end()) {
+ if (unique.find(successor) != unique.end()) {
return;
}
- if (!metadata_.IsCompanionInstruction(successor)) {
- successors.push_back(successor);
+ if (metadata_.IsCompanionInstruction(successor)) {
+ for (HloInstruction* instr : metadata_.Companions(successor)) {
+ if (unique.insert(instr).second) {
+ successors.push_back(instr);
+ }
+ }
return;
}
- for (HloInstruction* companion : metadata_.Companions(successor)) {
- successors.push_back(companion);
+ if (successor->IsCrossModuleAllReduce()) {
+ for (HloInstruction* instr :
+ metadata_.GetAllReduceGroup(*successor->all_reduce_id())) {
+ if (unique.insert(instr).second) {
+ successors.push_back(instr);
+ }
+ }
+ return;
}
+ unique.insert(successor);
+ successors.push_back(successor);
};
// If the given instruction is a companion instruction, we need to find the