aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_module_group_util.cc
blob: fddeb5f0a27a43ff9ca8b2b5d314bcfe91aaf0e6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/service/hlo_module_group_util.h"

#include <algorithm>
#include <list>
#include <queue>
#include <stack>
#include <string>
#include <utility>

#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"

namespace xla {

std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
    HloInstruction* instruction) {
  std::vector<HloInstruction*>
      predecessors;  // Use a vector to avoid non-determinism.
  absl::flat_hash_set<HloInstruction*> unique;

  // Adds to the unique predecessors list; if the predecessors is a companion
  // instruction, also add companion instructions; if the predecessors is a
  // cross-module all-reduce, also add the all-reduce instructions in the same
  // group.
  auto add_unique_predecessor = [&](HloInstruction* predecessor) {
    if (unique.find(predecessor) != unique.end()) {
      return;
    }
    if (metadata_.IsCompanionInstruction(predecessor)) {
      for (HloInstruction* instr : metadata_.Companions(predecessor)) {
        if (unique.insert(instr).second) {
          predecessors.push_back(instr);
        }
      }
      return;
    }
    if (predecessor->IsCrossModuleAllReduce()) {
      for (HloInstruction* instr :
           metadata_.GetAllReduceGroup(*predecessor->all_reduce_id())) {
        if (unique.insert(instr).second) {
          predecessors.push_back(instr);
        }
      }
      return;
    }
    unique.insert(predecessor);
    predecessors.push_back(predecessor);
  };
  // If the given instruction is a companion instruction, we need to find the
  // predecessors of all of its companion instructions. If the instruction is an
  // all-reduce, we need to find the predecessors of all the peer all-reduce
  // instructions.
  std::vector<HloInstruction*> instruction_group;
  if (metadata_.IsCompanionInstruction(instruction)) {
    for (HloInstruction* companion : metadata_.Companions(instruction)) {
      instruction_group.push_back(companion);
    }
  } else if (instruction->IsCrossModuleAllReduce()) {
    instruction_group =
        metadata_.GetAllReduceGroup(*instruction->all_reduce_id());
  } else {
    instruction_group.push_back(instruction);
  }

  for (HloInstruction* hlo : instruction_group) {
    for (HloInstruction* operand : hlo->operands()) {
      add_unique_predecessor(operand);
    }
    for (HloInstruction* control_predecessor : hlo->control_predecessors()) {
      add_unique_predecessor(control_predecessor);
    }
  }
  if (instruction->opcode() == HloOpcode::kRecvDone &&
      !DynCast<HloRecvDoneInstruction>(instruction)->is_host_transfer()) {
    // Send is a remote predecessor of RecvDone.
    HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
    add_unique_predecessor(send);
  }
  if (instruction->opcode() == HloOpcode::kSend &&
      !DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
    // Recv is a remote predecessor of Send.
    HloInstruction* recv_done =
        metadata_.GetChannel(instruction->channel_id()).recv_done;
    CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
    CHECK_EQ(recv_done->operand_count(), 1);
    HloInstruction* recv = recv_done->mutable_operand(0);
    add_unique_predecessor(recv);
  }
  return predecessors;
}

