aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-07-11 09:06:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-11 09:09:36 -0700
commitdb2bb82963be35862de94425b30dfce17594eb82 (patch)
treef1fbb61a2063ac59ac982add337d1d75838e33fa /tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
parent50d121e2365a4edffe59741df19e0dcc96291576 (diff)
[XLA:GPU] Cleanups to fused 021 transpose implementation.
- Fix typos. - Clarify comments. - Reduce nesting in a few places. - Add asserts that this code is dealing with specifically a loop fusion. - Rename some functions. In particular, it's confusing to have a function with a generic name like EmitCodeWithBoundCheck that actually is specialized to a tiled implementation. - Remove statement expression (GCC language extension), replacing it with an IIFE. - Don't refer to shared-memory tile space as "buffer" without other qualifying words, since that's ambiguous with what XLA refers to as a "buffer". - Use llvm::cast instead of static_cast. - Comply with style guide naming rules for compile-time constants (kFoo). - Use c_accumulate instead of std::accumulate. - Put std::function parameter at the end of the param list. This lets us cleanly embed the lambda into the call because of how clang-format formats such calls. (I think this one is possibly the most helpful change in this patch, as it suddenly makes clear to me the way that we use two calls to emit_tiled_elemental_code_with_bounds_check to emit the code.) PiperOrigin-RevId: 204134102
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc538
1 files changed, 265 insertions, 273 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 59edba30e6..a97447860e 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -80,6 +80,7 @@ namespace gpu {
namespace {
+using llvm_ir::IrArray;
using llvm_ir::IrName;
using tensorflow::gtl::ArraySlice;
using tensorflow::gtl::InlinedVector;
@@ -693,16 +694,18 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
fusion, operand_arrays, output_array, &elemental_emitter,
launch_dimensions, &ir_builder_);
}
+
if (ImplementedAsGemm(*fusion)) {
thunk_sequence_->emplace_back(BuildGemmThunk(fusion));
return Status::OK();
}
+ CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop);
+
if (CheckAndEmitHloWithTile021(fusion)) {
return Status::OK();
}
- CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop);
int unroll_factor = ComputeMaxUnrollFactor(fusion);
thunk_sequence_->emplace_back(BuildKernelThunk(
@@ -774,7 +777,7 @@ Status IrEmitterUnnested::EmitReductionToScalar(
llvm::Type* index_ty = GetIndexTypeForKernel(
reduce, launch_dimensions.launch_bound(), &ir_builder_);
- auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_ty, c);
};
@@ -840,21 +843,22 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// Emit an inner for-loop that reduces the elements in the tile.
auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
- llvm_ir::ForLoop::EmitForLoop(
- "element_id_in_tile", index_typed_const(0),
- index_typed_const(kTileSize), index_typed_const(1), &ir_builder_);
+ llvm_ir::ForLoop::EmitForLoop("element_id_in_tile",
+ index_typed_constant(0),
+ index_typed_constant(kTileSize),
+ index_typed_constant(1), &ir_builder_);
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
&ir_builder_);
llvm::Value* x = ir_builder_.CreateNSWAdd(
- ir_builder_.CreateNSWMul(x_in_tiles, index_typed_const(kTileSize)),
+ ir_builder_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)),
tile_element_loop->GetIndVarValue());
// Unless we know the tile is entirely in bounds, we have to emit a
// x-in-bounds check before reading from the input.
if (!tile_in_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpULT(x, index_typed_const(num_elems)),
+ ir_builder_.CreateICmpULT(x, index_typed_constant(num_elems)),
"x_in_bounds", &ir_builder_);
// Emit code that reads the input element and accumulates it to
@@ -880,12 +884,12 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's
// immediately beyond the tile.
llvm::Value* x_end = ir_builder_.CreateNSWAdd(
- index_typed_const(kTileSize),
- ir_builder_.CreateNSWMul(x_in_tiles, index_typed_const(kTileSize)));
+ index_typed_constant(kTileSize),
+ ir_builder_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)));
// The tile is entirely in bound if all_threads_in_bounds or
// x_end <= num_elems.
llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
- ir_builder_.CreateICmpULE(x_end, index_typed_const(num_elems)),
+ ir_builder_.CreateICmpULE(x_end, index_typed_constant(num_elems)),
ir_builder_.getInt1(all_threads_in_bounds));
llvm_ir::LlvmIfData if_tile_in_bounds_data =
llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_);
@@ -936,9 +940,9 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// lane 0 (which holds the partially accumulated result for its warp) to the
// output element.
llvm::Value* lane_id = ir_builder_.CreateURem(
- x_in_tiles, index_typed_const(kWarpSize), "lane_id");
+ x_in_tiles, index_typed_constant(kWarpSize), "lane_id");
llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpEQ(lane_id, index_typed_const(0)),
+ ir_builder_.CreateICmpEQ(lane_id, index_typed_constant(0)),
"lane_id_is_zero", &ir_builder_);
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
&ir_builder_);
@@ -1006,7 +1010,7 @@ Status IrEmitterUnnested::EmitColumnReduction(
// TODO(b/110211620): Convert to use i32 index_type when it is possible.
llvm::Type* index_ty = ir_builder_.getInt64Ty();
- auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_ty, c);
};
@@ -1062,22 +1066,23 @@ Status IrEmitterUnnested::EmitColumnReduction(
auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
- llvm_ir::ForLoop::EmitForLoop(
- "element_id_in_tile", index_typed_const(0),
- index_typed_const(kTileSize), index_typed_const(1), &ir_builder_);
+ llvm_ir::ForLoop::EmitForLoop("element_id_in_tile",
+ index_typed_constant(0),
+ index_typed_constant(kTileSize),
+ index_typed_constant(1), &ir_builder_);
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
&ir_builder_);
llvm::Value* y = ir_builder_.CreateNSWAdd(
- ir_builder_.CreateNSWMul(y_in_tiles, index_typed_const(kTileSize)),
+ ir_builder_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileSize)),
tile_element_loop->GetIndVarValue());
// Unless we know the tile is entirely in bounds, we have to emit a
// y-in-bounds check before reading from the input.
if (!tile_in_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpULT(y, index_typed_const(height)),
+ ir_builder_.CreateICmpULT(y, index_typed_constant(height)),
"y_in_bounds", &ir_builder_);
// Emit code that reads the input element and accumulates it to
@@ -1127,10 +1132,10 @@ Status IrEmitterUnnested::EmitColumnReduction(
// y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's
// immediately beyond the tile.
llvm::Value* y_end = ir_builder_.CreateNSWAdd(
- index_typed_const(kTileSize),
- ir_builder_.CreateNSWMul(y_in_tiles, index_typed_const(kTileSize)));
+ index_typed_constant(kTileSize),
+ ir_builder_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileSize)));
llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
- ir_builder_.CreateICmpULE(y_end, index_typed_const(height)),
+ ir_builder_.CreateICmpULE(y_end, index_typed_constant(height)),
ir_builder_.getInt1(height % kTileSize == 0));
// The tile is entirely in bound if "height" is a multiple of kTileSize or
// y_end <= height.
@@ -1326,7 +1331,7 @@ Status IrEmitterUnnested::EmitRowReduction(
llvm::Type* index_ty = GetIndexTypeForKernel(
reduce, launch_dimensions.launch_bound(), &ir_builder_);
- auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_ty, c);
};
@@ -1353,20 +1358,20 @@ Status IrEmitterUnnested::EmitRowReduction(
x_tile = ir_builder_.CreateZExtOrTrunc(x_tile, index_ty);
- llvm::Value* warp_id =
- ir_builder_.CreateUDiv(x_tile, index_typed_const(kWarpSize), "warp_id");
- llvm::Value* lane_id =
- ir_builder_.CreateURem(x_tile, index_typed_const(kWarpSize), "lane_id");
+ llvm::Value* warp_id = ir_builder_.CreateUDiv(
+ x_tile, index_typed_constant(kWarpSize), "warp_id");
+ llvm::Value* lane_id = ir_builder_.CreateURem(
+ x_tile, index_typed_constant(kWarpSize), "lane_id");
// The x-location of the last element in this z-x-tile.
// last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size);
llvm::Value* last_x = ir_builder_.CreateNSWAdd(
lane_id, ir_builder_.CreateNSWMul(
- index_typed_const(kWarpSize),
+ index_typed_constant(kWarpSize),
ir_builder_.CreateNSWAdd(
- index_typed_const(x_tile_size - 1),
+ index_typed_constant(x_tile_size - 1),
ir_builder_.CreateNSWMul(
- warp_id, index_typed_const(x_tile_size)))));
+ warp_id, index_typed_constant(x_tile_size)))));
KernelSupportLibrary ksl(
&ir_builder_,
@@ -1379,19 +1384,19 @@ Status IrEmitterUnnested::EmitRowReduction(
int64 x_tile_loop_bound) -> Status {
auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status {
llvm::Value* z = ir_builder_.CreateNSWAdd(
- z_indvar,
- ir_builder_.CreateNSWMul(index_typed_const(z_tile_size), z_tile));
+ z_indvar, ir_builder_.CreateNSWMul(
+ index_typed_constant(z_tile_size), z_tile));
TF_RETURN_IF_ERROR(ksl.For(
"x_tile",
- /*start=*/index_typed_const(0),
- /*end=*/index_typed_const(x_tile_loop_bound),
+ /*start=*/index_typed_constant(0),
+ /*end=*/index_typed_constant(x_tile_loop_bound),
/*step=*/1, [&](llvm::Value* x_indvar) -> Status {
// x = lane_id +
// warpSize * (element_id_in_x_tile + warp_id * x_tile_size);
llvm::Value* x = ir_builder_.CreateNSWAdd(
lane_id,
ir_builder_.CreateNSWMul(
- index_typed_const(kWarpSize),
+ index_typed_constant(kWarpSize),
ir_builder_.CreateNSWAdd(
x_indvar, ir_builder_.CreateNSWMul(
warp_id, llvm::ConstantInt::get(
@@ -1401,9 +1406,9 @@ Status IrEmitterUnnested::EmitRowReduction(
// emit a x-in-bounds check before reading from the input.
if (!x_tile_in_bounds) {
llvm_ir::LlvmIfData if_x_in_bounds_data =
- llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpULT(x, index_typed_const(width)),
- "x_in_bounds", &ir_builder_);
+ llvm_ir::EmitIfThenElse(ir_builder_.CreateICmpULT(
+ x, index_typed_constant(width)),
+ "x_in_bounds", &ir_builder_);
// Points ir_builder_ to the then-block.
llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block,
&ir_builder_);
@@ -1458,14 +1463,14 @@ Status IrEmitterUnnested::EmitRowReduction(
};
return ksl.For("z_tile",
- /*start=*/index_typed_const(0),
- /*end=*/index_typed_const(z_tile_size),
+ /*start=*/index_typed_constant(0),
+ /*end=*/index_typed_constant(z_tile_size),
/*step=*/1, emit_z_tile_element_loop);
};
llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
ir_builder_.getInt1(width % (x_tile_size * kWarpSize) == 0),
- ir_builder_.CreateICmpULT(last_x, index_typed_const(width)));
+ ir_builder_.CreateICmpULT(last_x, index_typed_constant(width)));
TF_RETURN_IF_ERROR(
ksl.If(tile_in_bounds,
@@ -1519,7 +1524,7 @@ Status IrEmitterUnnested::EmitRowReduction(
// lane 0 (which holds the partially accumulated result for its warp) to the
// output element.
llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpEQ(lane_id, index_typed_const(0)),
+ ir_builder_.CreateICmpEQ(lane_id, index_typed_constant(0)),
"lane_id_is_zero", &ir_builder_);
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
&ir_builder_);
@@ -1763,7 +1768,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
source->shape(), ir_emitter_context_->device_description());
llvm::Type* index_type = GetIndexTypeForKernel(
select_and_scatter, launch_dimensions.launch_bound(), &ir_builder_);
- auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_type, c);
};
@@ -1797,7 +1802,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
"selected_value_address", &ir_builder_);
llvm::Value* selected_index_address =
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
- index_type, index_typed_const(rank), "selected_index_address",
+ index_type, index_typed_constant(rank), "selected_index_address",
&ir_builder_);
llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_);
@@ -1824,13 +1829,13 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
for (int64 i = 0; i < rank; ++i) {
llvm::Value* strided_index = ir_builder_.CreateNSWMul(
- source_index[i], index_typed_const(window.dimensions(i).stride()));
+ source_index[i], index_typed_constant(window.dimensions(i).stride()));
operand_index[i] = ir_builder_.CreateNSWSub(
ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
- index_typed_const(window.dimensions(i).padding_low()));
+ index_typed_constant(window.dimensions(i).padding_low()));
llvm::Value* index_condition = ir_builder_.CreateICmpULT(
operand_index[i],
- index_typed_const(ShapeUtil::GetDimension(operand->shape(), i)));
+ index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i)));
in_bounds_condition =
ir_builder_.CreateAnd(in_bounds_condition, index_condition);
}
@@ -2682,7 +2687,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
&ir_builder_));
}
- // For multiple outputs fusion, we need to emit each operand and the root.
+ // For multioutput fusion, we need to emit each operand and the root.
std::vector<llvm_ir::IrArray> output_arrays;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) {
output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
@@ -2792,8 +2797,9 @@ int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape(
namespace {
-// Emits the intrinsic call and meta data for thread_id. Calculates the (y,x)
-// coordinate from thread_id and returns the (y, x) coordinate.
+// Reads thread_idx.x and converts it to a (y,x) coordinate, assuming that the
+// thread lives within a square tile of size tile_size (so thread blocks are of
+// size tile_size * tile_size).
std::tuple<llvm::Value*, llvm::Value*> CalculateYXCoordinateWithinTile(
llvm::IRBuilder<>* builder, llvm::Value* tile_size,
int64 threads_per_tile) {
@@ -2802,42 +2808,42 @@ std::tuple<llvm::Value*, llvm::Value*> CalculateYXCoordinateWithinTile(
llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder);
llvm_ir::AddRangeMetadata(0, threads_per_tile,
- static_cast<llvm::Instruction*>(thread_id));
+ llvm::cast<llvm::Instruction>(thread_id));
thread_id = builder->CreateIntCast(thread_id, tile_size->getType(),
/*isSigned=*/true, "thread.id.x");
auto x = builder->CreateURem(thread_id, tile_size);
auto y = builder->CreateUDiv(thread_id, tile_size);
-
return std::make_tuple(y, x);
}
-// Emits the intrinsic call and meta data for block_id. Returns the block_id.
-llvm::Value* BuildBlockId(llvm::IRBuilder<>* builder, llvm::Type* index_ty,
- int64 num_tiles) {
+// Reads block_idx.x, casts it to type index_ty, and adds the assumption that
+// it's in the range [0, num_blocks].
+llvm::Value* GetBlockIdx(llvm::IRBuilder<>* builder, llvm::Type* index_ty,
+ int64 num_blocks) {
llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, builder);
- llvm_ir::AddRangeMetadata(0, num_tiles,
- static_cast<llvm::Instruction*>(block_id));
+ llvm_ir::AddRangeMetadata(0, num_blocks,
+ llvm::cast<llvm::Instruction>(block_id));
return builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true,
"block.id.x");
}
// Emits code to process up to (tile_size/num_rows) elements in a tile, given
-// `emit_ele_function` is the function to emit code to process one element, `y`
+// `emit_elem_function` is the function to emit code to process one element, `y`
// and `x` are the coordinates for the first element to process, and `index` is
-// the index for the origin of the tile. Emits bound check to ensure that each
+// the index for the origin of the tile. Emits bounds check to ensure that each
// processed element is within the boundary defined by `tile_width` and
// `tile_height`.
-void EmitCodeWithBoundCheck(
- int64 tile_size, int64 num_rows,
- std::function<void(const llvm_ir::IrArray::Index&, llvm::Value*)>
- emit_ele_function,
- const llvm_ir::IrArray::Index& index, const string& loop_name,
- KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y,
- llvm::Value* x, llvm::Value* tile_width, llvm::Value* tile_height) {
+void EmitTiledElementalCodeWithBoundsCheck(
+ int64 tile_size, int64 num_rows, const llvm_ir::IrArray::Index& index,
+ const string& loop_name, KernelSupportLibrary* ksl,
+ llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x,
+ llvm::Value* tile_width, llvm::Value* tile_height,
+ const std::function<void(const llvm_ir::IrArray::Index&, llvm::Value*)>&
+ emit_elem_function) {
llvm::Type* index_ty = tile_width->getType();
// Emits a constant value with index type.
- auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_ty, c);
};
// Adds `addend` to the given `dim` of `index`.
@@ -2846,48 +2852,45 @@ void EmitCodeWithBoundCheck(
index[dim] = builder->CreateAdd(index[dim], addend);
return index;
};
- auto emit_not_last_row_true = [&]() {
+
+ auto emit_full_tile = [&] {
for (int64 i = 0; i < tile_size; i += num_rows) {
- auto source_idx = offset_dim(index, index_typed_const(i), /*dim=*/1);
- auto y_loc = builder->CreateAdd(index_typed_const(i), y);
- emit_ele_function(source_idx, y_loc);
+ auto source_idx = offset_dim(index, index_typed_constant(i), /*dim=*/1);
+ auto y_loc = builder->CreateAdd(index_typed_constant(i), y);
+ emit_elem_function(source_idx, y_loc);
}
};
- auto emit_not_last_row_false = [&]() {
- ksl->IfReturnVoid(
- "x_in_tile", builder->CreateICmpULT(x, tile_width), [&]() {
- // tile_height_upper_bound = ceil(tile_height / num_rows) *
- // num_rows
- auto tile_height_upper_bound = builder->CreateMul(
- builder->CreateUDiv(
- builder->CreateAdd(tile_height,
- index_typed_const(num_rows - 1)),
- index_typed_const(num_rows)),
- index_typed_const(num_rows));
-
- ksl->ForReturnVoid(
- loop_name,
- /*start=*/index_typed_const(0),
- /*end=*/tile_height_upper_bound,
- /*step=*/index_typed_const(num_rows), [&](llvm::Value* y_indvar) {
- auto y_loc = builder->CreateAdd(y_indvar, y);
- ksl->IfReturnVoid(
- "y_in_tile", builder->CreateICmpULT(y_loc, tile_height),
- [&]() {
- emit_ele_function(offset_dim(index, y_indvar, /*dim=*/1),
- y_loc);
- });
- });
- });
+
+ auto emit_last_row = [&] {
+ ksl->IfReturnVoid("x_in_tile", builder->CreateICmpULT(x, tile_width), [&] {
+ // tile_height_upper_bound =
+ // ceil(tile_height / num_rows) * num_rows
+ auto tile_height_upper_bound = builder->CreateMul(
+ builder->CreateUDiv(
+ builder->CreateAdd(tile_height,
+ index_typed_constant(num_rows - 1)),
+ index_typed_constant(num_rows)),
+ index_typed_constant(num_rows));
+ ksl->ForReturnVoid(
+ loop_name, /*start=*/index_typed_constant(0),
+ /*end=*/tile_height_upper_bound,
+ /*step=*/index_typed_constant(num_rows), [&](llvm::Value* y_indvar) {
+ auto y_loc = builder->CreateAdd(y_indvar, y);
+ ksl->IfReturnVoid(
+ "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), [&] {
+ emit_elem_function(offset_dim(index, y_indvar, /*dim=*/1),
+ y_loc);
+ });
+ });
+ });
};
ksl->IfReturnVoid(
- "not_last_row",
+ "full_tile",
builder->CreateAnd(
- builder->CreateICmpEQ(index_typed_const(tile_size), tile_width),
- builder->CreateICmpEQ(index_typed_const(tile_size), tile_height)),
- emit_not_last_row_true, emit_not_last_row_false);
+ builder->CreateICmpEQ(index_typed_constant(tile_size), tile_width),
+ builder->CreateICmpEQ(index_typed_constant(tile_size), tile_height)),
+ emit_full_tile, emit_last_row);
}
-
} // namespace
// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
@@ -2895,91 +2898,81 @@ void EmitCodeWithBoundCheck(
// which have a shape that is a 0-2-1 transpose of the output tensors.
//
// For the purpose of tiling, the output tensors have a logical shape of three
-// compoents 0-2-1 while the relevant input parameters have a logical shape of
+// components 0-2-1 while the relevant input parameters have a logical shape of
// three components 0-1-2 in the order major to minor. The x- and y- dimensions
-// of the tensores are tiled in square tiles of edge length `tile_size`. Each
-// thread block of `tile_size` x `num_rows` threads transposes one tile: each
-// thread copies tile_size/num_rows elements from the input to a shared memory
+// of the tensors are tiled in square tiles of edge length `kTileSize`. Each
+// thread block of `kTileSize` x `kNumRows` threads transposes one tile: each
+// thread copies kTileSize/kNumRows elements from the input to a shared memory
// tile, then the otherwise "regular hlo kernel" reads from the shared memory
// instead of the original input.
//
-// This is similar to the algorithm in CUDA:
-// https://github.com/tensorflow/tensorflow/blob/d2693c8a70567cc78b2e8a9ac8020d321620ca83/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L189
+// This is similar to the following CUDA algorithm in TensorFlow:
+// https://goo.gl/MStRV6.
//
-// `tile_size` should usually be same as warp size. We currently choose
-// 32 for `tile_size` and 4 for `num_rows`. The CUDA algorithm usesa 8 for
-// `num_rows`.
+// `kTileSize` should usually be same as warp size. We currently choose 32 for
+// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`.
//
// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient
-// to launch fewer blocks so each transposes many tiles, and
-// in any case, the number of blocks we can launch is limited.
-//
+// to launch fewer blocks so each transposes many tiles.
LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
HloInstruction* hlo, tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
tensorflow::gtl::ArraySlice<int64> tiled_param_ids) {
// Parameters for the tiling algorithm.
- constexpr int64 tile_size = 32;
- constexpr int64 num_rows = 4;
- constexpr int64 threads_per_tile = tile_size * num_rows;
+ constexpr int64 kTileSize = 32;
+ constexpr int64 kNumRows = 4;
+ constexpr int64 kThreadsPerTile = kTileSize * kNumRows;
- // Construct the IrArray for the outputs.
- std::vector<llvm_ir::IrArray> output_arrays;
+ // Construct IrArrays for the inputs and outputs.
+ std::vector<IrArray> output_arrays;
int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays);
-
- // Construct the IrArray for the inputs and initialize the array to store
- // the tile buffers for the tiled parameters.
- std::vector<llvm_ir::IrArray> param_arrays;
+ std::vector<IrArray> param_arrays;
int64 num_params = ConstructIrArrayForInputs(*hlo, &param_arrays);
- std::vector<llvm::Value*> param_buffers;
- param_buffers.reserve(num_params);
- for (int i = 0; i < num_params; ++i) {
- param_buffers.push_back(nullptr);
- }
- // Allocate share memory buffers to store the tiled parameters.
+ // Allocate shared memory buffers to store the tiled inputs.
+ std::vector<llvm::Value*> param_shmem_buffers(num_params, nullptr);
for (int64 id : tiled_param_ids) {
const HloInstruction* param = hlo->operand(id);
- // Add 1 to the minor dimension to reduce share memory bank conflict.
+ // Add 1 to the minor dimension to reduce shared memory bank conflicts.
llvm::Type* tile_type = llvm::ArrayType::get(
llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType(
param->shape().element_type(), module_),
- tile_size + 1),
- tile_size);
+ kTileSize + 1),
+ kTileSize);
+ const int kNVPTXSharedMemoryAddrSpace = 3;
auto* tile_base_ptr = new llvm::GlobalVariable(
*ir_builder_.GetInsertBlock()->getParent()->getParent(), tile_type,
/*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
- llvm::UndefValue::get(tile_type), "tile", nullptr,
- llvm::GlobalValue::NotThreadLocal,
- /*AddressSpace=*/3 /* GPU shared memory */);
- param_buffers[id] = tile_base_ptr;
- VLOG(3) << "Add buffer for parameter " << id << " "
+ llvm::UndefValue::get(tile_type),
+ llvm_ir::AsStringRef(IrName(hlo, StrCat("tile", id))), nullptr,
+ llvm::GlobalValue::NotThreadLocal, kNVPTXSharedMemoryAddrSpace);
+ param_shmem_buffers[id] = tile_base_ptr;
+ VLOG(3) << "Added shmem buffer for parameter " << id << ": "
<< llvm_ir::DumpToString(*tile_base_ptr);
}
// The 0-2-1 shape of the tiling scheme is the reduced shape of the HLO result
- // for the purpose of tiling. Calculate the logical output dimensions in tile
- // from the reduced output dimensions.
+ // for the purpose of tiling. Calculate the logical output dimensions in the
+ // tile from the reduced output dimensions.
std::vector<int64> output_dims_in_tiles = std::vector<int64>(
reduced_output_dims.begin(), reduced_output_dims.end());
CHECK_EQ(output_dims_in_tiles.size(), 3);
for (int i = 1; i < 3; ++i) {
output_dims_in_tiles[i] =
- CeilOfRatio<int64>(output_dims_in_tiles[i], tile_size);
+ CeilOfRatio<int64>(output_dims_in_tiles[i], kTileSize);
}
const int64 num_tiles =
- std::accumulate(output_dims_in_tiles.begin(), output_dims_in_tiles.end(),
- 1, std::multiplies<int64>());
- LaunchDimensions launch_dimensions(num_tiles, threads_per_tile);
+ c_accumulate(output_dims_in_tiles, 1, std::multiplies<int64>());
+ LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile);
llvm::Type* index_ty = GetIndexTypeForKernel(
hlo, launch_dimensions.launch_bound(), &ir_builder_);
- auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_ty, c);
};
// Cast each output IrArray to its corresponding reduced shape and keep the
// reduced shape live during IR emission.
- std::vector<llvm_ir::IrArray> output_in_reduced_shape_arrays;
+ std::vector<IrArray> output_in_reduced_shape_arrays;
std::vector<Shape> output_reduced_shapes;
CHECK_EQ(ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
*hlo, output_arrays, reduced_output_dims, &output_reduced_shapes,
@@ -2988,10 +2981,10 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
// For each tiled parameter, cast its input IrArray to the corresponding
// reduced shape and keep the reduced shape live during IR emission.
- std::vector<llvm_ir::IrArray> param_in_reduced_shape_arrays;
+ std::vector<IrArray> param_in_reduced_shape_arrays;
std::vector<Shape> param_reduced_shapes;
CHECK_EQ(ConstructInputReducedShapeAndCastInputIrArrayToShape(
- *hlo, param_arrays, param_buffers, reduced_output_dims,
+ *hlo, param_arrays, param_shmem_buffers, reduced_output_dims,
&param_reduced_shapes, &param_in_reduced_shape_arrays),
num_params);
@@ -3000,83 +2993,82 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
llvm::Value* x;
llvm::Value* y;
std::tie(y, x) = CalculateYXCoordinateWithinTile(
- &ir_builder_, index_typed_const(tile_size), threads_per_tile);
+ &ir_builder_, index_typed_constant(kTileSize), kThreadsPerTile);
// Calculate the index for the current output tile from block_id.
- const llvm_ir::IrArray::Index output_tile_index(
- BuildBlockId(&ir_builder_, index_ty, num_tiles),
+ const IrArray::Index output_tile_index(
+ GetBlockIdx(&ir_builder_, index_ty, num_tiles),
ShapeUtil::MakeShapeWithDescendingLayout(PRED /*arbitrary*/,
output_dims_in_tiles),
&ir_builder_);
// Output tile origin is the index for the first element of the current output
// tile.
- const llvm_ir::IrArray::Index output_tile_origin = ({
- llvm_ir::IrArray::Index index = output_tile_index;
+ const IrArray::Index output_tile_origin = [&] {
+ IrArray::Index index = output_tile_index;
for (int i = 1; i < 3; ++i) {
- index[i] = ir_builder_.CreateMul(index[i], index_typed_const(tile_size),
+ index[i] = ir_builder_.CreateMul(output_tile_index[i],
+ index_typed_constant(kTileSize),
"tile_origin." + std::to_string(i));
}
- index;
- });
+ return index;
+ }();
// Calculate the input tile origin from the output tile origin.
- const llvm_ir::IrArray::Index input_tile_origin(
+ const IrArray::Index input_tile_origin(
Permute({0, 2, 1}, output_tile_origin.multidim()));
- // A utility to add `addend` to the given `dim` of `index`.
- auto offset_dim = [&](llvm_ir::IrArray::Index index, llvm::Value* addend,
- int64 dim) {
- index[dim] = ir_builder_.CreateAdd(index[dim], addend);
- return index;
- };
-
- // Calculate the current output tile bounds in each of the logical dimension.
+ // Calculate the current output tile bounds in each of the logical dimensions.
std::vector<llvm::Value*> output_tile_bounds(3);
for (int i = 1; i < 3; ++i) {
// Only last row or column may not have full size.
output_tile_bounds[i] = ir_builder_.CreateSelect(
ir_builder_.CreateICmpEQ(
output_tile_index[i],
- index_typed_const(output_dims_in_tiles[i] - 1)),
- index_typed_const(reduced_output_dims[i] -
- (output_dims_in_tiles[i] - 1) * tile_size),
- index_typed_const(tile_size), "tile_size");
+ index_typed_constant(output_dims_in_tiles[i] - 1)),
+ index_typed_constant(reduced_output_dims[i] -
+ (output_dims_in_tiles[i] - 1) * kTileSize),
+ index_typed_constant(kTileSize), "kTileSize");
}
KernelSupportLibrary ksl(&ir_builder_, llvm_ir::UnrollMode::kDefaultUnroll);
- // A wrapper of EmitCodeWithBoundCheck, to pass in the parameters common to
- // both call sites to of the routine.
- auto emit_code_with_bound_check =
- [&](std::function<void(const llvm_ir::IrArray::Index&, llvm::Value*)>
- emit_ele_function,
- const llvm_ir::IrArray::Index& index, const string& loop_name,
- llvm::Value* tile_width, llvm::Value* tile_height) {
- EmitCodeWithBoundCheck(tile_size, num_rows, emit_ele_function, index,
- loop_name, &ksl, &ir_builder_, y, x, tile_width,
- tile_height);
+
+ // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck.
+ auto emit_tiled_elemental_code_with_bounds_check =
+ [&](const IrArray::Index& index, const string& loop_name,
+ llvm::Value* tile_width, llvm::Value* tile_height,
+ const std::function<void(const llvm_ir::IrArray::Index&,
+ llvm::Value*)>& emit_elem_function) {
+ EmitTiledElementalCodeWithBoundsCheck(
+ kTileSize, kNumRows, index, loop_name, &ksl, &ir_builder_, y, x,
+ tile_width, tile_height, emit_elem_function);
};
- auto copy_tiled_params_to_buffers = [&](const llvm_ir::IrArray::Index& index,
- llvm::Value* y_loc) {
- for (int64 id : tiled_param_ids) {
- llvm_ir::IrArray& input_in_logical_shape =
- param_in_reduced_shape_arrays[id];
- llvm::Value* buffer = param_buffers[id];
- ir_builder_.CreateStore(
- input_in_logical_shape.EmitReadArrayElement(index, &ir_builder_,
- "input_element"),
- ir_builder_.CreateGEP(buffer, {index_typed_const(0), y_loc, x}));
- }
+ // Adds `addend` to the given `dim` of `index`.
+ auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) {
+ index[dim] = ir_builder_.CreateAdd(index[dim], addend);
+ return index;
};
-
- const llvm_ir::IrArray::Index input_index =
+ const IrArray::Index input_index =
offset_dim(offset_dim(input_tile_origin, x, /*dim=*/2), y, /*dim=*/1);
- // Copy input parameter values to shared memory buffers.
- emit_code_with_bound_check(
- // tile[y, x] = input[index]
- copy_tiled_params_to_buffers, input_index, "input", output_tile_bounds[1],
- output_tile_bounds[2]);
+
+ // Copy input parameter values to shared memory buffers:
+ // tile[y, x] = input[index]
+ emit_tiled_elemental_code_with_bounds_check(
+ input_index, "input", output_tile_bounds[1], output_tile_bounds[2],
+ [&](const IrArray::Index& index, llvm::Value* y_loc) {
+ for (int64 id : tiled_param_ids) {
+ IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id];
+ llvm::Value* shmem_buffer = param_shmem_buffers[id];
+ // TODO(jlebar): Add AA metadata to this store. Tile buffers are
+ // global variables, so LLVM can't infer much about it.
+ ir_builder_.CreateStore(
+ input_in_logical_shape.EmitReadArrayElement(index, &ir_builder_,
+ "input_element"),
+ ir_builder_.CreateGEP(shmem_buffer,
+ {index_typed_constant(0), y_loc, x}));
+ }
+ });
// Wait for all threads to reach this point, lest we copy a value from tile to
// output before the other thread copies it from input to tile.
@@ -3084,59 +3076,60 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {},
&ir_builder_);
- llvm_ir::TiledParameterInfo tiled_param_info(param_buffers, y, x);
- // A helper to emit an otherwise usual kernel for an hlo instruction except
- // that values for the tiled parameters are read from tile buffers.
- auto emit_kernel_using_buffers = [&](const llvm_ir::IrArray::Index& index,
- llvm::Value* y_loc) {
- if (hlo->opcode() == HloOpcode::kCopy) {
- llvm::Instruction* load_from_buffer = ir_builder_.CreateLoad(
- ir_builder_.CreateGEP(param_buffers[0],
- {ir_builder_.getInt64(0), x, y_loc}),
- "output_element");
- output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
- index, load_from_buffer, &ir_builder_);
- } else {
- CHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
- GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_,
- &ir_builder_, GetNestedComputer());
- FusedIrEmitter fused_emitter(param_arrays, &elem_emitter);
- tiled_param_info.set_y(y_loc);
- fused_emitter.SetTiledParameterInfo(&tiled_param_info);
- TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter));
- llvm_ir::IrArray::Index untiled_index = llvm_ir::GetUnreducedOutputIndex(
- index, output_reduced_shapes[0], output_arrays[0].GetShape(),
- &ir_builder_);
- const llvm_ir::ElementGenerator& output_generator =
- fused_emitter.GetRootGenerator();
- llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
- if (hlo->IsMultiOutputFusion()) {
- CHECK(output_value->getType()->isStructTy())
- << "This BodyEmitter is for multi-output fusion, but target "
- "element "
- "generator does not produce values of struct type.";
- CHECK_EQ(output_value->getType()->getStructNumElements(),
- output_in_reduced_shape_arrays.size());
- for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) {
- output_in_reduced_shape_arrays[i].EmitWriteArrayElement(
- index, ir_builder_.CreateExtractValue(output_value, i),
- &ir_builder_);
- }
- } else {
- output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
- index, output_value, &ir_builder_);
- }
- }
- };
+ llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x);
- const llvm_ir::IrArray::Index output_index =
+ const IrArray::Index output_index =
offset_dim(offset_dim(output_tile_origin, x, /*dim=*/2), y, /*dim=*/1);
- emit_code_with_bound_check(
- // output[index] = ...
- emit_kernel_using_buffers, output_index, "output", output_tile_bounds[2],
- output_tile_bounds[1]);
- // For multiple output fusion, emit a tuple with all the individual outputs.
+ // Write to output[index] by emitting code like normal, except that values for
+ // the tiled parameters are read from the shmem buffers.
+ if (hlo->opcode() == HloOpcode::kCopy) {
+ emit_tiled_elemental_code_with_bounds_check(
+ output_index, "output", output_tile_bounds[2], output_tile_bounds[1],
+ [&](const IrArray::Index& index, llvm::Value* y_loc) {
+ // TODO(jlebar): Add AA metadata to this load.
+ llvm::Instruction* load_from_shmem_buffer = ir_builder_.CreateLoad(
+ ir_builder_.CreateGEP(param_shmem_buffers[0],
+ {ir_builder_.getInt64(0), x, y_loc}),
+ "output_element");
+ output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
+ index, load_from_shmem_buffer, &ir_builder_);
+ });
+ } else {
+ CHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
+ emit_tiled_elemental_code_with_bounds_check(
+ output_index, "output", output_tile_bounds[2], output_tile_bounds[1],
+ [&](const IrArray::Index& index, llvm::Value* y_loc) {
+ GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_,
+ &ir_builder_, GetNestedComputer());
+ FusedIrEmitter fused_emitter(param_arrays, &elem_emitter);
+ tiled_param_info.set_y(y_loc);
+ fused_emitter.SetTiledParameterInfo(&tiled_param_info);
+ TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter));
+ IrArray::Index untiled_index = llvm_ir::GetUnreducedOutputIndex(
+ index, output_reduced_shapes[0], output_arrays[0].GetShape(),
+ &ir_builder_);
+ const llvm_ir::ElementGenerator& output_generator =
+ fused_emitter.GetRootGenerator();
+ llvm::Value* output_value =
+ output_generator(untiled_index).ValueOrDie();
+ if (hlo->IsMultiOutputFusion()) {
+ CHECK(output_value->getType()->isStructTy());
+ CHECK_EQ(output_value->getType()->getStructNumElements(),
+ output_in_reduced_shape_arrays.size());
+ for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) {
+ output_in_reduced_shape_arrays[i].EmitWriteArrayElement(
+ index, ir_builder_.CreateExtractValue(output_value, i),
+ &ir_builder_);
+ }
+ } else {
+ output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
+ index, output_value, &ir_builder_);
+ }
+ });
+ }
+
+ // For multioutput fusion, emit a tuple with all the individual outputs.
if (hlo->IsMultiOutputFusion()) {
std::vector<llvm::Value*> tuple_operand_ptrs;
for (int64 i = 0; i < output_arrays.size(); ++i) {
@@ -3152,53 +3145,52 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
HloOpcode opcode = hlo->opcode();
CHECK(opcode == HloOpcode::kFusion || opcode == HloOpcode::kCopy);
+ CHECK(opcode != HloOpcode::kFusion ||
+ hlo->fusion_kind() == HloInstruction::FusionKind::kLoop)
+ << "Only loop fusions are supported.";
+
const Shape& output_shape = hlo->IsMultiOutputFusion()
? ShapeUtil::GetSubshape(hlo->shape(), {0})
: hlo->shape();
- // If the output_shape is reduced to 021 shape find all the parameters of the
+
+ // If the output_shape is reduced to 021 shape, find all the parameters of the
// hlo that are in the corresponding 012 shape.
std::vector<int64> params_012;
- bool found = false;
- std::vector<int64> reduced_dims_021;
- int64 index = 0;
- for (HloInstruction* operand : hlo->operands()) {
+ optional<std::vector<int64>> reduced_dims_021;
+ for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
+ ++operand_idx) {
+ HloInstruction* operand = hlo->mutable_operand(operand_idx);
auto find_transpose_result =
llvm_ir::FindTranspose021(operand->shape(), output_shape);
- if (find_transpose_result.has_value()) {
- const std::vector<int64>& curr_reduced_dims_021 = *find_transpose_result;
- if (found) {
- if (!ContainersEqual(reduced_dims_021, curr_reduced_dims_021)) {
- // There are more than one possible transpose. Instead of picking one
- // transpose, we simply give up here.
- found = false;
- break;
- }
- } else {
- found = true;
- reduced_dims_021 = curr_reduced_dims_021;
- }
- if (found) {
- params_012.push_back(index);
- }
+ if (!find_transpose_result.has_value()) {
+ continue;
+ }
+ const std::vector<int64>& curr_reduced_dims_021 = *find_transpose_result;
+ if (!reduced_dims_021.has_value()) {
+ reduced_dims_021 = curr_reduced_dims_021;
+ }
+ if (!ContainersEqual(*reduced_dims_021, curr_reduced_dims_021)) {
+ // There is more than one possible transpose. Instead of picking one
+ // transpose, we simply give up here.
+ return false;
}
- ++index;
+ params_012.push_back(operand_idx);
}
- if (!found) {
+ if (!reduced_dims_021.has_value()) {
return false;
}
- if (reduced_dims_021[1] < kMinDimensionToTransposeTiled ||
- reduced_dims_021[2] < kMinDimensionToTransposeTiled) {
+ if ((*reduced_dims_021)[1] < kMinDimensionToTransposeTiled ||
+ (*reduced_dims_021)[2] < kMinDimensionToTransposeTiled) {
return false;
}
VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString();
thunk_sequence_->emplace_back(
- BuildKernelThunk(hlo,
- /*implements_whole_instruction=*/true));
+ BuildKernelThunk(hlo, /*implements_whole_instruction=*/true));
const LaunchDimensions launch_dimensions =
- EmitHlo021Tile(hlo, reduced_dims_021, params_012);
+ EmitHlo021Tile(hlo, *reduced_dims_021, params_012);
UpdateLaunchDimensions(launch_dimensions, LastThunk(),
ir_emitter_context_->llvm_module());