aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/multi_output_fusion.h
blob: 9508ab2ed1d38ec40983d8892ec8875b848fb21b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_

#include <queue>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"

namespace xla {

// This class implements the fusing of sibling fusion instructions that sharing
// common operands.
// It constructs the following associated data structures.
//  (1) candidates_: stores the instruction and the set of instructions it can
//      fuse to.
//  (2) candidates_index_: maps instruction to id.
//  (3) reachability_: reachability map in this computation.
//  (4) all_fusion_candidates_: the vector of candidate instructions.
//  (5) worklist_: a priority queue that contains pairs of instructions to be
//      fused and their fusion profit scores.
//
//  Function Perform() applies the optimization. It picks up the most profitable
//  pair in the worklist_, check if it's legal to fuse and fuse the pair.
//  After fusion, it updates the associated structure such as reachability_,
//  candidates_ and worklist_.
//  Note that the reachability map is updated based on the original computation.
//  This works because the reachability is monotonically increasing with
//  instruction fusion.
class MultiOutputFusion : public HloModulePass {
 public:
  MultiOutputFusion(int64 fuel) : fuel_(fuel) {}

  absl::string_view name() const override { return "multi_output_fusion"; }

  // Run multi-output fusion on the given module. Returns whether the module
  // was changed.
  StatusOr<bool> Run(HloModule* module) override;

 protected:
  // Main entry for the optimization. Returns true if the optimization happens.
  bool Perform();

  // Test if instr1 and instr2 have the compatible shapes that can be legally
  // fused.
  virtual bool ShapesCompatibleForFusion(HloInstruction* instr1,
                                         HloInstruction* instr2) = 0;

  // Whether the instruction is a candidate for fusion.
  virtual bool IsFusible(HloInstruction* instr) = 0;

  // This function estimates the savings by merging instr1 and instr2 into one
  // multi-output fusion instruction.
  virtual int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) = 0;

  // Whether fusing the instruction can reduce memory reads.
  virtual bool IsProfitableOperand(HloInstruction* instr);

  // Test if it's legal to fuse instr1 and instr2 into one fusion instruction.
  virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2);

  // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction.
  // The other instruction is removed from its parent computation.
  virtual HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2);

  // Recompute reachability for the current computation.
  void RecomputeReachability();

  // Returns the reachability map for the current computation.
  HloReachabilityMap* reachability() const { return reachability_.get(); }

  // Returns the computation for the pass.
  HloComputation* computation() const { return computation_; }

  // Update the reachability map after fusing instr1 and instr2.
  void UpdateReachability(
      HloInstruction* instr1, HloInstruction* instr2,
      absl::Span<HloInstruction* const> instrs_to_update,
      const std::function<bool(HloInstruction*)>& skip = nullptr);

  // Hook for multi-output fusion along producer-consumer edges.
  // Returns whether any instructions were fused.
  //
  // TODO(b/80420762): Perform producer-consumer multi-output fusion in
  // InstructionFusion instead.
  virtual bool DoProducerConsumerMultiOutputFusion();

  // Optimization fuel is a compiler debugging technique that makes an
  // optimization pass stop what it is doing after having made N changes to the
  // program, where N is the fuel. By varying N, this can be used to find the
  // first single change that makes a test fail.
  int64 fuel_;

 private:
  // Update the internal data structures after instr1 and instr2 are fused into
  // one fusion instruction.
  void Update(HloInstruction* instr1, HloInstruction* instr2);

  // Computation for the pass.
  HloComputation* computation_;

  // An internal data structure for each instruction in current computation.
  // When an instruction is removed, member 'hlo' is set to nullptr.
  struct FusionCandidate {
    HloInstruction* hlo;
    std::list<std::pair<HloInstruction*, int64>> fusibles;
    explicit FusionCandidate(HloInstruction* hlo) : hlo(hlo) {}
  };
  std::vector<FusionCandidate> candidates_;

  // A map that maps an instruction to the index_.
  absl::flat_hash_map<HloInstruction*, int> candidates_index_;

  // The reachability map of current computation.
  std::unique_ptr<HloReachabilityMap> reachability_;

  // This stores all the candidate instructions in current computation.
  std::vector<HloInstruction*> all_fusion_candidates_;

  // The pair of candidates to be fused and the profit score.
  struct ToBeFused {
    HloInstruction* instr1;
    HloInstruction* instr2;
    int64 score;
    ToBeFused(HloInstruction* instr1, HloInstruction* instr2, int64 score)
        : instr1(instr1), instr2(instr2), score(score) {}
    bool operator<(const ToBeFused& rhs) const { return score < rhs.score; }
  };
  std::priority_queue<ToBeFused> worklist_;

  int64 get_candidate_id(HloInstruction* instr) {
    return FindOrDie(candidates_index_, instr);
  }

  bool is_fused(HloInstruction* instr) {
    return candidates_[get_candidate_id(instr)].hlo == nullptr;
  }

  void set_is_fused(HloInstruction* instr) {
    candidates_[get_candidate_id(instr)].hlo = nullptr;
  }

  bool is_connected(HloInstruction* instr1, HloInstruction* instr2) {
    return reachability_->IsConnected(instr1, instr2);
  }
};

}  // namespace xla

#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_