aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
blob: bd5db7205155dc6b15ddea069e172bbd8f419996 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_

#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h"

namespace xla {
namespace gpu {

// Emits LLVM IR for an "unnested computation".
//
// An unnested computation is an HloComputation which you run by executing one
// or more kernels for each HloInstruction it contains.  Examples of unnested
// computations:
//
//  - An HloModule's root computation,
//  - The body of an HLO while loop,
//  - The true/false computation of an HLO conditional.
//
// Note the opportunity for confusion -- the while loop's computation is nested
// within the root computation, but it's emitted using IrEmitterUnnested!  Don't
// think about it too hard.
//
// Examples of things that are not unnested computations:
//
//  - The reducer of a kReduce HLO.  This is emitted using IrEmitterNested.
//  - The body of a fusion node.  IrEmitterUnenested emits the relevant code
//    within a kernel function using FusedIrEmitter.  (FusedIrEmitter is not
//    really an IrEmitter, but is more an "IR generator generator".)
//
class IrEmitterUnnested : public IrEmitter {
 public:
  IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
                    const HloComputation* hlo_computation,
                    IrEmitterContext* ir_emitter_context);
  IrEmitterUnnested(const IrEmitterUnnested&) = delete;
  IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete;

  // Transfers the ownship of thunk_sequence_ out.
  std::unique_ptr<ThunkSequence> ConsumeThunkSequence() {
    return std::move(thunk_sequence_);
  }

  Status DefaultAction(HloInstruction* hlo) override;

  // IrEmitterUnnested handles the following instructions differently from
  // IrEmitter.
  Status HandleCopy(HloInstruction* copy) override;
  Status HandleConditional(HloInstruction* conditional) override;
  Status HandleConvolution(HloInstruction* convolution) override;
  Status HandleCustomCall(HloInstruction* custom_call) override;
  Status HandleDot(HloInstruction* dot) override;
  Status HandleFft(HloInstruction* fft) override;
  Status HandleFusion(HloInstruction* fusion) override;
  Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
  Status HandleReduce(HloInstruction* reduce) override;
  Status HandleSelectAndScatter(HloInstruction* instruction) override;
  Status HandleTuple(HloInstruction* tuple) override;
  Status HandleWhile(HloInstruction* xla_while) override;
  Status HandleInfeed(HloInstruction* xla_infeed) override;
  Status HandleOutfeed(HloInstruction* outfeed) override;
  Status HandleRng(HloInstruction* random) override;
  Status HandleSelect(HloInstruction* select) override;
  Status HandleSort(HloInstruction* sort) override;
  Status HandleTupleSelect(HloInstruction* tuple_select) override;
  Status HandleCrossReplicaSum(HloInstruction* crs) override;
  Status HandleAfterAll(HloInstruction* gen_token) override;

  Status EmitTargetElementLoop(
      const HloInstruction& hlo,
      const llvm_ir::ElementGenerator& body_emitter) override;

  // Same as `EmitTargetElementLoop`, but in given `thunk` rather than
  // `LastThunk()`.
  Status EmitTargetElementLoopInThunk(
      const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter,
      KernelThunk* thunk);

  // Emits LLVM global variables corresponding to constant instructions.
  Status EmitConstantGlobals();

 private:
  // Builds the appropriate thunk for the instruction hlo and returns the owning
  // pointer to it. The caller needs to make sure `inst` outlives the lifetime
  // of the returned Thunk object.
  std::unique_ptr<Thunk> BuildThunk(const HloInstruction* hlo);

  // Builds the prototype of the IR kernel for `inst` and adds it to the module.
  // This kernel takes as arguments pointers to the given buffer allocations.
  llvm::Function* BuildKernelPrototype(
      const HloInstruction& inst,
      absl::Span<const BufferAllocation* const> args);

  // Helper for writing extra outputs from inside a reduce kernel.
  Status EmitExtraOutputsForReduce(
      const HloInstruction* reduce, const llvm_ir::IrArray::Index& index,
      absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
          extra_output_gens);

