aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_ordering.cc
blob: 23d41d91d6969ddf9062507e926ae39c1e1315d4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/service/hlo_ordering.h"

#include <utility>
#include <vector>

#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"

namespace xla {

bool HloOrdering::ExecutesBefore(const HloInstruction* a,
                                 const HloInstruction* b) const {
  // 'a' and 'b' may be in different computations. In this case, find the
  // callgraph ancestor instructions which call (potentially transitively) the
  // computations containing 'a' and 'b' and use these ancestor instructions to
  // compare order.
  const HloInstruction* a_ancestor;
  const HloInstruction* b_ancestor;
  std::tie(a_ancestor, b_ancestor) =
      call_graph_->NearestAncestorsInSameComputation(
          const_cast<HloInstruction*>(a), const_cast<HloInstruction*>(b));

  if (a_ancestor == nullptr) {
    // Ancestors in a common computation could not be found so consider the
    // instructions 'a' and 'b' to be unordered.
    return false;
  }
  // a_ancestor and b_ancestor must be either both null or both non-null.
  CHECK_NE(b_ancestor, nullptr);
  CHECK_EQ(a_ancestor->parent(), b_ancestor->parent());

  // If the common ancestor is a while instruction there is an additional
  // ordering criteria which may apply. The condition computation is considered
  // to execute before the body computation so if 'a' is in the condition and
  // 'b' is in the body, then 'a' executes before 'b'.
  if (a_ancestor == b_ancestor && a_ancestor->opcode() == HloOpcode::kWhile) {
    const HloComputation* body = a_ancestor->while_body();
    const HloComputation* condition = a_ancestor->while_condition();
    if (call_graph_->InstructionIsNestedIn(a, condition) &&
        call_graph_->InstructionIsNestedIn(b, body)) {
      return true;
    }
  }

  // If the common ancestor is a conditional instruction, even though the true
  // and false computations are not really ordered per-se, we define the true
  // computation to be ordered before the false one.
  // This ensures that buffers can still be shared among the two computations
  // as they will forcibly have disjoint liveness.
  if (a_ancestor == b_ancestor &&
      a_ancestor->opcode() == HloOpcode::kConditional) {
    const HloComputation* true_computation = a_ancestor->true_computation();
    const HloComputation* false_computation = a_ancestor->false_computation();
    if (call_graph_->InstructionIsNestedIn(a, true_computation) &&
        call_graph_->InstructionIsNestedIn(b, false_computation)) {
      return true;
    }
    // If 'b' is the conditional ancestor, and 'a' is within the true or false
    // computations, 'a' executes before 'b'.
    if (b == a_ancestor &&
        (call_graph_->InstructionIsNestedIn(a, true_computation) ||
         call_graph_->InstructionIsNestedIn(a, false_computation))) {
      return true;
    }
  }

  return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor);
}

bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
  // Entry parameter should always be defined before other instructions.
  const HloModule* module = b.defining_instruction()->parent()->parent();
  if (b.defining_instruction()->parent() == module->entry_computation() &&
      b.defining_instruction()->opcode() == HloOpcode::kParameter) {
    return false;
  }

  if (a.defining_instruction()->parent() == module->entry_computation() &&
      a.defining_instruction()->opcode() == HloOpcode::kParameter) {
    return true;
  }

  // Phi values require special handling. Because XLA does not have a phi
  // instruction, the definition instruction of the phis values are
  // placeholders: either the subcomputation parameter (body or condition) or
  // the while instruction. However, the program point where these values are
  // logically defined does not necessarily coincide exactly with program point
  // of these place-holder instructions. So we explicitly define the following
  // order for phi values:
  //
  //   body/condition parameter phi:
  //     Defined before all values defined in its computation excepting other
  //     phis.
  //
  //   while phi:
  //     defined after all values defined in the condition or body.
  //
  auto is_body_or_condition_phi = [](const HloValue& v) {
    return v.is_phi() &&
           v.defining_instruction()->opcode() == HloOpcode::kParameter;
  };
  if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) &&
      call_graph_->InstructionIsNestedIn(b.defining_instruction(),
                                         a.defining_instruction()->parent())) {
    return true;
  }
  if (is_body_or_condition_phi(b) &&
      call_graph_->InstructionIsNestedIn(a.defining_instruction(),
                                         b.defining_instruction()->parent())) {
    return false;
  }

  // If 'b' is a while phi and 'a' is in the body or condition, then 'a'
  // executes before 'b'.
  if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile &&
      (call_graph_->InstructionIsNestedIn(
           a.defining_instruction(), b.defining_instruction()->while_body()) ||
       call_graph_->InstructionIsNestedIn(
           a.defining_instruction(),
           b.defining_instruction()->while_condition()))) {
    return true;
  }
  // If 'b' is a conditional phi and 'a' is in the true or false computation,
  // then 'a' executes before 'b'.
  if (b.is_phi() &&
      b.defining_instruction()->opcode() == HloOpcode::kConditional &&
      (call_graph_->InstructionIsNestedIn(
           a.defining_instruction(),
           b.defining_instruction()->true_computation()) ||
       call_graph_->InstructionIsNestedIn(
           a.defining_instruction(),
           b.defining_instruction()->false_computation()))) {
    return true;
  }
  return ExecutesBefore(a.defining_instruction(), b.defining_instruction());
}

