aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-07-20 14:33:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 14:37:45 -0700
commit67869d1266f17b4502391a13eac9180bda5bce0b (patch)
tree02b001fa6a470de68a831d6d5dcb97b6633d77ef /tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
parent2c1f2847c2afbc6f23ef7040d49c71ffaa8b669c (diff)
[XLA] s/ir_builder/b/
Brevity. PiperOrigin-RevId: 205454869
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc560
1 files changed, 259 insertions, 301 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 1caf10a6c1..7100c9a08a 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -213,7 +213,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
llvm::LLVMContext& context = module->getContext();
llvm::FunctionType* kernel_type = llvm::FunctionType::get(
/*Result=*/llvm::Type::getVoidTy(context),
- std::vector<llvm::Type*>(args.size(), ir_builder_.getInt8PtrTy()),
+ std::vector<llvm::Type*>(args.size(), b_.getInt8PtrTy()),
/*isVarArg=*/false);
llvm::Function* kernel =
llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage,
@@ -249,7 +249,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
nvvm_annotations_node->addOperand(llvm::MDNode::get(
context, {llvm::ConstantAsMetadata::get(kernel),
llvm::MDString::get(context, "kernel"),
- llvm::ConstantAsMetadata::get(ir_builder_.getInt32(1))}));
+ llvm::ConstantAsMetadata::get(b_.getInt32(1))}));
// Update the insert point to the entry basic block.
llvm::BasicBlock* entry_bb =
@@ -257,7 +257,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
// Emit a "return void" at entry_bb's end, and set the insert point before
// that return instruction.
- ir_builder_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb));
+ b_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb));
return kernel;
}
@@ -295,7 +295,7 @@ int ComputeMaxUnrollFactor(const HloInstruction* hlo) {
// range of i32.
// Otherwise, the return type is i64.
llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size,
- llvm::IRBuilder<>* ir_builder) {
+ llvm::IRBuilder<>* b) {
// Find the unnested hlo instructon for which the kernel is generated for.
const HloInstruction* unnested_hlo = hlo;
const HloComputation* computation = hlo->parent();
@@ -316,7 +316,7 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size,
return in_range;
};
- llvm::Type* i64_ty = ir_builder->getInt64Ty();
+ llvm::Type* i64_ty = b->getInt64Ty();
// Check launch dimension
if (!IsInt32(launch_size)) {
return i64_ty;
@@ -345,7 +345,7 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size,
}
}
- return ir_builder->getInt32Ty();
+ return b->getInt32Ty();
}
} // namespace
@@ -600,8 +600,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
parameter_arrays.push_back(GetIrArray(*operand, *fusion));
}
GpuElementalIrEmitter elemental_emitter(
- hlo_module_config_, ir_emitter_context_->llvm_module(),
- &ir_builder_, GetNestedComputer());
+ hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
+ GetNestedComputer());
FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
TF_RETURN_IF_ERROR(root->Accept(&fused_emitter));
@@ -674,7 +674,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
}
GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
ir_emitter_context_->llvm_module(),
- &ir_builder_, GetNestedComputer());
+ &b_, GetNestedComputer());
// Shape of the dynamic-update-slice's "update" operand.
Shape update_shape = root->operand(1)->shape();
@@ -692,7 +692,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
fusion, operand_arrays, output_array, &elemental_emitter,
- launch_dimensions, &ir_builder_);
+ launch_dimensions, &b_);
}
if (ImplementedAsGemm(*fusion)) {
@@ -740,11 +740,11 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce(
const HloInstruction* output = reduce->parent()->FusionInstruction();
llvm::Value* extra_output_address =
GetIrArray(*output, *output, extra_output_gens[i].second)
- .EmitArrayElementAddress(index, &ir_builder_,
+ .EmitArrayElementAddress(index, &b_,
"extra_output_element_address");
TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
extra_output_gens[i].first(index));
- ir_builder_.CreateStore(extra_output_ir_value, extra_output_address);
+ b_.CreateStore(extra_output_ir_value, extra_output_address);
}
return Status::OK();
}
@@ -774,8 +774,8 @@ Status IrEmitterUnnested::EmitReductionToScalar(
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
tiled_input_shape, ir_emitter_context_->device_description());
- llvm::Type* index_ty = GetIndexTypeForKernel(
- reduce, launch_dimensions.launch_bound(), &ir_builder_);
+ llvm::Type* index_ty =
+ GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_);
auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_ty, c);
@@ -825,52 +825,51 @@ Status IrEmitterUnnested::EmitReductionToScalar(
llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
std::vector<llvm::Value*> partial_reduction_result_addresses;
for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
- element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." + llvm::Twine(i));
+ llvm::Value* partial_reduction_result_address =
+ b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." + llvm::Twine(i));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
init_value_gens[i](IrArray::Index(index_ty)));
- ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
+ b_.CreateStore(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
llvm::Value* x_in_tiles = tile_index[0];
- x_in_tiles = ir_builder_.CreateZExtOrTrunc(x_in_tiles, index_ty);
+ x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty);
// Emit an inner for-loop that reduces the elements in the tile.
auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
- llvm_ir::ForLoop::EmitForLoop("element_id_in_tile",
- index_typed_constant(0),
- index_typed_constant(kTileSize),
- index_typed_constant(1), &ir_builder_);
+ llvm_ir::ForLoop::EmitForLoop(
+ "element_id_in_tile", index_typed_constant(0),
+ index_typed_constant(kTileSize), index_typed_constant(1), &b_);
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
- &ir_builder_);
- llvm::Value* x = ir_builder_.CreateNSWAdd(
- ir_builder_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)),
+ &b_);
+ llvm::Value* x = b_.CreateNSWAdd(
+ b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)),
tile_element_loop->GetIndVarValue());
// Unless we know the tile is entirely in bounds, we have to emit a
// x-in-bounds check before reading from the input.
if (!tile_in_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpULT(x, index_typed_constant(num_elems)),
- "x_in_bounds", &ir_builder_);
+ b_.CreateICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds",
+ &b_);
// Emit code that reads the input element and accumulates it to
// the partial reduction result.
- llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_);
}
IrArray::Index input_index(
- /*linear=*/x, input_shape, &ir_builder_);
- llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
+ /*linear=*/x, input_shape, &b_);
+ llvm::Value* input_address = b_.CreateAlloca(element_ir_type);
for (int i = 0; i != num_reduces; ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
input_gens[i](input_index));
- ir_builder_.CreateStore(input_ir_value, input_address);
+ b_.CreateStore(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], input_address},
@@ -881,52 +880,48 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's
// immediately beyond the tile.
- llvm::Value* x_end = ir_builder_.CreateNSWAdd(
+ llvm::Value* x_end = b_.CreateNSWAdd(
index_typed_constant(kTileSize),
- ir_builder_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)));
+ b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)));
// The tile is entirely in bound if all_threads_in_bounds or
// x_end <= num_elems.
- llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
- ir_builder_.CreateICmpULE(x_end, index_typed_constant(num_elems)),
- ir_builder_.getInt1(all_threads_in_bounds));
+ llvm::Value* tile_in_bounds =
+ b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(num_elems)),
+ b_.getInt1(all_threads_in_bounds));
llvm_ir::LlvmIfData if_tile_in_bounds_data =
- llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_);
- llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block,
- &ir_builder_);
+ llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_);
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_);
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true));
- llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block,
- &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, &b_);
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false));
// After the if-then-else statement on tile_in_bounds, emit calls to
// shfl_down that accumulate the partial reduction results of all threads
// from the warp.
- llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block,
- &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, &b_);
int bit_width = llvm_ir::GetSizeInBits(element_ir_type);
// bitcast cannot be applied to aggregate types (even packed ones), so we
// instead bitcast addresses of load/store to intN* of the same bit-width.
llvm::Type* shuffle_ir_type = element_ir_type->isStructTy()
- ? ir_builder_.getIntNTy(bit_width)
+ ? b_.getIntNTy(bit_width)
: element_ir_type;
for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1;
shuffle_distance /= 2) {
- llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca(
- element_ir_type, nullptr, "result_from_other_lane");
+ llvm::Value* result_from_other_lane =
+ b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane");
for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result = ir_builder_.CreateLoad(
- ir_builder_.CreateBitCast(partial_reduction_result_addresses[i],
- shuffle_ir_type->getPointerTo()),
+ llvm::Value* partial_reduction_result = b_.CreateLoad(
+ b_.CreateBitCast(partial_reduction_result_addresses[i],
+ shuffle_ir_type->getPointerTo()),
"partial_reduction_result");
CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0)
<< "Requires block size a multiple of the warp size, otherwise we "
"will read undefined elements.";
- ir_builder_.CreateStore(
+ b_.CreateStore(
EmitFullWarpShuffleDown(partial_reduction_result,
- ir_builder_.getInt32(shuffle_distance),
- &ir_builder_),
- ir_builder_.CreateBitCast(result_from_other_lane,
- shuffle_ir_type->getPointerTo()));
+ b_.getInt32(shuffle_distance), &b_),
+ b_.CreateBitCast(result_from_other_lane,
+ shuffle_ir_type->getPointerTo()));
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], result_from_other_lane},
@@ -940,24 +935,23 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// Emit an atomic operation that accumulates the partial reduction result of
// lane 0 (which holds the partially accumulated result for its warp) to the
// output element.
- llvm::Value* lane_id = ir_builder_.CreateURem(
- x_in_tiles, index_typed_constant(kWarpSize), "lane_id");
+ llvm::Value* lane_id =
+ b_.CreateURem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id");
llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpEQ(lane_id, index_typed_constant(0)),
- "lane_id_is_zero", &ir_builder_);
- llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
- &ir_builder_);
+ b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero",
+ &b_);
+ llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
for (int i = 0; i != num_reduces; ++i) {
llvm::Value* output_address =
GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
IrArray::Index(
- /*linear=*/ir_builder_.getInt64(0),
+ /*linear=*/b_.getInt64(0),
ShapeUtil::GetSubshape(output->shape(),
reduce_output_shapes[i]),
- &ir_builder_),
- &ir_builder_, "output_element_address");
+ &b_),
+ &b_, "output_element_address");
TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
*reducers[i], output_address, partial_reduction_result_addresses[i]));
}
@@ -971,7 +965,7 @@ Status IrEmitterUnnested::EmitReductionToScalar(
static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(),
ir_emitter_context_->llvm_module());
return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape,
- launch_dimensions, &ir_builder_)
+ launch_dimensions, &b_)
.EmitLoop(IrName(reduce), index_ty);
}
@@ -1015,7 +1009,7 @@ Status IrEmitterUnnested::EmitColumnReduction(
tiled_input_shape, ir_emitter_context_->device_description());
// TODO(b/110211620): Convert to use i32 index_type when it is possible.
- llvm::Type* index_ty = ir_builder_.getInt64Ty();
+ llvm::Type* index_ty = b_.getInt64Ty();
auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_ty, c);
@@ -1065,14 +1059,12 @@ Status IrEmitterUnnested::EmitColumnReduction(
for (int i = 0; i != num_reduces; ++i) {
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
llvm::Value* partial_reduction_result_address =
- ir_builder_.CreateAlloca(
- element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." +
- llvm::Twine(i * kTileWidth + x_offset));
+ b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." +
+ llvm::Twine(i * kTileWidth + x_offset));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
init_value_gens[i](IrArray::Index(index_ty)));
- ir_builder_.CreateStore(init_ir_value,
- partial_reduction_result_address);
+ b_.CreateStore(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
@@ -1083,50 +1075,47 @@ Status IrEmitterUnnested::EmitColumnReduction(
llvm::Value* y_in_tiles = tile_index[0];
llvm::Value* x_in_tiles = tile_index[1];
- y_in_tiles = ir_builder_.CreateZExtOrTrunc(y_in_tiles, index_ty);
- x_in_tiles = ir_builder_.CreateZExtOrTrunc(x_in_tiles, index_ty);
+ y_in_tiles = b_.CreateZExtOrTrunc(y_in_tiles, index_ty);
+ x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty);
auto emit_tile_element_loop = [=](bool tile_in_y_bounds,
bool tile_in_x_bounds) -> Status {
std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
- llvm_ir::ForLoop::EmitForLoop("element_id_in_tile",
- index_typed_constant(0),
- index_typed_constant(kTileHeight),
- index_typed_constant(1), &ir_builder_);
+ llvm_ir::ForLoop::EmitForLoop(
+ "element_id_in_tile", index_typed_constant(0),
+ index_typed_constant(kTileHeight), index_typed_constant(1), &b_);
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
- &ir_builder_);
- llvm::Value* y = ir_builder_.CreateNSWAdd(
- ir_builder_.CreateNSWMul(y_in_tiles,
- index_typed_constant(kTileHeight)),
+ &b_);
+ llvm::Value* y = b_.CreateNSWAdd(
+ b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)),
tile_element_loop->GetIndVarValue());
// Unless we know that y is in bounds, we have to emit a check before
// reading from the input.
if (!tile_in_y_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpULT(y, index_typed_constant(height)),
- "y_in_bounds", &ir_builder_);
+ b_.CreateICmpULT(y, index_typed_constant(height)), "y_in_bounds",
+ &b_);
// Emit code that reads the input element and accumulates it to
// the partial reduction result.
- llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_);
}
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
- llvm::Value* x = ir_builder_.CreateNSWAdd(
- ir_builder_.CreateNSWMul(x_in_tiles,
- index_typed_constant(kTileWidth)),
+ llvm::Value* x = b_.CreateNSWAdd(
+ b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
index_typed_constant(x_offset));
// Unless we know that x is in bounds, we have to emit a check before
// reading from the input.
if (!tile_in_x_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpULT(x, index_typed_constant(width)),
- "x_in_bounds", &ir_builder_);
- llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
+ b_.CreateICmpULT(x, index_typed_constant(width)), "x_in_bounds",
+ &b_);
+ llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_);
}
- llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
+ llvm::Value* input_address = b_.CreateAlloca(element_ir_type);
// {y,x} is an index to input_matrix_shape [height,width]. We need to
// convert that to an index to input_shape (the shape of the operand of
// "reduce"). This conversion is composed of a transposition from
@@ -1143,18 +1132,17 @@ Status IrEmitterUnnested::EmitColumnReduction(
ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(),
{height, width});
const IrArray::Index input_matrix_index({y, x}, input_matrix_shape,
- &ir_builder_);
+ &b_);
const IrArray::Index input_index =
input_matrix_index
.SourceIndexOfReshape(input_matrix_shape,
- normalized_input_shape, &ir_builder_)
+ normalized_input_shape, &b_)
.SourceIndexOfTranspose(normalized_input_shape, input_shape,
- transpose_dimension_mapping,
- &ir_builder_);
+ transpose_dimension_mapping, &b_);
for (int i = 0; i != num_reduces; ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
input_gens[i](input_index));
- ir_builder_.CreateStore(input_ir_value, input_address);
+ b_.CreateStore(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i * kTileWidth + x_offset],
@@ -1169,64 +1157,55 @@ Status IrEmitterUnnested::EmitColumnReduction(
// y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location
// that's immediately beyond the tile.
- llvm::Value* y_end = ir_builder_.CreateNSWAdd(
+ llvm::Value* y_end = b_.CreateNSWAdd(
index_typed_constant(kTileHeight),
- ir_builder_.CreateNSWMul(y_in_tiles,
- index_typed_constant(kTileHeight)));
+ b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)));
// x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location
// that's immediately beyond the tile.
- llvm::Value* x_end = ir_builder_.CreateNSWAdd(
+ llvm::Value* x_end = b_.CreateNSWAdd(
index_typed_constant(kTileWidth),
- ir_builder_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)));
- llvm::Value* tile_in_y_bounds = ir_builder_.CreateOr(
- ir_builder_.CreateICmpULE(y_end, index_typed_constant(height)),
- ir_builder_.getInt1(height % kTileHeight == 0));
- llvm::Value* tile_in_x_bounds = ir_builder_.CreateOr(
- ir_builder_.CreateICmpULE(x_end, index_typed_constant(width)),
- ir_builder_.getInt1(width % kTileWidth == 0));
+ b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)));
+ llvm::Value* tile_in_y_bounds =
+ b_.CreateOr(b_.CreateICmpULE(y_end, index_typed_constant(height)),
+ b_.getInt1(height % kTileHeight == 0));
+ llvm::Value* tile_in_x_bounds =
+ b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(width)),
+ b_.getInt1(width % kTileWidth == 0));
// The tile is in y bounds if "height" is a multiple of kTileHeight or
// y_end <= height.
- llvm_ir::LlvmIfData if_tile_in_y_bounds_data = llvm_ir::EmitIfThenElse(
- tile_in_y_bounds, "tile_in_y_bounds", &ir_builder_);
- llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.true_block,
- &ir_builder_);
+ llvm_ir::LlvmIfData if_tile_in_y_bounds_data =
+ llvm_ir::EmitIfThenElse(tile_in_y_bounds, "tile_in_y_bounds", &b_);
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.true_block, &b_);
// The tile is in x bounds if "width" is a multiple of kTileWidth or
// x_end <= width.
- llvm_ir::LlvmIfData if_tile_in_x_bounds_data = llvm_ir::EmitIfThenElse(
- tile_in_x_bounds, "tile_in_x_bounds", &ir_builder_);
- llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block,
- &ir_builder_);
+ llvm_ir::LlvmIfData if_tile_in_x_bounds_data =
+ llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_);
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_);
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true,
/*tile_in_x_bounds=*/true));
- llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block,
- &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_);
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true,
/*tile_in_x_bounds=*/false));
- llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.false_block,
- &ir_builder_);
- if_tile_in_x_bounds_data = llvm_ir::EmitIfThenElse(
- tile_in_x_bounds, "tile_in_x_bounds", &ir_builder_);
- llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block,
- &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.false_block, &b_);
+ if_tile_in_x_bounds_data =
+ llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_);
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_);
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false,
/*tile_in_x_bounds=*/true));
- llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block,
- &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_);
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false,
/*tile_in_x_bounds=*/false));
// After the nested if-then-else statement on tile_in_y_bounds and
// tile_in_x_bounds, emit atomic operations to accumulate the partial
// reduction result to the output element.
- llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.after_block,
- &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.after_block, &b_);
const HloInstruction* output =
reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
for (int i = 0; i != num_reduces; ++i) {
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
- llvm::Value* x = ir_builder_.CreateNSWAdd(
- ir_builder_.CreateNSWMul(x_in_tiles,
- index_typed_constant(kTileWidth)),
+ llvm::Value* x = b_.CreateNSWAdd(
+ b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
index_typed_constant(x_offset));
llvm::Value* output_address =
GetIrArray(*output, *output, reduce_output_shapes[i])
@@ -1235,8 +1214,8 @@ Status IrEmitterUnnested::EmitColumnReduction(
x,
ShapeUtil::GetSubshape(output->shape(),
reduce_output_shapes[i]),
- &ir_builder_),
- &ir_builder_, "output_element_address");
+ &b_),
+ &b_, "output_element_address");
TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
*reducers[i], output_address,
partial_reduction_result_addresses[i * kTileWidth + x_offset]));
@@ -1252,7 +1231,7 @@ Status IrEmitterUnnested::EmitColumnReduction(
static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(),
ir_emitter_context_->llvm_module());
return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape,
- launch_dimensions, &ir_builder_)
+ launch_dimensions, &b_)
.EmitLoop(IrName(reduce), index_ty);
}
@@ -1402,8 +1381,8 @@ Status IrEmitterUnnested::EmitRowReduction(
{depth / z_tile_size, height, width_in_tiles}, {2, 1, 0});
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
tiled_input_shape, ir_emitter_context_->device_description());
- llvm::Type* index_ty = GetIndexTypeForKernel(
- reduce, launch_dimensions.launch_bound(), &ir_builder_);
+ llvm::Type* index_ty =
+ GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_);
auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_ty, c);
@@ -1415,12 +1394,12 @@ Status IrEmitterUnnested::EmitRowReduction(
input_shape.element_type(), ir_emitter_context_->llvm_module());
std::vector<llvm::Value*> partial_reduction_result_addresses;
for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
- element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." + llvm::Twine(i));
+ llvm::Value* partial_reduction_result_address =
+ b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." + llvm::Twine(i));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
init_value_gens[i](IrArray::Index(index_ty)));
- ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
+ b_.CreateStore(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
@@ -1429,25 +1408,25 @@ Status IrEmitterUnnested::EmitRowReduction(
llvm::Value* y = tile_index[1];
llvm::Value* x_tile = tile_index[2];
- x_tile = ir_builder_.CreateZExtOrTrunc(x_tile, index_ty);
+ x_tile = b_.CreateZExtOrTrunc(x_tile, index_ty);
- llvm::Value* warp_id = ir_builder_.CreateUDiv(
- x_tile, index_typed_constant(kWarpSize), "warp_id");
- llvm::Value* lane_id = ir_builder_.CreateURem(
- x_tile, index_typed_constant(kWarpSize), "lane_id");
+ llvm::Value* warp_id =
+ b_.CreateUDiv(x_tile, index_typed_constant(kWarpSize), "warp_id");
+ llvm::Value* lane_id =
+ b_.CreateURem(x_tile, index_typed_constant(kWarpSize), "lane_id");
// The x-location of the last element in this z-x-tile.
// last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size);
- llvm::Value* last_x = ir_builder_.CreateNSWAdd(
- lane_id, ir_builder_.CreateNSWMul(
- index_typed_constant(kWarpSize),
- ir_builder_.CreateNSWAdd(
- index_typed_constant(x_tile_size - 1),
- ir_builder_.CreateNSWMul(
- warp_id, index_typed_constant(x_tile_size)))));
+ llvm::Value* last_x = b_.CreateNSWAdd(
+ lane_id,
+ b_.CreateNSWMul(
+ index_typed_constant(kWarpSize),
+ b_.CreateNSWAdd(
+ index_typed_constant(x_tile_size - 1),
+ b_.CreateNSWMul(warp_id, index_typed_constant(x_tile_size)))));
KernelSupportLibrary ksl(
- &ir_builder_,
+ &b_,
/*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll,
/*prevent_vectorization=*/false);
@@ -1456,9 +1435,9 @@ Status IrEmitterUnnested::EmitRowReduction(
auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds,
int64 x_tile_loop_bound) -> Status {
auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status {
- llvm::Value* z = ir_builder_.CreateNSWAdd(
- z_indvar, ir_builder_.CreateNSWMul(
- index_typed_constant(z_tile_size), z_tile));
+ llvm::Value* z = b_.CreateNSWAdd(
+ z_indvar,
+ b_.CreateNSWMul(index_typed_constant(z_tile_size), z_tile));
TF_RETURN_IF_ERROR(ksl.For(
"x_tile",
/*start=*/index_typed_constant(0),
@@ -1466,12 +1445,12 @@ Status IrEmitterUnnested::EmitRowReduction(
/*step=*/1, [&](llvm::Value* x_indvar) -> Status {
// x = lane_id +
// warpSize * (element_id_in_x_tile + warp_id * x_tile_size);
- llvm::Value* x = ir_builder_.CreateNSWAdd(
+ llvm::Value* x = b_.CreateNSWAdd(
lane_id,
- ir_builder_.CreateNSWMul(
+ b_.CreateNSWMul(
index_typed_constant(kWarpSize),
- ir_builder_.CreateNSWAdd(
- x_indvar, ir_builder_.CreateNSWMul(
+ b_.CreateNSWAdd(
+ x_indvar, b_.CreateNSWMul(
warp_id, llvm::ConstantInt::get(
index_ty, x_tile_size)))));
@@ -1479,18 +1458,17 @@ Status IrEmitterUnnested::EmitRowReduction(
// emit a x-in-bounds check before reading from the input.
if (!x_tile_in_bounds) {
llvm_ir::LlvmIfData if_x_in_bounds_data =
- llvm_ir::EmitIfThenElse(ir_builder_.CreateICmpULT(
- x, index_typed_constant(width)),
- "x_in_bounds", &ir_builder_);
- // Points ir_builder_ to the then-block.
+ llvm_ir::EmitIfThenElse(
+ b_.CreateICmpULT(x, index_typed_constant(width)),
+ "x_in_bounds", &b_);
+ // Points b_ to the then-block.
llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block,
- &ir_builder_);
+ &b_);
}
// Emit code that reads the input element and accumulates it
// to the partial reduction result.
- llvm::Value* input_address =
- ir_builder_.CreateAlloca(element_ir_type);
+ llvm::Value* input_address = b_.CreateAlloca(element_ir_type);
{
// {z,y,x} is an index to input_3d_tensor_shape
// [depth,height,width]. We need to convert that to an index
@@ -1509,20 +1487,19 @@ Status IrEmitterUnnested::EmitRowReduction(
ShapeUtil::MakeShapeWithDescendingLayout(
input_shape.element_type(), {depth, height, width});
const IrArray::Index input_3d_tensor_index(
- {z, y, x}, input_3d_tensor_shape, &ir_builder_);
+ {z, y, x}, input_3d_tensor_shape, &b_);
const IrArray::Index input_index =
input_3d_tensor_index
.SourceIndexOfReshape(input_3d_tensor_shape,
- normalized_input_shape,
- &ir_builder_)
+ normalized_input_shape, &b_)
.SourceIndexOfTranspose(
normalized_input_shape, input_shape,
- transpose_dimension_mapping, &ir_builder_);
+ transpose_dimension_mapping, &b_);
for (int i = 0; i != num_reduces; ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
input_gens[i](input_index));
- ir_builder_.CreateStore(input_ir_value, input_address);
+ b_.CreateStore(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], input_address},
@@ -1541,9 +1518,9 @@ Status IrEmitterUnnested::EmitRowReduction(
/*step=*/1, emit_z_tile_element_loop);
};
- llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
- ir_builder_.getInt1(width % (x_tile_size * kWarpSize) == 0),
- ir_builder_.CreateICmpULT(last_x, index_typed_constant(width)));
+ llvm::Value* tile_in_bounds =
+ b_.CreateOr(b_.getInt1(width % (x_tile_size * kWarpSize) == 0),
+ b_.CreateICmpULT(last_x, index_typed_constant(width)));
TF_RETURN_IF_ERROR(
ksl.If(tile_in_bounds,
@@ -1566,26 +1543,25 @@ Status IrEmitterUnnested::EmitRowReduction(
// bitcast cannot be applied to aggregate types (even packed ones), so we
// instead bitcast addresses of load/store to intN* of the same bit-width.
llvm::Type* shuffle_ir_type = element_ir_type->isStructTy()
- ? ir_builder_.getIntNTy(bit_width)
+ ? b_.getIntNTy(bit_width)
: element_ir_type;
for (int shuffle_distance = 16; shuffle_distance >= 1;
shuffle_distance /= 2) {
- llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca(
- element_ir_type, nullptr, "result_from_other_lane");
+ llvm::Value* result_from_other_lane =
+ b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane");
for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result = ir_builder_.CreateLoad(
- ir_builder_.CreateBitCast(partial_reduction_result_addresses[i],
- shuffle_ir_type->getPointerTo()),
+ llvm::Value* partial_reduction_result = b_.CreateLoad(
+ b_.CreateBitCast(partial_reduction_result_addresses[i],
+ shuffle_ir_type->getPointerTo()),
"partial_reduction_result");
CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0)
<< "Requires block size a multiple of the warp size, otherwise we "
"will read undefined elements.";
- ir_builder_.CreateStore(
+ b_.CreateStore(
EmitFullWarpShuffleDown(partial_reduction_result,
- ir_builder_.getInt32(shuffle_distance),
- &ir_builder_),
- ir_builder_.CreateBitCast(result_from_other_lane,
- shuffle_ir_type->getPointerTo()));
+ b_.getInt32(shuffle_distance), &b_),
+ b_.CreateBitCast(result_from_other_lane,
+ shuffle_ir_type->getPointerTo()));
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], result_from_other_lane},
@@ -1600,10 +1576,9 @@ Status IrEmitterUnnested::EmitRowReduction(
// lane 0 (which holds the partially accumulated result for its warp) to the
// output element.
llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpEQ(lane_id, index_typed_constant(0)),
- "lane_id_is_zero", &ir_builder_);
- llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
- &ir_builder_);
+ b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero",
+ &b_);
+ llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
for (int i = 0; i != num_reduces; ++i) {
llvm::Value* output_address =
GetIrArray(*output, *output, reduce_output_shapes[i])
@@ -1611,8 +1586,8 @@ Status IrEmitterUnnested::EmitRowReduction(
IrArray::Index(y,
ShapeUtil::GetSubshape(
output->shape(), reduce_output_shapes[i]),
- &ir_builder_),
- &ir_builder_, "output_element_address");
+ &b_),
+ &b_, "output_element_address");
// We don't need to emit atomic operations if there is only one tile of
// results. 'depth' is the z dimension, 'width' is the x dimension.
if (z_tile_size >= depth && x_tile_size >= width) {
@@ -1636,7 +1611,7 @@ Status IrEmitterUnnested::EmitRowReduction(
static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(),
ir_emitter_context_->llvm_module());
return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape,
- launch_dimensions, &ir_builder_)
+ launch_dimensions, &b_)
.EmitLoop(IrName(reduce), index_ty);
}
@@ -1762,12 +1737,11 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
return EmitReductionToVector(
reduce, input->shape(), {[&](const IrArray::Index& index) {
- return GetIrArray(*input, *reduce)
- .EmitReadArrayElement(index, &ir_builder_);
+ return GetIrArray(*input, *reduce).EmitReadArrayElement(index, &b_);
}},
{[&](const IrArray::Index& index) {
return GetIrArray(*init_value, *reduce)
- .EmitReadArrayElement(index, &ir_builder_);
+ .EmitReadArrayElement(index, &b_);
}},
dimensions_to_reduce, {reducer}, {{}}, {});
}
@@ -1842,7 +1816,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
source->shape(), ir_emitter_context_->device_description());
llvm::Type* index_type = GetIndexTypeForKernel(
- select_and_scatter, launch_dimensions.launch_bound(), &ir_builder_);
+ select_and_scatter, launch_dimensions.launch_bound(), &b_);
auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_type, c);
};
@@ -1873,19 +1847,18 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(operand_element_type,
ir_emitter_context_->llvm_module()),
- "selected_value_address", &ir_builder_);
+ "selected_value_address", &b_);
llvm::Value* selected_index_address =
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
index_type, index_typed_constant(rank), "selected_index_address",
- &ir_builder_);
+ &b_);
llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
- ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_);
- ir_builder_.CreateStore(ir_builder_.getInt1(false),
- initialized_flag_address);
+ b_.getInt1Ty(), "initialized_flag_address", &b_);
+ b_.CreateStore(b_.getInt1(false), initialized_flag_address);
// Create the inner loop to iterate over the window.
- llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"),
- &ir_builder_, index_type);
+ llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_,
+ index_type);
std::vector<int64> window_size;
for (const auto& dim : window.dimensions()) {
window_size.push_back(dim.size());
@@ -1894,84 +1867,79 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
const IrArray::Index window_index = window_loops.AddLoopsForShape(
ShapeUtil::MakeShape(operand_element_type, window_size), "window");
llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
- &ir_builder_);
+ &b_);
// Compute the operand index to visit and evaluate the condition whether the
// operand index is within the bounds. The unsigned comparison includes
// checking whether the operand index >= 0.
IrArray::Index operand_index(index_type, source_index.size());
- llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
+ llvm::Value* in_bounds_condition = b_.getInt1(true);
for (int64 i = 0; i < rank; ++i) {
- llvm::Value* strided_index = ir_builder_.CreateNSWMul(
+ llvm::Value* strided_index = b_.CreateNSWMul(
source_index[i], index_typed_constant(window.dimensions(i).stride()));
- operand_index[i] = ir_builder_.CreateNSWSub(
- ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
+ operand_index[i] = b_.CreateNSWSub(
+ b_.CreateNSWAdd(strided_index, window_index[i]),
index_typed_constant(window.dimensions(i).padding_low()));
- llvm::Value* index_condition = ir_builder_.CreateICmpULT(
+ llvm::Value* index_condition = b_.CreateICmpULT(
operand_index[i],
index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i)));
- in_bounds_condition =
- ir_builder_.CreateAnd(in_bounds_condition, index_condition);
+ in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition);
}
CHECK(in_bounds_condition != nullptr);
// Only need to do something if the operand index is within the bounds.
// First check if the initialized_flag is set.
llvm_ir::LlvmIfData if_in_bounds =
- llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_);
- llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &ir_builder_);
+ llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
+ llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateLoad(initialized_flag_address), "initialized",
- &ir_builder_);
+ b_.CreateLoad(initialized_flag_address), "initialized", &b_);
// If the initialized_flag is false, initialize the selected value and index
// with the currently visiting operand.
- llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &b_);
const auto save_operand_index = [&](const IrArray::Index& operand_index) {
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- ir_builder_.CreateInBoundsGEP(selected_index_address,
- {ir_builder_.getInt32(i)});
- ir_builder_.CreateStore(operand_index[i], selected_index_address_slot);
+ b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ b_.CreateStore(operand_index[i], selected_index_address_slot);
}
};
IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
llvm::Value* operand_data =
- operand_array.EmitReadArrayElement(operand_index, &ir_builder_);
- ir_builder_.CreateStore(operand_data, selected_value_address);
+ operand_array.EmitReadArrayElement(operand_index, &b_);
+ b_.CreateStore(operand_data, selected_value_address);
save_operand_index(operand_index);
- ir_builder_.CreateStore(ir_builder_.getInt1(true),
- initialized_flag_address);
+ b_.CreateStore(b_.getInt1(true), initialized_flag_address);
// If the initialized_flag is true, call the `select` function to
// potentially update the selected value and index with the currently
// visiting operand.
- llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_);
const Shape output_shape = ShapeUtil::MakeShape(PRED, {});
llvm::Value* operand_address =
- operand_array.EmitArrayElementAddress(operand_index, &ir_builder_);
+ operand_array.EmitArrayElementAddress(operand_index, &b_);
llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(PRED,
ir_emitter_context_->llvm_module()),
- "select_return_buffer", &ir_builder_);
+ "select_return_buffer", &b_);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*select_and_scatter->select(),
{selected_value_address, operand_address}, select_return_buffer));
- llvm::Value* result = ir_builder_.CreateLoad(select_return_buffer);
+ llvm::Value* result = b_.CreateLoad(select_return_buffer);
// If the 'select' function returns false, update the selected value and the
// index to the currently visiting operand.
- llvm::Value* cond = ir_builder_.CreateICmpNE(
+ llvm::Value* cond = b_.CreateICmpNE(
result,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(
PRED, ir_emitter_context_->llvm_module()),
0),
"boolean_predicate");
llvm_ir::LlvmIfData if_select_lhs =
- llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_);
- llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &ir_builder_);
- ir_builder_.CreateStore(ir_builder_.CreateLoad(operand_address),
- selected_value_address);
+ llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
+ llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
+ b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address);
save_operand_index(operand_index);
// After iterating over the window elements, scatter the source element to
@@ -1979,20 +1947,19 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// location is computed by calling the `scatter` function with the source
// value and the current output value.
llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
- &ir_builder_);
+ &b_);
IrArray::Index selected_index(operand_index.GetType());
for (int64 i = 0; i < rank; ++i) {
- llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP(
- selected_index_address, {ir_builder_.getInt32(i)});
- selected_index.push_back(
- ir_builder_.CreateLoad(selected_index_address_slot));
+ llvm::Value* selected_index_address_slot =
+ b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ selected_index.push_back(b_.CreateLoad(selected_index_address_slot));
}
llvm::Value* source_value_address =
GetIrArray(*source, *select_and_scatter)
- .EmitArrayElementAddress(source_index, &ir_builder_);
+ .EmitArrayElementAddress(source_index, &b_);
llvm::Value* output_value_address =
GetIrArray(*select_and_scatter, *select_and_scatter)
- .EmitArrayElementAddress(selected_index, &ir_builder_);
+ .EmitArrayElementAddress(selected_index, &b_);
return EmitAtomicOperationForNestedComputation(
*select_and_scatter->scatter(), output_value_address,
source_value_address);
@@ -2007,7 +1974,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(),
ir_emitter_context_->llvm_module());
return ParallelLoopEmitter(loop_body_emitter, source->shape(),
- launch_dimensions, &ir_builder_)
+ launch_dimensions, &b_)
.EmitLoop(IrName(select_and_scatter), index_type);
}
@@ -2326,18 +2293,16 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
<< " is found in slice " << slice.ToString() << " at GTE index "
<< gte_index.ToString();
- llvm::Value* loc =
- ir_builder_.CreateInBoundsGEP(kernel_args.at(slice.allocation()),
- {ir_builder_.getInt64(slice.offset())});
+ llvm::Value* loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()),
+ {b_.getInt64(slice.offset())});
// If gte_index is nonempty, we have to dereference `loc` to get to the
// value we're ultimately interested in.
llvm::Type* int8_double_pointer =
- llvm::PointerType::get(ir_builder_.getInt8PtrTy(), /*AddressSpace=*/0);
+ llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0);
for (int64 idx : gte_index) {
- loc = ir_builder_.CreateBitCast(loc, int8_double_pointer);
- loc = ir_builder_.CreateLoad(
- ir_builder_.CreateInBoundsGEP(loc, {ir_builder_.getInt64(idx)}));
+ loc = b_.CreateBitCast(loc, int8_double_pointer);
+ loc = b_.CreateLoad(b_.CreateInBoundsGEP(loc, {b_.getInt64(idx)}));
}
bindings_.BindHloToIrValue(*instr, loc, index);
@@ -2349,7 +2314,7 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
bindings_.SetTempBufferBase(kernel_args.at(*temp_buffer));
} else {
bindings_.SetTempBufferBase(
- llvm::ConstantPointerNull::get(ir_builder_.getInt8PtrTy()));
+ llvm::ConstantPointerNull::get(b_.getInt8PtrTy()));
}
return MakeUnique<KernelThunk>(buffers, llvm_ir::AsString(kernel->getName()),
@@ -2596,10 +2561,9 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
TF_RETURN_IF_ERROR(ParallelLoopEmitter(
[=](const IrArray::Index& index) {
return GetIrArray(*init_value, *hlo)
- .EmitReadArrayElement(index, &ir_builder_);
+ .EmitReadArrayElement(index, &b_);
},
- GetIrArray(*hlo, *hlo, index), launch_dimensions,
- &ir_builder_)
+ GetIrArray(*hlo, *hlo, index), launch_dimensions, &b_)
.EmitLoop(IrName(hlo)));
// Clean up state left behind by emitting the loop above. (This is normally
@@ -2783,10 +2747,10 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
ir_emitter_context_->llvm_module());
if (!hlo.IsMultiOutputFusion()) {
return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo),
- launch_dimensions, &ir_builder_, unroll_factor)
- .EmitLoop(IrName(&hlo),
- GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(),
- &ir_builder_));
+ launch_dimensions, &b_, unroll_factor)
+ .EmitLoop(
+ IrName(&hlo),
+ GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), &b_));
}
// For multioutput fusion, we need to emit each operand and the root.
@@ -2796,18 +2760,17 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
}
TF_RETURN_IF_ERROR(
ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions,
- &ir_builder_, unroll_factor)
+ &b_, unroll_factor)
.EmitLoop(IrName(&hlo),
GetIndexTypeForKernel(
- &hlo, launch_dimensions.launch_bound(), &ir_builder_)));
+ &hlo, launch_dimensions.launch_bound(), &b_)));
std::vector<llvm::Value*> tuple_operand_ptrs;
for (int64 i = 0; i < output_arrays.size(); ++i) {
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
}
- ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator());
- llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_,
- module_);
+ b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
+ llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_);
return Status::OK();
}
@@ -2858,14 +2821,14 @@ int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout(
ShapeUtil::GetSubshape(hlo.shape(), {i}).element_type(),
reduced_output_dims));
- output_in_reduced_shape_arrays->push_back(output_arrays[i].CastToShape(
- (*output_reduced_shapes)[i], &ir_builder_));
+ output_in_reduced_shape_arrays->push_back(
+ output_arrays[i].CastToShape((*output_reduced_shapes)[i], &b_));
}
} else {
output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout(
hlo.shape().element_type(), reduced_output_dims));
- output_in_reduced_shape_arrays->push_back(output_arrays[0].CastToShape(
- (*output_reduced_shapes)[0], &ir_builder_));
+ output_in_reduced_shape_arrays->push_back(
+ output_arrays[0].CastToShape((*output_reduced_shapes)[0], &b_));
}
return num_outputs;
}
@@ -2889,8 +2852,8 @@ int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape(
param_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout(
param->shape().element_type(),
Permute({0, 2, 1}, reduced_output_dims)));
- param_in_reduced_shape_arrays->push_back(param_arrays[id].CastToShape(
- (*param_reduced_shapes)[id], &ir_builder_));
+ param_in_reduced_shape_arrays->push_back(
+ param_arrays[id].CastToShape((*param_reduced_shapes)[id], &b_));
}
return num_params;
}
@@ -3039,7 +3002,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
kTileSize);
const int kNVPTXSharedMemoryAddrSpace = 3;
auto* tile_base_ptr = new llvm::GlobalVariable(
- *ir_builder_.GetInsertBlock()->getParent()->getParent(), tile_type,
+ *b_.GetInsertBlock()->getParent()->getParent(), tile_type,
/*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
llvm::UndefValue::get(tile_type),
llvm_ir::AsStringRef(IrName(hlo, StrCat("tile", id))), nullptr,
@@ -3063,8 +3026,8 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
c_accumulate(output_dims_in_tiles, 1, std::multiplies<int64>());
LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile);
- llvm::Type* index_ty = GetIndexTypeForKernel(
- hlo, launch_dimensions.launch_bound(), &ir_builder_);
+ llvm::Type* index_ty =
+ GetIndexTypeForKernel(hlo, launch_dimensions.launch_bound(), &b_);
auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_ty, c);
};
@@ -3092,23 +3055,23 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
llvm::Value* x;
llvm::Value* y;
std::tie(y, x) = CalculateYXCoordinateWithinTile(
- &ir_builder_, index_typed_constant(kTileSize), kThreadsPerTile);
+ &b_, index_typed_constant(kTileSize), kThreadsPerTile);
// Calculate the index for the current output tile from block_id.
const IrArray::Index output_tile_index(
- GetBlockIdx(&ir_builder_, index_ty, num_tiles),
+ GetBlockIdx(&b_, index_ty, num_tiles),
ShapeUtil::MakeShapeWithDescendingLayout(PRED /*arbitrary*/,
output_dims_in_tiles),
- &ir_builder_);
+ &b_);
// Output tile origin is the index for the first element of the current output
// tile.
const IrArray::Index output_tile_origin = [&] {
IrArray::Index index = output_tile_index;
for (int i = 1; i < 3; ++i) {
- index[i] = ir_builder_.CreateMul(output_tile_index[i],
- index_typed_constant(kTileSize),
- "tile_origin." + std::to_string(i));
+ index[i] =
+ b_.CreateMul(output_tile_index[i], index_typed_constant(kTileSize),
+ "tile_origin." + std::to_string(i));
}
return index;
}();
@@ -3121,16 +3084,15 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
std::vector<llvm::Value*> output_tile_bounds(3);
for (int i = 1; i < 3; ++i) {
// Only last row or column may not have full size.
- output_tile_bounds[i] = ir_builder_.CreateSelect(
- ir_builder_.CreateICmpEQ(
- output_tile_index[i],
- index_typed_constant(output_dims_in_tiles[i] - 1)),
+ output_tile_bounds[i] = b_.CreateSelect(
+ b_.CreateICmpEQ(output_tile_index[i],
+ index_typed_constant(output_dims_in_tiles[i] - 1)),
index_typed_constant(reduced_output_dims[i] -
(output_dims_in_tiles[i] - 1) * kTileSize),
index_typed_constant(kTileSize), "kTileSize");
}
- KernelSupportLibrary ksl(&ir_builder_, llvm_ir::UnrollMode::kDefaultUnroll);
+ KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
// Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck.
auto emit_tiled_elemental_code_with_bounds_check =
@@ -3139,13 +3101,13 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
const std::function<void(const IrArray::Index&, llvm::Value*)>&
emit_elem_function) {
EmitTiledElementalCodeWithBoundsCheck(
- kTileSize, kNumRows, index, loop_name, &ksl, &ir_builder_, y, x,
- tile_width, tile_height, emit_elem_function);
+ kTileSize, kNumRows, index, loop_name, &ksl, &b_, y, x, tile_width,
+ tile_height, emit_elem_function);
};
// Adds `addend` to the given `dim` of `index`.
auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) {
- index[dim] = ir_builder_.CreateAdd(index[dim], addend);
+ index[dim] = b_.CreateAdd(index[dim], addend);
return index;
};
const IrArray::Index input_index =
@@ -3161,19 +3123,17 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
llvm::Value* shmem_buffer = param_shmem_buffers[id];
// TODO(jlebar): Add AA metadata to this store. Tile buffers are
// global variables, so LLVM can't infer much about it.
- ir_builder_.CreateStore(
- input_in_logical_shape.EmitReadArrayElement(index, &ir_builder_,
+ b_.CreateStore(
+ input_in_logical_shape.EmitReadArrayElement(index, &b_,
"input_element"),
- ir_builder_.CreateGEP(shmem_buffer,
- {index_typed_constant(0), y_loc, x}));
+ b_.CreateGEP(shmem_buffer, {index_typed_constant(0), y_loc, x}));
}
});
// Wait for all threads to reach this point, lest we copy a value from tile to
// output before the other thread copies it from input to tile.
// This is `__syncthreads` in CUDA.
- llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {},
- &ir_builder_);
+ llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_);
llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x);
@@ -3187,27 +3147,26 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
output_index, "output", output_tile_bounds[2], output_tile_bounds[1],
[&](const IrArray::Index& index, llvm::Value* y_loc) {
// TODO(jlebar): Add AA metadata to this load.
- llvm::Instruction* load_from_shmem_buffer = ir_builder_.CreateLoad(
- ir_builder_.CreateGEP(param_shmem_buffers[0],
- {ir_builder_.getInt64(0), x, y_loc}),
+ llvm::Instruction* load_from_shmem_buffer = b_.CreateLoad(
+ b_.CreateGEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}),
"output_element");
output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
- index, load_from_shmem_buffer, &ir_builder_);
+ index, load_from_shmem_buffer, &b_);
});
} else {
CHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
emit_tiled_elemental_code_with_bounds_check(
output_index, "output", output_tile_bounds[2], output_tile_bounds[1],
[&](const IrArray::Index& index, llvm::Value* y_loc) {
- GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_,
- &ir_builder_, GetNestedComputer());
+ GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
+ GetNestedComputer());
FusedIrEmitter fused_emitter(param_arrays, &elem_emitter);
tiled_param_info.set_y(y_loc);
fused_emitter.SetTiledParameterInfo(&tiled_param_info);
TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter));
IrArray::Index untiled_index = llvm_ir::GetUnreducedOutputIndex(
index, output_reduced_shapes[0], output_arrays[0].GetShape(),
- &ir_builder_);
+ &b_);
const llvm_ir::ElementGenerator& output_generator =
fused_emitter.GetRootGenerator();
llvm::Value* output_value =
@@ -3218,12 +3177,11 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
output_in_reduced_shape_arrays.size());
for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) {
output_in_reduced_shape_arrays[i].EmitWriteArrayElement(
- index, ir_builder_.CreateExtractValue(output_value, i),
- &ir_builder_);
+ index, b_.CreateExtractValue(output_value, i), &b_);
}
} else {
output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
- index, output_value, &ir_builder_);
+ index, output_value, &b_);
}
});
}
@@ -3234,7 +3192,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
for (int64 i = 0; i < output_arrays.size(); ++i) {
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
}
- llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &ir_builder_,
+ llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &b_,
module_);
}