diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc | 60 |
1 files changed, 29 insertions, 31 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index cd833ec7bd..3838fee674 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -32,27 +32,27 @@ namespace gpu { ParallelLoopEmitter::ParallelLoopEmitter( BodyEmitter body_emitter, const Shape& shape, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, int unroll_factor) - : LoopEmitter(body_emitter, shape, ir_builder), + : LoopEmitter(body_emitter, shape, b), launch_dimensions_(launch_dimensions), unroll_factor_(unroll_factor) {} ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice<llvm_ir::IrArray> target_arrays, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, int unroll_factor) - : LoopEmitter(target_element_generator, target_arrays, ir_builder), + : LoopEmitter(target_element_generator, target_arrays, b), launch_dimensions_(launch_dimensions), unroll_factor_(unroll_factor) {} ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, const llvm_ir::IrArray& target_array, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, int unroll_factor) - : LoopEmitter(target_element_generator, target_array, ir_builder), + : LoopEmitter(target_element_generator, target_array, b), launch_dimensions_(launch_dimensions), unroll_factor_(unroll_factor) {} @@ -74,29 +74,27 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( CHECK_NE(index_type, nullptr); std::vector<llvm_ir::IrArray::Index> array_indices; llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, ir_builder_); + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(), static_cast<llvm::Instruction*>(block_id)); - block_id = ir_builder_->CreateZExtOrTrunc(block_id, index_type, "block_id"); + block_id = b_->CreateZExtOrTrunc(block_id, index_type, "block_id"); // Per the PTX documentation: // "It is guaranteed that [...] 0 <= %tid.x < %ntid.x" // // %ntid.x is currently specified as 1024. llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, ir_builder_); + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.threads_per_block(), static_cast<llvm::Instruction*>(thread_id)); - thread_id = - ir_builder_->CreateZExtOrTrunc(thread_id, index_type, "thread_id"); - - llvm::Value* linear_index_base = ir_builder_->CreateAdd( - ir_builder_->CreateMul( - block_id, - llvm::ConstantInt::get(index_type, - launch_dimensions_.threads_per_block()), - "", - /*HasNUW=*/true, /*HasNSW=*/true), + thread_id = b_->CreateZExtOrTrunc(thread_id, index_type, "thread_id"); + + llvm::Value* linear_index_base = b_->CreateAdd( + b_->CreateMul(block_id, + llvm::ConstantInt::get( + index_type, launch_dimensions_.threads_per_block()), + "", + /*HasNUW=*/true, /*HasNSW=*/true), thread_id, "linear_index", /*HasNUW=*/true, /*HasNSW=*/true); // Add an @llvm.assume(linear_index < threads_per_block * num_blocks). @@ -109,41 +107,41 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( // conditions in the same basic block as their operands. llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::assume, - {ir_builder_->CreateICmpULT( + {b_->CreateICmpULT( linear_index_base, llvm::ConstantInt::get(index_type, launch_dimensions_.threads_per_block() * launch_dimensions_.block_count()), "linear_index_in_range")}, - {}, ir_builder_); + {}, b_); if (unroll_factor_ > 1) { - linear_index_base = ir_builder_->CreateMul( + linear_index_base = b_->CreateMul( linear_index_base, llvm::ConstantInt::get(index_type, unroll_factor_), "linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true); } - array_indices.emplace_back(linear_index_base, shape_, ir_builder_); + array_indices.emplace_back(linear_index_base, shape_, b_); for (int i = 1; i < unroll_factor_; ++i) { - llvm::Value* linear_index = ir_builder_->CreateAdd( - linear_index_base, llvm::ConstantInt::get(index_type, i), - "linear_index", - /*HasNUW=*/true, /*HasNSW=*/true); - array_indices.emplace_back(linear_index, shape_, ir_builder_); + llvm::Value* linear_index = + b_->CreateAdd(linear_index_base, llvm::ConstantInt::get(index_type, i), + "linear_index", + /*HasNUW=*/true, /*HasNSW=*/true); + array_indices.emplace_back(linear_index, shape_, b_); } auto if_in_bounds = llvm_ir::EmitIfThenElse( - ir_builder_->CreateICmpULT( + b_->CreateICmpULT( linear_index_base, llvm::ConstantInt::get(index_type, ShapeUtil::ElementsIn(shape_))), - llvm_ir::IrName(loop_name, "in_bounds"), ir_builder_, false); + llvm_ir::IrName(loop_name, "in_bounds"), b_, false); // Set exit_bb_ to the exit block of the if structure. exit_bb_ = if_in_bounds.after_block; CHECK_NE(nullptr, exit_bb_); // Set IR builder insertion point to the body of the if structure. - llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, b_); return array_indices; } |