/* static */
bool HloOrdering::UseIsBeforeValueDefinition(
    const HloUse& use, const HloValue& value,
    const HloDataflowAnalysis& dataflow) const {
  VLOG(4) << "UseIsBeforeValueDefinition(use=" << use
          << ", value=" << value.ToShortString() << ")";
  if (ExecutesBefore(use.instruction, value.defining_instruction())) {
    VLOG(4) << "  use instruction executes before value-defining instruction";
    return true;
  }

  // If the use is at the instruction where the value is defined, then the use
  // is before the def if the instruction allows buffer sharing (in place
  // computation).
  if (use.instruction == value.defining_instruction() &&
      dataflow.CanShareOperandBufferWithUser(
          use.instruction->mutable_operand(use.operand_number),
          use.operand_index, value.defining_instruction(),
          value.defining_index())) {
    VLOG(4) << "  use is value def, and instruction can share use buffer";
    return true;
  }

  // The use at a while is an input to a phi, and logically occurs before values
  // are defined in the body or condition computations.
  if (use.instruction->opcode() == HloOpcode::kWhile) {
    const HloInstruction* xla_while = use.instruction;
    if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
                                           xla_while->while_body()) ||
        call_graph_->InstructionIsNestedIn(value.defining_instruction(),
                                           xla_while->while_condition())) {
      VLOG(4) << "  use is while " << use.instruction->name()
              << " and def is in condition or body";
      return true;
    }
  }

  // Similarly if the value is defined at a while, it logically occurs after any
  // uses in the body or condition computations.
  if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
    CHECK(value.is_phi());
    const HloInstruction* xla_while = value.defining_instruction();
    if (call_graph_->InstructionIsNestedIn(use.instruction,
                                           xla_while->while_body()) ||
        call_graph_->InstructionIsNestedIn(use.instruction,
                                           xla_while->while_condition())) {
      VLOG(4) << "  value is while " << value.defining_instruction()->name()
              << " and use is in condition or body";
      return true;
    }
  }

  // The use at a call occurs before values that are defined in the called
  // computation.
  if (use.instruction->opcode() == HloOpcode::kCall) {
    const HloInstruction* call = use.instruction;
    if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
                                           call->to_apply())) {
      VLOG(4) << "  use is call " << use.instruction->name()
              << " and def is in called computation";
      return true;
    }
  }

  if (use.instruction->opcode() == HloOpcode::kConditional) {
    const HloInstruction* conditional = use.instruction;
    if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
                                           conditional->true_computation())) {
      VLOG(4) << "  use is conditional " << use.instruction->name()
              << " and def is in TRUE computation";
      return true;
    }
    if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
                                           conditional->false_computation())) {
      VLOG(4) << "  use is conditional " << use.instruction->name()
              << " and def is in FALSE computation";
      return true;
    }
    if (value.defining_instruction() == use.instruction) {
      VLOG(4) << "  use is conditional " << use << " and def is "
              << value.ToShortString();
      return true;
    }
  }

  VLOG(4) << "  use is not before value";
  return false;
}

