aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/multi_output_fusion.cc
blob: 79b5a442aa0ecd0f67ffe4dad50465627d8913fd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/service/multi_output_fusion.h"

#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/types.h"

namespace xla {

StatusOr<bool> MultiOutputFusion::Run(HloModule* module) {
  bool changed = false;

  for (auto* computation : module->MakeNonfusionComputations()) {
    computation_ = computation;
    RecomputeReachability();
    candidates_.clear();
    candidates_index_.clear();
    all_fusion_candidates_.clear();

    int64 index = 0;
    for (auto it : computation_->MakeInstructionPostOrder()) {
      candidates_.emplace_back(it);
      InsertOrDie(&candidates_index_, it, index++);
    }

    // Create the initial candidate list for each Node.
    for (auto& node : candidates_) {
      HloInstruction* instruction = node.hlo;
      int64 instruction_id = get_candidate_id(instruction);
      FusionCandidate& instr_node = candidates_[instruction_id];
      if (!IsFusible(instruction)) {
        continue;
      }
      all_fusion_candidates_.push_back(instruction);

      std::vector<HloInstruction*> candidates;
      tensorflow::gtl::FlatSet<HloInstruction*> candidates_set;
      VLOG(10) << "Looking at instruction: " << instruction->name();
      for (auto operand : instruction->operands()) {
        // Filter out the non-interesting instructions -- they
        // will not generate the savings.
        if (!IsProfitableOperand(operand)) {
          VLOG(10) << "Operand not profitable: " << operand->name();
          continue;
        }
        VLOG(10) << "Operand profitable: " << operand->name();
        for (auto user : operand->users()) {
          VLOG(10) << "User: " << user->name();
          if (user == instruction || !IsFusible(user)) {
            VLOG(10) << "User is not fusible, or is the instruction itself: "
                     << user->name();
            continue;
          }
          int64 user_id = get_candidate_id(user);
          if (is_connected(instruction, user)) {
            VLOG(10) << "User is connected: " << user->name();
            continue;
          }
          if (instruction_id < user_id &&
              user->opcode() == HloOpcode::kFusion) {
            VLOG(10) << "User ID for user: " << user->name() << " is "
                     << user_id << " which is higher than " << instruction_id;
            continue;
          }
          if (!LegalToFuse(instruction, user)) {
            VLOG(10) << "User not legal to fuse: " << user->name();
            continue;
          }
          if (candidates_set.insert(user).second) {
            VLOG(10) << "User added to candidate list: " << user->name();
            candidates.push_back(user);
          }
        }
      }

      // Iterate over candidates rather than candidates_set to avoid
      // nondeterminism.
      for (auto candidate : candidates) {
        int64 profit = GetProfit(instruction, candidate);
        if (profit > 0) {
          FusionCandidate& candidate_node =
              candidates_[get_candidate_id(candidate)];
          instr_node.fusibles.emplace_back(candidate, profit);
          candidate_node.fusibles.emplace_back(instruction, profit);
          worklist_.emplace(instruction, candidate, profit);
        }
      }
    }
    if (Perform()) {
      changed = true;
    }
  }
  return changed;
}

HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1,
                                        HloInstruction* instr2) {
  HloInstruction* remaining = instr1;
  HloInstruction* fused = instr2;
  // Make sure that if only one of the instructions is a fusion, or if only one
  // of the instructions is a multi-output fusion, it's what will be fused into.
  //
  // An invariant is that no bitcast nodes will show up in the middle of a
  // fusion node. This invariant must hold in order for us to lower it. Given
  // that, we require that during multi-output fusion, a fusion node ending with
  // bitcast to preserve its structure as a nested fusion instead being
  // merged and flattened.
  if (fused->opcode() == HloOpcode::kFusion &&
      fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) {
    std::swap(remaining, fused);
  }
  if (fused->IsMultiOutputFusion()) {
    std::swap(remaining, fused);
  }

  if (fused->opcode() == HloOpcode::kFusion &&
      fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) {
    remaining->MergeFusionInstructionIntoMultiOutput(fused);
  } else {
    if (remaining->opcode() == HloOpcode::kFusion &&
        remaining->fused_expression_root()->opcode() == HloOpcode::kBitcast) {
      auto parent_computation = remaining->parent();
      // Create a nested fusion node.
      auto remaining_nested_fused =
          parent_computation->AddInstruction(HloInstruction::CreateFusion(
              remaining->shape(), HloInstruction::FusionKind::kLoop,
              remaining));
      TF_CHECK_OK(parent_computation->ReplaceInstruction(
          remaining, remaining_nested_fused));
      remaining = remaining_nested_fused;
    }
    remaining->FuseInstructionIntoMultiOutput(fused);
  }

  return remaining;
}