  // EmitColumnReduction and EmitRowReduction emit code for column and row
  // reduction of a matrix and/or 3D tensor. Row and column reduction have
  // different memory access pattern, so for performance their implementations
  // are significantly different.
  //
  // Emits code that reduces a matrix of shape [height x width] to a vector of
  // [width]. Other parameters have the same meaning as those of
  // `EmitReductionToVector`. Note that input shape might not be
  // [height x width], but can be bitcast to [height x width] with "height"
  // being the major dimension.
  Status EmitColumnReduction(
      int64 height, int64 width, HloInstruction* reduce,
      const Shape& input_shape,
      absl::Span<const llvm_ir::ElementGenerator> input_gens,
      absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
      absl::Span<HloComputation* const> reducers,
      absl::Span<const ShapeIndex> reduce_output_shapes,
      absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
          extra_output_gens);

  // Emits code that reduces a 3D tensor of shape [depth x height x width] to a
  // vector of shape [height]. Other parameters have the same meaning as those
  // of `EmitReductionToVector`. Note that input shape might not be
  // [depth x height x width], but can be bitcast to [depth x height x width]
  // with "depth" being the most major dimension.
  Status EmitRowReduction(
      int64 depth, int64 height, int64 width, HloInstruction* reduce,
      const Shape& input_shape,
      absl::Span<const llvm_ir::ElementGenerator> input_gens,
      absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
      absl::Span<HloComputation* const> reducers,
      absl::Span<const ShapeIndex> reduce_output_shapes,
      absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
          extra_output_gens);

  // Emits code that reduces a tensor of arbitrary rank to a scalar.
  Status EmitReductionToScalar(
      HloInstruction* reduce, const Shape& input_shape,
      absl::Span<const llvm_ir::ElementGenerator> input_gens,
      absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
      absl::Span<HloComputation* const> reducers,
      absl::Span<const ShapeIndex> reduce_output_shapes,
      absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
          extra_output_gens);

  // Figures out whether `reduce` is a row or column reduction, and which
  // dimensions to reduce, and calls either `EmitRowReduction` or
  // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the
  // input array, which is the operand of the Reduce instruction if unfused or
  // of the Fusion instruction if fused. `input_gen` and `init_value_gen`
  // generate elements of the input and the initial value. Other parameters mean
  // the same as for `HandleReduce`.
  //
  // Multiple reduces can be emitted in the same loop, assuming they have the
  // same input and output shapes, and the same reduce dimensions.
  //
  // extra_output_gens can contain extra generators for intermediate outputs.
  // These must have the same shape as the reduce input as they are computed
  // when the reduce inputs are being read.
  //
  // Prerequisite: `IsReductionToVector(*reduce)`
  Status EmitReductionToVector(
      HloInstruction* reduce, const Shape& input_shape,
      absl::Span<const llvm_ir::ElementGenerator> input_gens,
      absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
      absl::Span<const int64> dimensions_to_reduce,
      absl::Span<HloComputation* const> reducers,
      absl::Span<const ShapeIndex> reduce_output_shapes,
      absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
          extra_output_gens);

  // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
  // for the hlo instruction.
  bool CheckAndEmitHloWithTile021(HloInstruction* hlo);
  // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and
  // returns the launch dimensions for the kernel. This is a helper to support
  // the implementation of CheckAndEmitHloWithTile021.
  LaunchDimensions EmitHlo021Tile(HloInstruction* hlo,
                                  absl::Span<const int64> reduced_output_dims,
                                  absl::Span<const int64> tiled_param_ids);

  // Generates the IrArray for each input of an hlo and returns a vector that
  // constains such IrArrays.
  std::vector<llvm_ir::IrArray> ConstructIrArrayForInputs(
      const HloInstruction& hlo);