bool HloOrdering::LiveRangeStrictlyBefore(
    const HloValue& a, const HloValue& b,
    const HloDataflowAnalysis& dataflow) const {
  VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
          << ", b = " << b.ToShortString() << ")";
  if (!IsDefinedBefore(a, b)) {
    VLOG(4) << a << " not defined before " << b;
    return false;
  }

  if (a.live_out_of_module()) {
    VLOG(4) << a << " is live out of module and defined before " << b;
    return false;
  }

  // All uses of 'a' must be before 'b' is defined.
  for (const HloUse& use : a.uses()) {
    if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(),
                                         use.instruction)) {
      continue;
    }
    if (!UseIsBeforeValueDefinition(use, b, dataflow)) {
      VLOG(4) << "use of " << a << " (" << use << ") not before " << b
              << " is defined";
      return false;
    }
  }

  if (a.instruction()->parent() == b.instruction()->parent()) {
    for (const HloPosition& position : a.positions()) {
      if (position.instruction ==
          a.instruction()->parent()->root_instruction()) {
        VLOG(4) << a << " is live out of computation and defined before " << b
                << " which is in same computation";
        return false;
      }
    }
  }

  return true;
}

bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
                               const HloDataflowAnalysis& dataflow) const {
  // Buffers without disjoint liveness may interfere.
  return !LiveRangeStrictlyBefore(a, b, dataflow) &&
         !LiveRangeStrictlyBefore(b, a, dataflow);
}

PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
    : HloOrdering(module) {}

bool PredecessorHloOrdering::ExecutesBeforeInSameComputation(
    const HloInstruction* a, const HloInstruction* b) const {
  CHECK_EQ(a->parent(), b->parent());

  // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'.
  return a != b && predecessors_.at(a->parent())->IsReachable(a, b);
}

string PredecessorHloOrdering::ToStringHelper(const string& name) const {
  std::vector<string> pieces;
  pieces.push_back(name);
  for (auto* computation : module_->MakeNonfusionComputations()) {
    pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
    const auto all = computation->MakeInstructionPostOrder();
    for (auto instruction : all) {
      pieces.push_back(
          absl::StrFormat("  %s predecessors:", instruction->name()));
      for (auto predecessor : all) {
        if (predecessors_.at(computation)
                ->IsReachable(predecessor, instruction)) {
          pieces.push_back(absl::StrFormat("    %s", predecessor->name()));
        }
      }
    }
  }
  return absl::StrJoin(pieces, "\n");
}

DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
    : PredecessorHloOrdering(module) {
  // Compute predecessor relationships between all instructions to determine
  // ordering based on dependencies. ExecutesBefore will return true iff there
  // exists a path in the HLO computation graph from 'a' to 'b'.
  for (auto* computation : module->MakeNonfusionComputations()) {
    predecessors_.emplace(computation, computation->ComputeReachability());
  }
}

string DependencyHloOrdering::ToString() const {
  return ToStringHelper("DependencyHloOrdering");
}

SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule)
    : HloOrdering(schedule.module()), schedule_(schedule) {
  Initialize();
}

SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule)
    : HloOrdering(schedule.module()), schedule_(std::move(schedule)) {
  Initialize();
}

void SequentialHloOrdering::Initialize() {
  // Create a map from instruction to its order position.
  TF_DCHECK_OK(schedule_.Verify());
  for (const auto& computation_sequence : schedule_.sequences()) {
    const std::vector<const HloInstruction*>& order =
        computation_sequence.second.instructions();
    for (int i = 0; i < order.size(); ++i) {
      InsertOrDie(&order_position_, order[i], i);
    }
  }
}

bool SequentialHloOrdering::ExecutesBeforeInSameComputation(
    const HloInstruction* a, const HloInstruction* b) const {
  CHECK_EQ(a->parent(), b->parent());
  // If either instruction is not in the order, then 'a' and 'b' are unordered.
  if (order_position_.count(a) == 0 || order_position_.count(b) == 0) {
    return false;
  }
  return order_position_.at(a) < order_position_.at(b);
}

const std::vector<const HloInstruction*>*
SequentialHloOrdering::SequentialOrder(
    const HloComputation& computation) const {
  return schedule_.is_computation_scheduled(&computation)
             ? &schedule_.sequence(&computation).instructions()
             : nullptr;
}

string SequentialHloOrdering::ToString() const {
  return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString());
}

}  // namespace xla