std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
    HloInstruction* instruction) {
  std::vector<HloInstruction*>
      successors;  // Use a vector to avoid non-determinism.
  absl::flat_hash_set<HloInstruction*> unique;

  // Adds to the unique successors list; if the successor is a companion
  // instruction, also add companion instructions; if the successor is a
  // cross-module all-reduce, also add the all-reduce instructions in the same
  // group.
  auto add_unique_successor = [&](HloInstruction* successor) {
    if (unique.find(successor) != unique.end()) {
      return;
    }
    if (metadata_.IsCompanionInstruction(successor)) {
      for (HloInstruction* instr : metadata_.Companions(successor)) {
        if (unique.insert(instr).second) {
          successors.push_back(instr);
        }
      }
      return;
    }
    if (successor->IsCrossModuleAllReduce()) {
      for (HloInstruction* instr :
           metadata_.GetAllReduceGroup(*successor->all_reduce_id())) {
        if (unique.insert(instr).second) {
          successors.push_back(instr);
        }
      }
      return;
    }
    unique.insert(successor);
    successors.push_back(successor);
  };

  // If the given instruction is a companion instruction, we need to find the
  // successors of all of its companion instructions. If the instruction is an
  // all-reduce, we need to find the successors of all its peer all-reduce
  // instructions.
  std::vector<HloInstruction*> instruction_group;
  if (metadata_.IsCompanionInstruction(instruction)) {
    for (HloInstruction* companion : metadata_.Companions(instruction)) {
      instruction_group.push_back(companion);
    }
  } else if (instruction->IsCrossModuleAllReduce()) {
    instruction_group =
        metadata_.GetAllReduceGroup(*instruction->all_reduce_id());
  } else {
    instruction_group.push_back(instruction);
  }

  for (HloInstruction* hlo : instruction_group) {
    for (HloInstruction* user : hlo->users()) {
      add_unique_successor(user);
    }
    for (HloInstruction* control_successor : hlo->control_successors()) {
      add_unique_successor(control_successor);
    }
  }
  if (instruction->opcode() == HloOpcode::kRecv &&
      !DynCast<HloRecvInstruction>(instruction)->is_host_transfer()) {
    // Send is a remote successor of Recv.
    const HloInstruction* recv_done = instruction->users().front();
    CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
    HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
    add_unique_successor(send);
  }
  if (instruction->opcode() == HloOpcode::kSend &&
      !DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
    // RecvDone is a remote successor of Send.
    HloInstruction* recv_done =
        metadata_.GetChannel(instruction->channel_id()).recv_done;
    add_unique_successor(recv_done);
  }
  return successors;
}

std::vector<HloInstruction*> HloModuleGroupUtil::RootInstructions(
    absl::Span<HloComputation* const> computations) {
  std::vector<HloInstruction*> roots;
  for (HloComputation* computation : computations) {
    for (HloInstruction* instruction : computation->instructions()) {
      if (GlobalSuccessors(instruction).empty()) {
        roots.push_back(instruction);
      }
    }
  }
  return roots;
}

Status HloModuleGroupUtil::VisitTopologicalOrder(
    VisitStates* visit_state, const VisitFunction& visit_function,
    HloInstruction* root) {
  // Stack of HLO instructions visited in DFS order.
  std::stack<HloInstruction*> stack;
  stack.push(root);

  while (!stack.empty()) {
    HloInstruction* hlo = stack.top();

    // Find the instruction group of the currently visited instruction. The
    // instruction group represents all companion instructions of the current
    // instruction, or all the all-reduce instructions that belong to the same
    // group, or are considered to be a single entity for the purpose of the
    // traversal (i.e., they must always be in the same visit state).
    std::vector<HloInstruction*> instruction_group;
    if (metadata_.IsCompanionInstruction(hlo)) {
      for (HloInstruction* companion : metadata_.Companions(hlo)) {
        instruction_group.push_back(companion);
      }
    } else if (hlo->IsCrossModuleAllReduce()) {
      instruction_group = metadata_.GetAllReduceGroup(*hlo->all_reduce_id());
    } else {
      instruction_group.push_back(hlo);
    }

    if ((*visit_state)[hlo] == VisitState::kVisited) {
      // All instructions in the group must be in the same state.
      for (HloInstruction* instruction : instruction_group) {
        TF_RET_CHECK((*visit_state)[instruction] == VisitState::kVisited);
      }
      stack.pop();
      continue;
    }

    if ((*visit_state)[hlo] == VisitState::kVisiting) {
      TF_RETURN_IF_ERROR(visit_function(hlo, instruction_group));

      // Set the visit state of all instructions in the group to kVisited.
      for (HloInstruction* instruction : instruction_group) {
        TF_RET_CHECK((*visit_state)[instruction] == VisitState::kVisiting);
        (*visit_state)[instruction] = VisitState::kVisited;
      }
      stack.pop();
      continue;
    }

    // Set the visit state of all instructions in the group to kVisiting.
    for (HloInstruction* instruction : instruction_group) {
      TF_RET_CHECK((*visit_state)[instruction] == VisitState::kNotVisited)
          << instruction->ToString();
      (*visit_state)[instruction] = VisitState::kVisiting;
    }

    // For each instruction in the group, visit its predecessors (operands,
    // control predecessors and remote predecessors).
    for (HloInstruction* instruction : instruction_group) {
      for (HloInstruction* predecessor : GlobalPredecessors(instruction)) {
        // Visiting a node that is already being visited implies that there is
        // a cycle. Generate an error with the list of instructions in the
        // cycle.
        if ((*visit_state)[predecessor] == VisitState::kVisiting) {
          string cyclic_instructions;
          for (const auto& state : *visit_state) {
            if (state.second == VisitState::kVisiting) {
              absl::StrAppend(&cyclic_instructions, state.first->ToString(),
                              "\n");
            }
          }
          // TODO(b/64305524): Improve the error message to print out the
          // instructions in a deterministic order that forms the cycle.
          return FailedPrecondition(
              "Cross-computation cycle detected via communicating nodes. The "
              "cycle contains the node %s. The cycle is found among the "
              "following nodes. Note that the order of the nodes is arbitrary "
              "and that the list may include nodes that are not part of the "
              "cycle.\n%s",
              predecessor->ToString(), cyclic_instructions);
        }
        stack.push(predecessor);
      }
    }
  }

  return Status::OK();
}

