aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/bfloat16_propagation.h
blob: 5fcaa15c8356107af02e9099874a293d8350c51a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_

#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/core/lib/hash/hash.h"

namespace xla {

// HLO pass which reduces the precision of some HLO instructions to BF16
// according to the backend-specific BFloat16Support rule provided by the
// caller.
//
// This pass can be used to reduce instruction precision without affecting the
// numerical accuracy of the module, i.e., the final output of the module would
// be bitwise identical to that without this pass; this is possible if the
// backend already reduces precision to BF16 on some HLO instructions.
//
// This pass will not modify the signature of a computation, unless it is a
// fusion computation or its only caller is a while.
//
// !!! WARNING !!! This pass can introduce mixed precision in individual HLOs,
// which has two issues:
//
// 1) It does not guarantee to respect the passed-in BFloat16Support
// specification in terms of mixed precision, so the backend may not support an
// HLO that has mixed precision produced by this pass. To address this issue,
// run BFloat16Normalization with the same BFloat16Support after this pass.
//
// 2) In general, mixed precision may break the assumptions of some other HLO
// passes even if the specific backend supports the individual HLOs. Such
// assumptions include that there are no HLOs using mixed precision, or that the
// precision of an HLO's output is determined by its inputs. It should be used
// at the end of the HLO optimization pipeline but before
// BFloat16ConversionFolding. If other passes are needed after this pass, run
// BFloat16MixedPrecisionRemoval first to undo some of the changes made by this
// pass.
class BFloat16Propagation : public HloModulePass {
 public:
  explicit BFloat16Propagation(const BFloat16Support* bfloat16_support);

  ~BFloat16Propagation() override = default;

  absl::string_view name() const override { return "bfloat16-propagation"; }

  // Runs the pass on the given module. Returns whether the module was changed
  // (precision reductions were added).
  StatusOr<bool> Run(HloModule* module) override;

 private:
  // ***************************
  // Function called and state produced by the forward analysis pass (from
  // parameters to root) that determines the candidate HLOs to use BF16 outputs.

  // Determines whether we should consider changing the precision of the given
  // instruction in the forward pass.
  bool InstructionIsCandidateForBF16Output(HloInstruction* hlo);

  // The set of instructions to consider using bfloat16, computed in the forward
  // pass.
  absl::flat_hash_set<const HloInstruction*> consider_using_bfloat16_;

  // ***************************
  // Functions called and state produced by the backward pass (from root to
  // parameters) that finds opportunities to use BF16.

  // Determines the precision for the given instruction in the
  // opportunity-finding pass.
  void DetermineInstructionPrecision(HloInstruction* hlo, bool skip_parameters);

  // Special handling in the opportunity-finding pass for fusion computations.
  //
  // Precondition: hlo->opcode() == kFusion
  void DetermineFusionComputationPrecision(HloInstruction* fusion);

  // Reverts changes to BF16 that will not propagate outside a fusion
  // computation. This avoids BF16 casts overhead inside a fusion which won't
  // save memory bandwidth.
  //
  // Precondition: hlo->opcode() == kFusion
  void RevertIfFusionInternalBF16Changes(HloInstruction* fusion);

  // Special handling in the opportunity-finding pass for while computations.
  //
  // Precondition: hlo->opcode() == kWhile
  void DetermineWhileComputationsPrecision(HloInstruction* while_hlo);

  // The set of HloInstructions that have been visited in the
  // opportunity-finding pass.
  absl::flat_hash_set<const HloInstruction*>
      instructions_visited_in_backward_pass_;

  // The set of HloComputations that have been visited in the
  // opportunity-finding pass.
  absl::flat_hash_set<const HloComputation*>
      computations_visited_in_backward_pass_;

  // ***************************
  // Functions called by the final inconsistency resolving pass.

  // Adjusts the output shapes of HloInstructions such that if two
  // HloInstructions have aliasing buffers in their outputs, they must have the
  // same precision.
  void ResolveInconsistencyOfAliasingBuffers(HloModule* module);

  // Resolves inconsistency of aliasing buffers for the given computation, and
  // recursively runs on a while instruction's condition and body until a fixed
  // point is reached.
  bool ResolveInconsistencyOfAliasingBuffersHelper(
      HloComputation* computation,
      absl::flat_hash_set<const HloComputation*>* visited_computations);

  // Makes the parameters of called computations match how they are called by
  // the given HLO.
  void AdjustCalledComputationParameters(HloInstruction* hlo);

  // Makes the root instructions of called computations match how they are used
  // by the given HLO.
  void AdjustCalledComputationRoot(HloInstruction* hlo);

  // ***************************
  // Functions called after changes in changes_to_bf16_ are applied.

  // Resolves inconsistencies introduced by this pass for fusions with
  // tuple-type output.
  Status ResolveInconsistentFusions(HloModule* module);

  // Converts the literals in kConstant HLOs which have their types changed to
  // BF16 by this pass.
  Status ResolveConvertedConstants(HloModule* module);

  // Skips no-op conversions (same source and target shapes) that can be
  // produced this pass, i.e., replaces them in their uses with their operands.
  Status SkipNoopConversions(HloModule* module);

  // ***************************
  // Functions called and state used by two or more passes.

  // Returns whether all uses of the given HloInstruction can consume BF16
  // input.
  bool AllUsersConsumeBF16(const HloInstruction& hlo,
                           const ShapeIndex& index) const;

  // The output element type of the HLO at the given shape index after changes
  // in changes_to_bf16_ are applied.
  PrimitiveType OutputTypeAfterChange(HloInstruction* hlo,
                                      const ShapeIndex& index) const;

  // The element type of the HLO value after changes in changes_to_bf16_ are
  // applied.
  PrimitiveType ValueTypeAfterChange(const HloValue* value) const;

  // If target_type == BF16, adds the HLO at the given index to
  // changes_to_bf16_; otherwise, target_type must be F32 and this function
  // removes the HLO at the given index from changes_to_bf16_ if it was earlier
  // added.
  void AddToOrRemoveFromBF16ChangeSet(HloInstruction* hlo,
                                      const ShapeIndex& index,
                                      PrimitiveType target_type);

  // The set of F32 HLO values that must be kept in F32.
  absl::flat_hash_set<const HloValue*> values_that_must_be_kept_as_f32_;

  // Mapping from each HloComputation to the number of callers to it in the
  // module. Populated at the beginning of this pass.
  absl::flat_hash_map<const HloComputation*, int64> caller_counts_;

  // We first store the potential F32-to-BF16 changes to changes_to_bf16_, which
  // are subject to further adjustment, then finally applied to the HLOs. This
  // avoids setting changed_ to true but all changes are reverted during
  // adjustment.
  //
  // For each HloInstruction, changes_to_bf16_ stores the affected buffers in
  // the output as a map from in-place pointers to subshapes to shape indices.
  absl::flat_hash_map<HloInstruction*, absl::flat_hash_map<Shape*, ShapeIndex>>
      changes_to_bf16_;

  // Whether the last processed HLO module has been changed by this pass.
  bool changed_ = false;

  const BFloat16Support* bfloat16_support_;
  std::unique_ptr<HloDataflowAnalysis> dataflow_;
};

}  // namespace xla

#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_