aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-08-27 18:50:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 18:55:02 -0700
commitfa607e7e9224b4d88ead0a81fc65c7884d25950a (patch)
tree3d0acc58934efb515a5b28d38b9397b7aa467205 /tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
parent9422d3d57c62399e425f63a769dcfa6ebd163bdc (diff)
Use a mixin to reduce llvm::IRBuilder<> related boilerplate.
PiperOrigin-RevId: 210472260
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc264
1 files changed, 124 insertions, 140 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 4d98955c58..c0c8ae181a 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -729,7 +729,7 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce(
"extra_output_element_address");
TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
extra_output_gens[i].first(index));
- b_.CreateStore(extra_output_ir_value, extra_output_address);
+ Store(extra_output_ir_value, extra_output_address);
}
return Status::OK();
}
@@ -810,17 +810,17 @@ Status IrEmitterUnnested::EmitReductionToScalar(
std::vector<llvm::Value*> partial_reduction_result_addresses;
for (int i = 0; i != num_reduces; ++i) {
llvm::Value* partial_reduction_result_address =
- b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." + llvm::Twine(i));
+ Alloca(element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." + llvm::Twine(i));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
init_value_gens[i](IrArray::Index(index_ty)));
- b_.CreateStore(init_ir_value, partial_reduction_result_address);
+ Store(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
llvm::Value* x_in_tiles = tile_index[0];
- x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty);
+ x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty);
// Emit an inner for-loop that reduces the elements in the tile.
auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
@@ -832,15 +832,14 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
&b_);
- llvm::Value* x = b_.CreateNSWAdd(
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)),
- tile_element_loop->GetIndVarValue());
+ llvm::Value* x =
+ NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileSize)),
+ tile_element_loop->GetIndVarValue());
// Unless we know the tile is entirely in bounds, we have to emit a
// x-in-bounds check before reading from the input.
if (!tile_in_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds",
- &b_);
+ ICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", &b_);
// Emit code that reads the input element and accumulates it to
// the partial reduction result.
@@ -849,11 +848,11 @@ Status IrEmitterUnnested::EmitReductionToScalar(
IrArray::Index input_index(
/*linear=*/x, input_shape, &b_);
- llvm::Value* input_address = b_.CreateAlloca(element_ir_type);
+ llvm::Value* input_address = Alloca(element_ir_type);
for (int i = 0; i != num_reduces; ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
input_gens[i](input_index));
- b_.CreateStore(input_ir_value, input_address);
+ Store(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], input_address},
@@ -864,14 +863,14 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's
// immediately beyond the tile.
- llvm::Value* x_end = b_.CreateNSWAdd(
- index_typed_constant(kTileSize),
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)));
+ llvm::Value* x_end =
+ NSWAdd(index_typed_constant(kTileSize),
+ NSWMul(x_in_tiles, index_typed_constant(kTileSize)));
// The tile is entirely in bound if all_threads_in_bounds or
// x_end <= num_elems.
llvm::Value* tile_in_bounds =
- b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(num_elems)),
- b_.getInt1(all_threads_in_bounds));
+ Or(ICmpULE(x_end, index_typed_constant(num_elems)),
+ b_.getInt1(all_threads_in_bounds));
llvm_ir::LlvmIfData if_tile_in_bounds_data =
llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_);
llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_);
@@ -892,20 +891,18 @@ Status IrEmitterUnnested::EmitReductionToScalar(
for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1;
shuffle_distance /= 2) {
llvm::Value* result_from_other_lane =
- b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane");
+ Alloca(element_ir_type, nullptr, "result_from_other_lane");
for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result = b_.CreateLoad(
- b_.CreateBitCast(partial_reduction_result_addresses[i],
- shuffle_ir_type->getPointerTo()),
- "partial_reduction_result");
+ llvm::Value* partial_reduction_result =
+ Load(BitCast(partial_reduction_result_addresses[i],
+ shuffle_ir_type->getPointerTo()),
+ "partial_reduction_result");
CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0)
<< "Requires block size a multiple of the warp size, otherwise we "
"will read undefined elements.";
- b_.CreateStore(
- EmitFullWarpShuffleDown(partial_reduction_result,
- b_.getInt32(shuffle_distance), &b_),
- b_.CreateBitCast(result_from_other_lane,
- shuffle_ir_type->getPointerTo()));
+ Store(EmitFullWarpShuffleDown(partial_reduction_result,
+ b_.getInt32(shuffle_distance), &b_),
+ BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo()));
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], result_from_other_lane},
@@ -920,10 +917,9 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// lane 0 (which holds the partially accumulated result for its warp) to the
// output element.
llvm::Value* lane_id =
- b_.CreateURem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id");
+ URem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id");
llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero",
- &b_);
+ ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_);
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
for (int i = 0; i != num_reduces; ++i) {
@@ -1043,12 +1039,12 @@ Status IrEmitterUnnested::EmitColumnReduction(
for (int i = 0; i != num_reduces; ++i) {
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
llvm::Value* partial_reduction_result_address =
- b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." +
- llvm::Twine(i * kTileWidth + x_offset));
+ Alloca(element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." +
+ llvm::Twine(i * kTileWidth + x_offset));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
init_value_gens[i](IrArray::Index(index_ty)));
- b_.CreateStore(init_ir_value, partial_reduction_result_address);
+ Store(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
@@ -1059,8 +1055,8 @@ Status IrEmitterUnnested::EmitColumnReduction(
llvm::Value* y_in_tiles = tile_index[0];
llvm::Value* x_in_tiles = tile_index[1];
- y_in_tiles = b_.CreateZExtOrTrunc(y_in_tiles, index_ty);
- x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty);
+ y_in_tiles = ZExtOrTrunc(y_in_tiles, index_ty);
+ x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty);
auto emit_tile_element_loop = [=](bool tile_in_y_bounds,
bool tile_in_x_bounds) -> Status {
@@ -1072,34 +1068,32 @@ Status IrEmitterUnnested::EmitColumnReduction(
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
&b_);
- llvm::Value* y = b_.CreateNSWAdd(
- b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)),
- tile_element_loop->GetIndVarValue());
+ llvm::Value* y =
+ NSWAdd(NSWMul(y_in_tiles, index_typed_constant(kTileHeight)),
+ tile_element_loop->GetIndVarValue());
// Unless we know that y is in bounds, we have to emit a check before
// reading from the input.
if (!tile_in_y_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpULT(y, index_typed_constant(height)), "y_in_bounds",
- &b_);
+ ICmpULT(y, index_typed_constant(height)), "y_in_bounds", &b_);
// Emit code that reads the input element and accumulates it to
// the partial reduction result.
llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_);
}
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
- llvm::Value* x = b_.CreateNSWAdd(
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
- index_typed_constant(x_offset));
+ llvm::Value* x =
+ NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
+ index_typed_constant(x_offset));
// Unless we know that x is in bounds, we have to emit a check before
// reading from the input.
if (!tile_in_x_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpULT(x, index_typed_constant(width)), "x_in_bounds",
- &b_);
+ ICmpULT(x, index_typed_constant(width)), "x_in_bounds", &b_);
llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_);
}
- llvm::Value* input_address = b_.CreateAlloca(element_ir_type);
+ llvm::Value* input_address = Alloca(element_ir_type);
// {y,x} is an index to input_matrix_shape [height,width]. We need to
// convert that to an index to input_shape (the shape of the operand of
// "reduce"). This conversion is composed of a transposition from
@@ -1126,7 +1120,7 @@ Status IrEmitterUnnested::EmitColumnReduction(
for (int i = 0; i != num_reduces; ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
input_gens[i](input_index));
- b_.CreateStore(input_ir_value, input_address);
+ Store(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i * kTileWidth + x_offset],
@@ -1141,20 +1135,20 @@ Status IrEmitterUnnested::EmitColumnReduction(
// y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location
// that's immediately beyond the tile.
- llvm::Value* y_end = b_.CreateNSWAdd(
- index_typed_constant(kTileHeight),
- b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)));
+ llvm::Value* y_end =
+ NSWAdd(index_typed_constant(kTileHeight),
+ NSWMul(y_in_tiles, index_typed_constant(kTileHeight)));
// x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location
// that's immediately beyond the tile.
- llvm::Value* x_end = b_.CreateNSWAdd(
- index_typed_constant(kTileWidth),
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)));
+ llvm::Value* x_end =
+ NSWAdd(index_typed_constant(kTileWidth),
+ NSWMul(x_in_tiles, index_typed_constant(kTileWidth)));
llvm::Value* tile_in_y_bounds =
- b_.CreateOr(b_.CreateICmpULE(y_end, index_typed_constant(height)),
- b_.getInt1(height % kTileHeight == 0));
+ Or(ICmpULE(y_end, index_typed_constant(height)),
+ b_.getInt1(height % kTileHeight == 0));
llvm::Value* tile_in_x_bounds =
- b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(width)),
- b_.getInt1(width % kTileWidth == 0));
+ Or(ICmpULE(x_end, index_typed_constant(width)),
+ b_.getInt1(width % kTileWidth == 0));
// The tile is in y bounds if "height" is a multiple of kTileHeight or
// y_end <= height.
llvm_ir::LlvmIfData if_tile_in_y_bounds_data =
@@ -1188,9 +1182,9 @@ Status IrEmitterUnnested::EmitColumnReduction(
reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
for (int i = 0; i != num_reduces; ++i) {
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
- llvm::Value* x = b_.CreateNSWAdd(
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
- index_typed_constant(x_offset));
+ llvm::Value* x =
+ NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
+ index_typed_constant(x_offset));
llvm::Value* output_address =
GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
@@ -1379,11 +1373,11 @@ Status IrEmitterUnnested::EmitRowReduction(
std::vector<llvm::Value*> partial_reduction_result_addresses;
for (int i = 0; i != num_reduces; ++i) {
llvm::Value* partial_reduction_result_address =
- b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." + llvm::Twine(i));
+ Alloca(element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." + llvm::Twine(i));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
init_value_gens[i](IrArray::Index(index_ty)));
- b_.CreateStore(init_ir_value, partial_reduction_result_address);
+ Store(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
@@ -1392,22 +1386,20 @@ Status IrEmitterUnnested::EmitRowReduction(
llvm::Value* y = tile_index[1];
llvm::Value* x_tile = tile_index[2];
- x_tile = b_.CreateZExtOrTrunc(x_tile, index_ty);
+ x_tile = ZExtOrTrunc(x_tile, index_ty);
llvm::Value* warp_id =
- b_.CreateUDiv(x_tile, index_typed_constant(kWarpSize), "warp_id");
+ UDiv(x_tile, index_typed_constant(kWarpSize), "warp_id");
llvm::Value* lane_id =
- b_.CreateURem(x_tile, index_typed_constant(kWarpSize), "lane_id");
+ URem(x_tile, index_typed_constant(kWarpSize), "lane_id");
// The x-location of the last element in this z-x-tile.
// last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size);
- llvm::Value* last_x = b_.CreateNSWAdd(
+ llvm::Value* last_x = NSWAdd(
lane_id,
- b_.CreateNSWMul(
- index_typed_constant(kWarpSize),
- b_.CreateNSWAdd(
- index_typed_constant(x_tile_size - 1),
- b_.CreateNSWMul(warp_id, index_typed_constant(x_tile_size)))));
+ NSWMul(index_typed_constant(kWarpSize),
+ NSWAdd(index_typed_constant(x_tile_size - 1),
+ NSWMul(warp_id, index_typed_constant(x_tile_size)))));
KernelSupportLibrary ksl(
&b_,
@@ -1419,9 +1411,8 @@ Status IrEmitterUnnested::EmitRowReduction(
auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds,
int64 x_tile_loop_bound) -> Status {
auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status {
- llvm::Value* z = b_.CreateNSWAdd(
- z_indvar,
- b_.CreateNSWMul(index_typed_constant(z_tile_size), z_tile));
+ llvm::Value* z =
+ NSWAdd(z_indvar, NSWMul(index_typed_constant(z_tile_size), z_tile));
TF_RETURN_IF_ERROR(ksl.For(
"x_tile",
/*start=*/index_typed_constant(0),
@@ -1429,22 +1420,20 @@ Status IrEmitterUnnested::EmitRowReduction(
/*step=*/1, [&](llvm::Value* x_indvar) -> Status {
// x = lane_id +
// warpSize * (element_id_in_x_tile + warp_id * x_tile_size);
- llvm::Value* x = b_.CreateNSWAdd(
+ llvm::Value* x = NSWAdd(
lane_id,
- b_.CreateNSWMul(
- index_typed_constant(kWarpSize),
- b_.CreateNSWAdd(
- x_indvar, b_.CreateNSWMul(
- warp_id, llvm::ConstantInt::get(
- index_ty, x_tile_size)))));
+ NSWMul(index_typed_constant(kWarpSize),
+ NSWAdd(x_indvar,
+ NSWMul(warp_id, llvm::ConstantInt::get(
+ index_ty, x_tile_size)))));
// Unless we know the x-tile is entirely in bounds, we have to
// emit a x-in-bounds check before reading from the input.
if (!x_tile_in_bounds) {
llvm_ir::LlvmIfData if_x_in_bounds_data =
llvm_ir::EmitIfThenElse(
- b_.CreateICmpULT(x, index_typed_constant(width)),
- "x_in_bounds", &b_);
+ ICmpULT(x, index_typed_constant(width)), "x_in_bounds",
+ &b_);
// Points b_ to the then-block.
llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block,
&b_);
@@ -1452,7 +1441,7 @@ Status IrEmitterUnnested::EmitRowReduction(
// Emit code that reads the input element and accumulates it
// to the partial reduction result.
- llvm::Value* input_address = b_.CreateAlloca(element_ir_type);
+ llvm::Value* input_address = Alloca(element_ir_type);
{
// {z,y,x} is an index to input_3d_tensor_shape
// [depth,height,width]. We need to convert that to an index
@@ -1483,7 +1472,7 @@ Status IrEmitterUnnested::EmitRowReduction(
for (int i = 0; i != num_reduces; ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
input_gens[i](input_index));
- b_.CreateStore(input_ir_value, input_address);
+ Store(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], input_address},
@@ -1503,8 +1492,8 @@ Status IrEmitterUnnested::EmitRowReduction(
};
llvm::Value* tile_in_bounds =
- b_.CreateOr(b_.getInt1(width % (x_tile_size * kWarpSize) == 0),
- b_.CreateICmpULT(last_x, index_typed_constant(width)));
+ Or(b_.getInt1(width % (x_tile_size * kWarpSize) == 0),
+ ICmpULT(last_x, index_typed_constant(width)));
TF_RETURN_IF_ERROR(
ksl.If(tile_in_bounds,
@@ -1532,20 +1521,18 @@ Status IrEmitterUnnested::EmitRowReduction(
for (int shuffle_distance = 16; shuffle_distance >= 1;
shuffle_distance /= 2) {
llvm::Value* result_from_other_lane =
- b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane");
+ Alloca(element_ir_type, nullptr, "result_from_other_lane");
for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result = b_.CreateLoad(
- b_.CreateBitCast(partial_reduction_result_addresses[i],
- shuffle_ir_type->getPointerTo()),
- "partial_reduction_result");
+ llvm::Value* partial_reduction_result =
+ Load(BitCast(partial_reduction_result_addresses[i],
+ shuffle_ir_type->getPointerTo()),
+ "partial_reduction_result");
CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0)
<< "Requires block size a multiple of the warp size, otherwise we "
"will read undefined elements.";
- b_.CreateStore(
- EmitFullWarpShuffleDown(partial_reduction_result,
- b_.getInt32(shuffle_distance), &b_),
- b_.CreateBitCast(result_from_other_lane,
- shuffle_ir_type->getPointerTo()));
+ Store(EmitFullWarpShuffleDown(partial_reduction_result,
+ b_.getInt32(shuffle_distance), &b_),
+ BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo()));
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], result_from_other_lane},
@@ -1560,8 +1547,7 @@ Status IrEmitterUnnested::EmitRowReduction(
// lane 0 (which holds the partially accumulated result for its warp) to the
// output element.
llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero",
- &b_);
+ ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_);
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
for (int i = 0; i != num_reduces; ++i) {
llvm::Value* output_address =
@@ -1845,7 +1831,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
&b_);
llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
b_.getInt1Ty(), "initialized_flag_address", &b_);
- b_.CreateStore(b_.getInt1(false), initialized_flag_address);
+ Store(b_.getInt1(false), initialized_flag_address);
// Create the inner loop to iterate over the window.
llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_,
@@ -1866,15 +1852,15 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
IrArray::Index operand_index(index_type, source_index.size());
llvm::Value* in_bounds_condition = b_.getInt1(true);
for (int64 i = 0; i < rank; ++i) {
- llvm::Value* strided_index = b_.CreateNSWMul(
+ llvm::Value* strided_index = NSWMul(
source_index[i], index_typed_constant(window.dimensions(i).stride()));
- operand_index[i] = b_.CreateNSWSub(
- b_.CreateNSWAdd(strided_index, window_index[i]),
- index_typed_constant(window.dimensions(i).padding_low()));
- llvm::Value* index_condition = b_.CreateICmpULT(
+ operand_index[i] =
+ NSWSub(NSWAdd(strided_index, window_index[i]),
+ index_typed_constant(window.dimensions(i).padding_low()));
+ llvm::Value* index_condition = ICmpULT(
operand_index[i],
index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i)));
- in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition);
+ in_bounds_condition = And(in_bounds_condition, index_condition);
}
CHECK(in_bounds_condition != nullptr);
@@ -1884,7 +1870,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
- b_.CreateLoad(initialized_flag_address), "initialized", &b_);
+ Load(initialized_flag_address), "initialized", &b_);
// If the initialized_flag is false, initialize the selected value and index
// with the currently visiting operand.
@@ -1892,16 +1878,16 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
const auto save_operand_index = [&](const IrArray::Index& operand_index) {
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
- b_.CreateStore(operand_index[i], selected_index_address_slot);
+ InBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ Store(operand_index[i], selected_index_address_slot);
}
};
IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
llvm::Value* operand_data =
operand_array.EmitReadArrayElement(operand_index, &b_);
- b_.CreateStore(operand_data, selected_value_address);
+ Store(operand_data, selected_value_address);
save_operand_index(operand_index);
- b_.CreateStore(b_.getInt1(true), initialized_flag_address);
+ Store(b_.getInt1(true), initialized_flag_address);
// If the initialized_flag is true, call the `select` function to
// potentially update the selected value and index with the currently
@@ -1917,11 +1903,11 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*select_and_scatter->select(),
{selected_value_address, operand_address}, select_return_buffer));
- llvm::Value* result = b_.CreateLoad(select_return_buffer);
+ llvm::Value* result = Load(select_return_buffer);
// If the 'select' function returns false, update the selected value and the
// index to the currently visiting operand.
- llvm::Value* cond = b_.CreateICmpNE(
+ llvm::Value* cond = ICmpNE(
result,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(
PRED, ir_emitter_context_->llvm_module()),
@@ -1930,7 +1916,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
llvm_ir::LlvmIfData if_select_lhs =
llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
- b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address);
+ Store(Load(operand_address), selected_value_address);
save_operand_index(operand_index);
// After iterating over the window elements, scatter the source element to
@@ -1942,8 +1928,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
IrArray::Index selected_index(operand_index.GetType());
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
- selected_index.push_back(b_.CreateLoad(selected_index_address_slot));
+ InBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ selected_index.push_back(Load(selected_index_address_slot));
}
llvm::Value* source_value_address =
GetIrArray(*source, *select_and_scatter)
@@ -2367,8 +2353,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
*slice.allocation())));
CHECK_NE(loc, nullptr);
} else {
- loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()),
- {b_.getInt64(slice.offset())});
+ loc = InBoundsGEP(kernel_args.at(slice.allocation()),
+ {b_.getInt64(slice.offset())});
}
// If gte_index is nonempty, we have to dereference `loc` to get to the
@@ -2376,8 +2362,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
llvm::Type* int8_double_pointer =
llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0);
for (int64 idx : gte_index) {
- loc = b_.CreateBitCast(loc, int8_double_pointer);
- loc = b_.CreateLoad(b_.CreateInBoundsGEP(loc, {b_.getInt64(idx)}));
+ loc = BitCast(loc, int8_double_pointer);
+ loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)}));
}
bindings_.BindHloToIrValue(*instr, loc, index);
@@ -3154,9 +3140,8 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
const IrArray::Index output_tile_origin = [&] {
IrArray::Index index = output_tile_index;
for (int i = 1; i < 3; ++i) {
- index[i] =
- b_.CreateMul(output_tile_index[i], index_typed_constant(kTileSize),
- "tile_origin." + std::to_string(i));
+ index[i] = Mul(output_tile_index[i], index_typed_constant(kTileSize),
+ "tile_origin." + std::to_string(i));
}
return index;
}();
@@ -3169,12 +3154,12 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
std::vector<llvm::Value*> output_tile_bounds(3);
for (int i = 1; i < 3; ++i) {
// Only last row or column may not have full size.
- output_tile_bounds[i] = b_.CreateSelect(
- b_.CreateICmpEQ(output_tile_index[i],
- index_typed_constant(output_dims_in_tiles[i] - 1)),
- index_typed_constant(reduced_output_dims[i] -
- (output_dims_in_tiles[i] - 1) * kTileSize),
- index_typed_constant(kTileSize), "kTileSize");
+ output_tile_bounds[i] =
+ Select(ICmpEQ(output_tile_index[i],
+ index_typed_constant(output_dims_in_tiles[i] - 1)),
+ index_typed_constant(reduced_output_dims[i] -
+ (output_dims_in_tiles[i] - 1) * kTileSize),
+ index_typed_constant(kTileSize), "kTileSize");
}
KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
@@ -3192,7 +3177,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
// Adds `addend` to the given `dim` of `index`.
auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) {
- index[dim] = b_.CreateAdd(index[dim], addend);
+ index[dim] = Add(index[dim], addend);
return index;
};
const IrArray::Index input_index =
@@ -3208,10 +3193,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
llvm::Value* shmem_buffer = param_shmem_buffers[id];
// TODO(jlebar): Add AA metadata to this store. Tile buffers are
// global variables, so LLVM can't infer much about it.
- b_.CreateStore(
- input_in_logical_shape.EmitReadArrayElement(index, &b_,
- "input_element"),
- b_.CreateGEP(shmem_buffer, {index_typed_constant(0), y_loc, x}));
+ Store(input_in_logical_shape.EmitReadArrayElement(index, &b_,
+ "input_element"),
+ GEP(shmem_buffer, {index_typed_constant(0), y_loc, x}));
}
});
@@ -3232,9 +3216,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
output_index, "output", output_tile_bounds[2], output_tile_bounds[1],
[&](const IrArray::Index& index, llvm::Value* y_loc) {
// TODO(jlebar): Add AA metadata to this load.
- llvm::Instruction* load_from_shmem_buffer = b_.CreateLoad(
- b_.CreateGEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}),
- "output_element");
+ llvm::Instruction* load_from_shmem_buffer =
+ Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}),
+ "output_element");
output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
index, load_from_shmem_buffer, &b_);
});
@@ -3262,7 +3246,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
output_in_reduced_shape_arrays.size());
for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) {
output_in_reduced_shape_arrays[i].EmitWriteArrayElement(
- index, b_.CreateExtractValue(output_value, i), &b_);
+ index, ExtractValue(output_value, i), &b_);
}
} else {
output_in_reduced_shape_arrays[0].EmitWriteArrayElement(