Status HloModuleGroupUtil::VerifyComputations(
    absl::Span<HloComputation* const> computations) {
  auto visit_function =
      [&](HloInstruction* instruction,
          const std::vector<HloInstruction*>& instruction_group) {
        return Status::OK();
      };
  int64 instructions_count = 0;
  VisitStates visit_states;
  for (HloComputation* computation : computations) {
    // Visit all instructions, and not just from the root instruction of the
    // computation. This allows us to detect dead cycles (i.e., cycles that
    // are not reachable from the root) or to enforce an order for the
    // communication instructions that are not reachable from any roots.
    for (HloInstruction* instruction : computation->instructions()) {
      TF_RETURN_IF_ERROR(
          VisitTopologicalOrder(&visit_states, visit_function, instruction));
    }
    instructions_count += computation->instruction_count();
  }

  // Check if all instructions are visited and are in the visited state.
  TF_RET_CHECK(visit_states.size() == instructions_count);
  for (auto& state : visit_states) {
    TF_RET_CHECK(state.second == VisitState::kVisited);
  }

  return Status::OK();
}

StatusOr<std::unique_ptr<HloReachabilityMap>>
HloModuleGroupUtil::ComputeReachability(
    absl::Span<HloComputation* const> computations) {
  std::vector<HloInstruction*> post_order;
  auto visit_function =
      [&](HloInstruction* instruction,
          const std::vector<HloInstruction*>& instruction_group) {
        post_order.insert(post_order.end(), instruction_group.begin(),
                          instruction_group.end());
        return Status::OK();
      };
  HloModuleGroupUtil::VisitStates visit_states;
  for (HloInstruction* root : RootInstructions(computations)) {
    TF_RETURN_IF_ERROR(
        VisitTopologicalOrder(&visit_states, visit_function, root));
  }
  auto reachability = absl::make_unique<HloReachabilityMap>(post_order);
  for (HloInstruction* hlo : post_order) {
    reachability->FastSetReachabilityToUnion(GlobalPredecessors(hlo), hlo);
  }
  return std::move(reachability);
}

void HloModuleGroupUtil::UpdateReachabilityThroughInstruction(
    HloInstruction* instruction, HloReachabilityMap* reachability_map) {
  std::queue<HloInstruction*> worklist;
  worklist.push(instruction);

  while (!worklist.empty()) {
    HloInstruction* item = worklist.front();
    worklist.pop();
    if (reachability_map->SetReachabilityToUnion(GlobalPredecessors(item),
                                                 item)) {
      for (HloInstruction* successor : GlobalSuccessors(item)) {
        worklist.push(successor);
      }
    }
  }
}

}  // namespace xla