aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc60
1 files changed, 29 insertions, 31 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
index cd833ec7bd..3838fee674 100644
--- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
@@ -32,27 +32,27 @@ namespace gpu {
ParallelLoopEmitter::ParallelLoopEmitter(
BodyEmitter body_emitter, const Shape& shape,
- const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder,
+ const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b,
int unroll_factor)
- : LoopEmitter(body_emitter, shape, ir_builder),
+ : LoopEmitter(body_emitter, shape, b),
launch_dimensions_(launch_dimensions),
unroll_factor_(unroll_factor) {}
ParallelLoopEmitter::ParallelLoopEmitter(
const llvm_ir::ElementGenerator& target_element_generator,
tensorflow::gtl::ArraySlice<llvm_ir::IrArray> target_arrays,
- const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder,
+ const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b,
int unroll_factor)
- : LoopEmitter(target_element_generator, target_arrays, ir_builder),
+ : LoopEmitter(target_element_generator, target_arrays, b),
launch_dimensions_(launch_dimensions),
unroll_factor_(unroll_factor) {}
ParallelLoopEmitter::ParallelLoopEmitter(
const llvm_ir::ElementGenerator& target_element_generator,
const llvm_ir::IrArray& target_array,
- const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder,
+ const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b,
int unroll_factor)
- : LoopEmitter(target_element_generator, target_array, ir_builder),
+ : LoopEmitter(target_element_generator, target_array, b),
launch_dimensions_(launch_dimensions),
unroll_factor_(unroll_factor) {}
@@ -74,29 +74,27 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
CHECK_NE(index_type, nullptr);
std::vector<llvm_ir::IrArray::Index> array_indices;
llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, ir_builder_);
+ llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_);
llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(),
static_cast<llvm::Instruction*>(block_id));
- block_id = ir_builder_->CreateZExtOrTrunc(block_id, index_type, "block_id");
+ block_id = b_->CreateZExtOrTrunc(block_id, index_type, "block_id");
// Per the PTX documentation:
// "It is guaranteed that [...] 0 <= %tid.x < %ntid.x"
//
// %ntid.x is currently specified as 1024.
llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, ir_builder_);
+ llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_);
llvm_ir::AddRangeMetadata(0, launch_dimensions_.threads_per_block(),
static_cast<llvm::Instruction*>(thread_id));
- thread_id =
- ir_builder_->CreateZExtOrTrunc(thread_id, index_type, "thread_id");
-
- llvm::Value* linear_index_base = ir_builder_->CreateAdd(
- ir_builder_->CreateMul(
- block_id,
- llvm::ConstantInt::get(index_type,
- launch_dimensions_.threads_per_block()),
- "",
- /*HasNUW=*/true, /*HasNSW=*/true),
+ thread_id = b_->CreateZExtOrTrunc(thread_id, index_type, "thread_id");
+
+ llvm::Value* linear_index_base = b_->CreateAdd(
+ b_->CreateMul(block_id,
+ llvm::ConstantInt::get(
+ index_type, launch_dimensions_.threads_per_block()),
+ "",
+ /*HasNUW=*/true, /*HasNSW=*/true),
thread_id, "linear_index", /*HasNUW=*/true, /*HasNSW=*/true);
// Add an @llvm.assume(linear_index < threads_per_block * num_blocks).
@@ -109,41 +107,41 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
// conditions in the same basic block as their operands.
llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::assume,
- {ir_builder_->CreateICmpULT(
+ {b_->CreateICmpULT(
linear_index_base,
llvm::ConstantInt::get(index_type,
launch_dimensions_.threads_per_block() *
launch_dimensions_.block_count()),
"linear_index_in_range")},
- {}, ir_builder_);
+ {}, b_);
if (unroll_factor_ > 1) {
- linear_index_base = ir_builder_->CreateMul(
+ linear_index_base = b_->CreateMul(
linear_index_base, llvm::ConstantInt::get(index_type, unroll_factor_),
"linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true);
}
- array_indices.emplace_back(linear_index_base, shape_, ir_builder_);
+ array_indices.emplace_back(linear_index_base, shape_, b_);
for (int i = 1; i < unroll_factor_; ++i) {
- llvm::Value* linear_index = ir_builder_->CreateAdd(
- linear_index_base, llvm::ConstantInt::get(index_type, i),
- "linear_index",
- /*HasNUW=*/true, /*HasNSW=*/true);
- array_indices.emplace_back(linear_index, shape_, ir_builder_);
+ llvm::Value* linear_index =
+ b_->CreateAdd(linear_index_base, llvm::ConstantInt::get(index_type, i),
+ "linear_index",
+ /*HasNUW=*/true, /*HasNSW=*/true);
+ array_indices.emplace_back(linear_index, shape_, b_);
}
auto if_in_bounds = llvm_ir::EmitIfThenElse(
- ir_builder_->CreateICmpULT(
+ b_->CreateICmpULT(
linear_index_base,
llvm::ConstantInt::get(index_type, ShapeUtil::ElementsIn(shape_))),
- llvm_ir::IrName(loop_name, "in_bounds"), ir_builder_, false);
+ llvm_ir::IrName(loop_name, "in_bounds"), b_, false);
// Set exit_bb_ to the exit block of the if structure.
exit_bb_ = if_in_bounds.after_block;
CHECK_NE(nullptr, exit_bb_);
// Set IR builder insertion point to the body of the if structure.
- llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, b_);
return array_indices;
}