aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h')
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h64
1 files changed, 58 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 819060061a..616d8a2206 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h"
namespace xla {
namespace gpu {
@@ -73,8 +74,11 @@ class IrEmitterUnnested : public IrEmitter {
Status HandleTuple(HloInstruction* tuple) override;
Status HandleWhile(HloInstruction* xla_while) override;
Status HandleInfeed(HloInstruction* xla_infeed) override;
+ Status HandleOutfeed(HloInstruction* outfeed) override;
Status HandleRng(HloInstruction* random) override;
Status HandleSelect(HloInstruction* select) override;
+ Status HandleSort(HloInstruction* sort) override;
+ Status HandleTupleSelect(HloInstruction* tuple_select) override;
Status HandleCrossReplicaSum(HloInstruction* crs) override;
Status HandleAfterAll(HloInstruction* gen_token) override;
@@ -115,7 +119,7 @@ class IrEmitterUnnested : public IrEmitter {
// Emits code that reduces a matrix of shape [height x width] to a vector of
// [width]. Other parameters have the same meaning as those of
// `EmitReductionToVector`. Note that input shape might not be
- // [height x width], but can be bitcast to [height x weight] with "height"
+ // [height x width], but can be bitcast to [height x width] with "height"
// being the major dimension.
Status EmitColumnReduction(
int64 height, int64 width, HloInstruction* reduce,
@@ -131,7 +135,7 @@ class IrEmitterUnnested : public IrEmitter {
// Emits code that reduces a 3D tensor of shape [depth x height x width] to a
// vector of shape [height]. Other parameters have the same meaning as those
// of `EmitReductionToVector`. Note that input shape might not be
- // [depth x height x width], but can be bitcast to [depth x height x weight]
+ // [depth x height x width], but can be bitcast to [depth x height x width]
// with "depth" being the most major dimension.
Status EmitRowReduction(
int64 depth, int64 height, int64 width, HloInstruction* reduce,
@@ -182,12 +186,56 @@ class IrEmitterUnnested : public IrEmitter {
std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
+ // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
+ // for the hlo instruction.
+ bool CheckAndEmitHloWithTile021(HloInstruction* hlo);
+ // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and
+ // returns the launch dimensions for the kernel. This is a helper to support
+ // the implementation of CheckAndEmitHloWithTile021.
+ LaunchDimensions EmitHlo021Tile(
+ HloInstruction* hlo,
+ tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ tensorflow::gtl::ArraySlice<int64> tiled_param_ids);
+ // Generates the IrArray for each output of hlo and returns the number of
+ // outputs.
+ int ConstructIrArrayForOutputs(const HloInstruction& hlo,
+ std::vector<llvm_ir::IrArray>* output_arrays);
+ // Generates the IrArray for each input of hlo and returns the number of
+ // inputs.
+ int ConstructIrArrayForInputs(const HloInstruction& hlo,
+ std::vector<llvm_ir::IrArray>* param_arrays);
+ // For each output of the `hlo` instruction, constructs the reduced shape for
+ // the output with the given `reduced_output_dims` and cast the original
+ // output IrArray element in `output_arrays` to the reduced shape. Returns
+ // the number of outputs.
+ int ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
+ const HloInstruction& hlo,
+ const std::vector<llvm_ir::IrArray>& output_arrays,
+ tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ std::vector<Shape>* output_reduced_shapes,
+ std::vector<llvm_ir::IrArray>* output_in_reduced_shape_arrays);
+ // For each input of the `hlo` instruction, checks its value in
+ // `param_buffers` to find out whether the input has a reduced shape. If the
+ // input has a reduced shape, constructs the reduced shape for the input and
+ // casts the original input IrArray in `param_arrays` to the reduced shape.
+ // Return the total number of inputs.
+ int ConstructInputReducedShapeAndCastInputIrArrayToShape(
+ const HloInstruction& hlo,
+ const std::vector<llvm_ir::IrArray>& param_arrays,
+ const std::vector<llvm::Value*>& param_buffers,
+ tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ std::vector<Shape>* param_reduced_shapes,
+ std::vector<llvm_ir::IrArray>* param_in_reduced_shape_arrays);
+
// Returns a KernelThunk that invokes the kernel emitted for `inst`. The
// caller needs to make sure `inst` outlives the lifetime of the returned
// Thunk object. The kernel implementation will be unrolled if unroll_factor
- // is greater than one.
- std::unique_ptr<KernelThunk> BuildKernelThunk(const HloInstruction* inst,
- int unroll_factor = 1);
+ // is greater than one. 'implements_whole_instruction' specifies whether this
+ // KernelThunk implements the whole 'inst' HloInstruction. In some cases
+ // 'inst' will be implemented by a sequence of Thunks.
+ std::unique_ptr<KernelThunk> BuildKernelThunk(
+ const HloInstruction* inst, bool implements_whole_instruction,
+ int unroll_factor = 1);
// Returns a FftThunk that calls cuFFT to implement `inst`.
std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst);
@@ -208,10 +256,14 @@ class IrEmitterUnnested : public IrEmitter {
std::unique_ptr<Thunk> BuildDeviceToDeviceCopyThunk(
const HloInstruction* inst);
- // Returns an InfeedThunk that performs device-to-device memcpy to implement
+ // Returns an InfeedThunk that performs a host-to-device memcpy to implement
// `inst`.
std::unique_ptr<Thunk> BuildInfeedThunk(const HloInstruction* inst);
+ // Returns an OutfeedThunk that performs a device-to-host memcpy to implement
+ // `inst`.
+ std::unique_ptr<Thunk> BuildOutfeedThunk(const HloInstruction* inst);
+
// Returns a WhileThunk that invokes thunk sequences for 'condition' and
// 'body' sub-computations of while instruction 'hlo'.
std::unique_ptr<Thunk> BuildWhileThunk(const HloInstruction* hlo);