  // For each output of the `hlo` instruction, constructs the reduced shape for
  // the output with the given `reduced_output_dims` and cast the original
  // output IrArray element in `output_arrays` to the reduced shape. Returns
  // the number of outputs.
  int ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
      const HloInstruction& hlo,
      const std::vector<llvm_ir::IrArray>& output_arrays,
      absl::Span<const int64> reduced_output_dims,
      std::vector<Shape>* output_reduced_shapes,
      std::vector<llvm_ir::IrArray>* output_in_reduced_shape_arrays);
  // For each input of the `hlo` instruction, checks its value in
  // `param_buffers` to find out whether the input has a reduced shape. If the
  // input has a reduced shape, constructs the reduced shape for the input and
  // casts the original input IrArray in `param_arrays` to the reduced shape.
  // Return the total number of inputs.
  int ConstructInputReducedShapeAndCastInputIrArrayToShape(
      const HloInstruction& hlo,
      const std::vector<llvm_ir::IrArray>& param_arrays,
      const std::vector<llvm::Value*>& param_buffers,
      absl::Span<const int64> reduced_output_dims,
      std::vector<Shape>* param_reduced_shapes,
      std::vector<llvm_ir::IrArray>* param_in_reduced_shape_arrays);

  // Returns a KernelThunk that invokes the kernel emitted for `inst`. The
  // caller needs to make sure `inst` outlives the lifetime of the returned
  // Thunk object. The kernel implementation will be unrolled if unroll_factor
  // is greater than one. 'implements_whole_instruction' specifies whether this
  // KernelThunk implements the whole 'inst' HloInstruction. In some cases
  // 'inst' will be implemented by a sequence of Thunks.
  std::unique_ptr<KernelThunk> BuildKernelThunk(
      const HloInstruction* inst, bool implements_whole_instruction,
      int unroll_factor = 1);

  // Returns a FftThunk that calls cuFFT to implement `inst`.
  std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst);

  // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs
  // to make sure `inst` outlives the lifetime of the returned Thunk object.
  std::unique_ptr<Thunk> BuildGemmThunk(const HloInstruction* inst);

  // Returns a thunk that, given a reduce or select-and-scatter op, initializes
  // its memory to the appropriate initial value.
  StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk(
      HloInstruction* hlo, const ShapeIndex& index = {});

  // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`.
  std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst);

  // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`.
  std::unique_ptr<Thunk> BuildDeviceToDeviceCopyThunk(
      const HloInstruction* inst);

  // Returns an InfeedThunk that performs a host-to-device memcpy to implement
  // `inst`.
  std::unique_ptr<Thunk> BuildInfeedThunk(const HloInstruction* inst);

  // Returns an OutfeedThunk that performs a device-to-host memcpy to implement
  // `inst`.
  std::unique_ptr<Thunk> BuildOutfeedThunk(const HloInstruction* inst);

  // Returns a WhileThunk that invokes thunk sequences for 'condition' and
  // 'body' sub-computations of while instruction 'hlo'.
  std::unique_ptr<Thunk> BuildWhileThunk(const HloInstruction* hlo);

  // Returns a ForThunk which executes 'loop_limit' invocations of a thunk
  // sequence from the 'body' sub-computation of the while instruction 'hlo'.
  std::unique_ptr<Thunk> BuildForThunk(const HloInstruction* hlo,
                                       const int64 loop_limit);

  // Returns a ConditionalThunk that executes the thunk sequence for
  // 'true_computation' or 'false_computation' depending on the value of the
  // predicate in the given conditional instruction.
  std::unique_ptr<Thunk> BuildConditionalThunk(const HloInstruction* hlo);

  Status Postprocess(HloInstruction* hlo) override;

  // Returns the last generated thunk.
  Thunk* LastThunk() const { return thunk_sequence_->back().get(); }

  // The thunk sequence this IrEmitter generates for the input computation.
  std::unique_ptr<ThunkSequence> thunk_sequence_;

  // The HloComputation that this IrEmitter emits code for.
  const HloComputation* hlo_computation_;
};

}  // namespace gpu
}  // namespace xla

#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_