bool MultiOutputFusion::IsProfitableOperand(HloInstruction* instr) {
  // kConstant instruction will not have memory reads, so it won't be a profit
  // source. Skip them.
  if (instr->opcode() == HloOpcode::kConstant &&
      ShapeUtil::IsEffectiveScalar(instr->shape())) {
    return false;
  }
  // We don't target to fuse producer/consumer instructions -- this should
  // be taken care of by the instruction_fusion pass. If instr has only
  // one user, it will not have sibling instructions. We won't consider it.
  if (instr->user_count() < 2) {
    return false;
  }
  return true;
}

void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) {
  HloInstruction* fusion = instr1;
  HloInstruction* fused = instr2;
  if (is_fused(instr1)) {
    fusion = instr2;
    fused = instr1;
  }

  // Insert the newly created instruction (if any), to candidates_.
  for (auto use : fusion->users()) {
    if (candidates_index_.find(use) == candidates_index_.end()) {
      int64 index = candidates_.size();
      candidates_.emplace_back(use);
      InsertOrDie(&candidates_index_, use, index++);
    }
  }
  FusionCandidate& fusion_node = candidates_[get_candidate_id(fusion)];
  FusionCandidate& fused_node = candidates_[get_candidate_id(fused)];

  // Update the reachability graph.
  UpdateReachability(fusion, fused, all_fusion_candidates_,
                     [this](HloInstruction* instr) { return is_fused(instr); });

  // Update the fusible list for fusion. Variable new_fusibles keeps
  // track of the new or changed entries.
  std::vector<std::pair<HloInstruction*, int64>> new_fusibles;
  tensorflow::gtl::FlatSet<HloInstruction*> in_list;
  auto it = fusion_node.fusibles.begin();
  while (it != fusion_node.fusibles.end()) {
    HloInstruction* instr = it->first;
    if (is_fused(instr) || is_connected(fusion, instr)) {
      it = fusion_node.fusibles.erase(it);
      continue;
    }
    in_list.insert(instr);
    int64 profit = GetProfit(instr, fusion);
    if (profit > it->second) {
      it->second = profit;
      new_fusibles.emplace_back(instr, profit);
    }
    ++it;
  }

  // Fused_node has been fused into fusion_node. Take the fusion candidates
  // (fusibles) from fused_nodes and add them to the fusion_node's. Filter
  // out those fusibles that no longer valid (or already in the list).
  for (const auto& it : fused_node.fusibles) {
    HloInstruction* instr = it.first;
    if (instr == fusion || is_fused(instr) || is_connected(fusion, instr)) {
      continue;
    }
    if (in_list.count(instr) > 0) {
      continue;
    }
    int64 profit = GetProfit(instr, fusion);
    fusion_node.fusibles.emplace_back(instr, profit);
    new_fusibles.emplace_back(instr, profit);
  }
  fused_node.fusibles.clear();

  // Update the worklist_.
  for (auto it : new_fusibles) {
    worklist_.emplace(fusion, it.first, it.second);
  }
}

bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1,
                                    HloInstruction* instr2) {
  if (instr1 == instr2) {
    return false;
  }
  if (instr1->opcode() != HloOpcode::kFusion) {
    return false;
  }

  // Fusing nodes with 0 user makes no sense and the rest of the implementation
  // doesn't support it either.
  if (instr1->user_count() == 0 || instr2->user_count() == 0) {
    return false;
  }

  // Check if the users of multioutput fusion is not a get-tuple-element.
  // If this is the case, we bail out because the transformation assumes
  // the users are get-tuple-element.
  auto multioutput_user_is_not_gte = [](HloInstruction* instr) {
    if (!instr->IsMultiOutputFusion()) {
      return false;
    }
    for (auto user : instr->users()) {
      if (user->opcode() != HloOpcode::kGetTupleElement) {
        return true;
      }
    }
    return false;
  };
  if (multioutput_user_is_not_gte(instr1) ||
      multioutput_user_is_not_gte(instr2)) {
    return false;
  }

  if (is_connected(instr1, instr2)) {
    return false;
  }
  if (!ShapesCompatibleForFusion(instr1, instr2)) {
    return false;
  }

  return true;
}

void MultiOutputFusion::RecomputeReachability() {
  reachability_ = computation_->ComputeReachability();
}

void MultiOutputFusion::UpdateReachability(
    HloInstruction* instr1, HloInstruction* instr2,
    tensorflow::gtl::ArraySlice<HloInstruction*> instrs_to_update,
    const std::function<bool(HloInstruction*)>& skip) {
  for (auto instr : instrs_to_update) {
    if (skip != nullptr && skip(instr)) {
      continue;
    }
    if (reachability_->IsReachable(instr2, instr) &&
        reachability_->IsReachable(instr1, instr)) {
      // If a candidate was already reachable by both, no update needed.
      continue;
    }
    if (reachability_->IsReachable(instr2, instr)) {
      reachability_->FastSetReachabilityToUnion({instr, instr1}, instr);
    }
    if (reachability_->IsReachable(instr1, instr)) {
      reachability_->FastSetReachabilityToUnion({instr, instr2}, instr);
    }
  }
}

bool MultiOutputFusion::Perform() {
  int changed = false;
  // Pick the top candidate from queue and try to merge.
  while (!worklist_.empty()) {
    if (fuel_ <= 0) {
      VLOG(2) << "No fusing: run out of fuel.";
      break;
    }
    ToBeFused candidate = worklist_.top();
    worklist_.pop();

    HloInstruction* instr1 = candidate.instr1;
    HloInstruction* instr2 = candidate.instr2;

    if (is_fused(instr1) || is_fused(instr2)) {
      continue;
    }

    VLOG(1) << "Considering candidate profit_score=" << candidate.score
            << "\n\t\tinstr1 = " << instr1->ToString()
            << "\n\t\tinstr2 = " << instr2->ToString();

    if (LegalToFuse(instr1, instr2)) {
      VLOG(1) << "Fuse!";
      VLOG(2) << "Before multi_output_fusion:";
      VLOG(2) << "instr1: " << instr1->ToString();
      VLOG(2) << "\n"
              << instr1->fused_instructions_computation()->ToString(
                     HloPrintOptions().set_indent_amount(1));
      VLOG(2) << "instr2: " << instr2->ToString();
      if (instr2->opcode() == HloOpcode::kFusion) {
        VLOG(2) << "\n"
                << instr2->fused_instructions_computation()->ToString(
                       HloPrintOptions().set_indent_amount(1));
      }
      HloInstruction* ret = Fuse(instr1, instr2);
      set_is_fused(ret == instr1 ? instr2 : instr1);
      Update(instr1, instr2);
      changed = true;
      VLOG(2) << "After fusion, \t this: " << ret->name() << "\n"
              << ret->fused_instructions_computation()->ToString(
                     HloPrintOptions().set_indent_amount(1));
      auto users = ret->users();
      --fuel_;
    }
  }
  if (DoProducerConsumerMultiOutputFusion()) {
    changed = true;
  }
  return changed;
}

bool MultiOutputFusion::DoProducerConsumerMultiOutputFusion() { return false; }
}  // namespace xla