diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | 1715 |
1 files changed, 1036 insertions, 679 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index bdb9e77da4..1f31a7f36b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -28,7 +28,7 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" @@ -48,6 +48,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" @@ -58,10 +59,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -70,6 +72,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" @@ -79,6 +82,7 @@ namespace gpu { namespace { +using llvm_ir::IrArray; using llvm_ir::IrName; using tensorflow::gtl::ArraySlice; using tensorflow::gtl::InlinedVector; @@ -211,7 +215,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( llvm::LLVMContext& context = module->getContext(); llvm::FunctionType* kernel_type = llvm::FunctionType::get( /*Result=*/llvm::Type::getVoidTy(context), - std::vector<llvm::Type*>(args.size(), ir_builder_.getInt8PtrTy()), + std::vector<llvm::Type*>(args.size(), b_.getInt8PtrTy()), /*isVarArg=*/false); llvm::Function* kernel = llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage, @@ -228,7 +232,9 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( kernel->addDereferenceableAttr(arg_no + 1, alloc->size()); kernel->addParamAttr( arg_no, llvm::Attribute::get(context, llvm::Attribute::Alignment, - kCudaMallocAlignBytes)); + alloc->is_entry_computation_parameter() + ? kEntryParameterAlignBytes + : kXlaAllocatedBufferAlignBytes)); if (alloc->IsPreallocatedTempBuffer()) { fn_arg->setName("temp_buf"); @@ -247,7 +253,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( nvvm_annotations_node->addOperand(llvm::MDNode::get( context, {llvm::ConstantAsMetadata::get(kernel), llvm::MDString::get(context, "kernel"), - llvm::ConstantAsMetadata::get(ir_builder_.getInt32(1))})); + llvm::ConstantAsMetadata::get(b_.getInt32(1))})); // Update the insert point to the entry basic block. llvm::BasicBlock* entry_bb = @@ -255,7 +261,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( // Emit a "return void" at entry_bb's end, and set the insert point before // that return instruction. - ir_builder_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb)); + b_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb)); return kernel; } @@ -293,7 +299,7 @@ int ComputeMaxUnrollFactor(const HloInstruction* hlo) { // range of i32. // Otherwise, the return type is i64. llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, - llvm::IRBuilder<>* ir_builder) { + llvm::IRBuilder<>* b) { // Find the unnested hlo instructon for which the kernel is generated for. const HloInstruction* unnested_hlo = hlo; const HloComputation* computation = hlo->parent(); @@ -314,7 +320,7 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, return in_range; }; - llvm::Type* i64_ty = ir_builder->getInt64Ty(); + llvm::Type* i64_ty = b->getInt64Ty(); // Check launch dimension if (!IsInt32(launch_size)) { return i64_ty; @@ -343,7 +349,7 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, } } - return ir_builder->getInt32Ty(); + return b->getInt32Ty(); } } // namespace @@ -355,7 +361,8 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { unroll_factor = ComputeMaxUnrollFactor(hlo); } - thunk_sequence_->emplace_back(BuildKernelThunk(hlo, unroll_factor)); + thunk_sequence_->emplace_back(BuildKernelThunk( + hlo, /*implements_whole_instruction=*/true, unroll_factor)); return IrEmitter::DefaultAction(hlo); } @@ -369,7 +376,8 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { thunk_sequence_->emplace_back(BuildGemmThunk(dot)); return Status::OK(); } - thunk_sequence_->emplace_back(BuildKernelThunk(dot)); + thunk_sequence_->emplace_back( + BuildKernelThunk(dot, /*implements_whole_instruction=*/true)); return IrEmitter::HandleDot(dot); } @@ -379,7 +387,8 @@ Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) { } Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) { - thunk_sequence_->emplace_back(BuildKernelThunk(convolution)); + thunk_sequence_->emplace_back( + BuildKernelThunk(convolution, /*implements_whole_instruction=*/true)); return IrEmitter::HandleConvolution(convolution); } @@ -586,16 +595,17 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { } } CHECK(first_reduce != nullptr); - thunks.push_back(BuildKernelThunk(fusion)); + thunks.push_back( + BuildKernelThunk(fusion, /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( MakeUnique<SequentialThunk>(std::move(thunks), fusion)); - std::vector<llvm_ir::IrArray> parameter_arrays; + std::vector<IrArray> parameter_arrays; for (HloInstruction* operand : fusion->operands()) { parameter_arrays.push_back(GetIrArray(*operand, *fusion)); } GpuElementalIrEmitter elemental_emitter( - hlo_module_config_, ir_emitter_context_->llvm_module(), - &ir_builder_, GetNestedComputer()); + hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, + GetNestedComputer()); FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); @@ -660,21 +670,22 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // touching the un-updated elements. // Set up kernel thunk and fused ir emitter. - thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); - std::vector<llvm_ir::IrArray> operand_arrays; + thunk_sequence_->emplace_back( + BuildKernelThunk(fusion, /*implements_whole_instruction=*/true)); + std::vector<IrArray> operand_arrays; for (HloInstruction* operand : fusion->operands()) { operand_arrays.push_back(GetIrArray(*operand, *fusion)); } GpuElementalIrEmitter elemental_emitter(hlo_module_config_, ir_emitter_context_->llvm_module(), - &ir_builder_, GetNestedComputer()); + &b_, GetNestedComputer()); // Shape of the dynamic-update-slice's "update" operand. Shape update_shape = root->operand(1)->shape(); // Array to write into. Because this is an in-place operation, this is the // same as operand 0's array. - llvm_ir::IrArray output_array = GetIrArray(*fusion, *fusion); + IrArray output_array = GetIrArray(*fusion, *fusion); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( update_shape, ir_emitter_context_->device_description()); @@ -685,316 +696,27 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace( fusion, operand_arrays, output_array, &elemental_emitter, - launch_dimensions, &ir_builder_); + launch_dimensions, &b_); } + if (ImplementedAsGemm(*fusion)) { thunk_sequence_->emplace_back(BuildGemmThunk(fusion)); return Status::OK(); } - CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); - int unroll_factor = ComputeMaxUnrollFactor(fusion); - - thunk_sequence_->emplace_back(BuildKernelThunk(fusion, unroll_factor)); - return IrEmitter::HandleFusion(fusion); -} + CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop); -namespace { - -// Returns the indices of the first elements of all consecutive subarrays of the -// given array. For example: -// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4} -std::vector<size_t> ConsecutiveSegments(tensorflow::gtl::ArraySlice<int64> xs) { - std::vector<size_t> is = {0}; - for (size_t i = 1; i < xs.size(); ++i) { - if (1 != xs[i] - xs[i - 1]) { - is.push_back(i); - } - } - return is; -} - -// Merges the sequences of dimensions of the given shape which start at the -// given indices `segs`. -Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs, - const Shape& shape) { - std::vector<int64> dimensions; - for (size_t i = 1; i <= segs.size(); ++i) { - dimensions.push_back(std::accumulate( - shape.dimensions().begin() + segs[i - 1], - shape.dimensions().begin() + - (segs.size() == i ? shape.dimensions().size() : segs[i]), - 1, std::multiplies<int64>())); - } - return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), - dimensions); -} - -// Returns whether the given shapes and permutation are a 0-2-1 transpose, and -// if so, the normalized and rank-reduced shapes. The shapes must have the same -// dimensions, so this considers layout only. -// -// This function recognizes higher-rank transposes which are elementwise -// equivalent to a 0-2-1 transpose. -std::tuple<bool, Shape, Shape> IsTranspose021(const Shape& a, const Shape& b) { - CHECK(ShapeUtil::Compatible(a, b)); - std::vector<int64> perm(a.dimensions().size()); - { - auto layout_a_orig = LayoutUtil::MinorToMajor(a); - std::vector<int64> layout_a(layout_a_orig.rbegin(), layout_a_orig.rend()); - auto layout_b_orig = LayoutUtil::MinorToMajor(b); - std::vector<int64> layout_b(layout_b_orig.rbegin(), layout_b_orig.rend()); - for (size_t i = 0; i < perm.size(); ++i) { - perm[i] = PositionInContainer(layout_b, layout_a[i]); - } + if (CheckAndEmitHloWithTile021(fusion)) { + return Status::OK(); } - auto segs = ConsecutiveSegments(perm); - Shape norm_a = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a); - Shape norm_b = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(b); - if (3 == segs.size() && 0 == perm[0]) { - Shape reduced_a = MergeDimensions(segs, norm_a); - Shape reduced_b = ShapeUtil::MakeShapeWithDescendingLayout( - b.element_type(), - Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions()))); - return std::make_tuple(true, reduced_a, reduced_b); - } - return std::make_tuple(false, ShapeUtil::MakeNil(), ShapeUtil::MakeNil()); -} - -// Returns whether the given shapes are potentially of a 0-2-1 transpose. -// As 0-2-1 is a self-inverse permutation, which shape is input or output is -// arbitrary. -bool AreShapesForTranspose021(const Shape& a, const Shape& b) { - return 3 == b.dimensions().size() && - ShapeUtil::Compatible( - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a), - ShapeUtil::PermuteDimensions( - {0, 2, 1}, - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - b))); -} - -// Emits a tiled 0-2-1 transpose, assuming both input and output lain out from -// major to minor. The x- and y- dimensions are tiled in square tiles of edge -// length `tile_size`. Each thread block of `tile_size` x `num_rows` threads -// transposes one tile: each thread copies a row from the input to a shared -// memory tile, then copies a column from the shared memory tile to the output. -// -// `tile_size` should usually be same as warp size. -// -// Returns (number of tiles = number of thread blocks needed). -// -// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient -// to launch fewer blocks so each transposes many tiles, and -// in any case, the number of blocks we can launch is limited. -// -// This is the same algorithm in CUDA: -// https://github.com/tensorflow/tensorflow/blob/d2693c8a70567cc78b2e8a9ac8020d321620ca83/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L189 -int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, - const int64 tile_size, const int64 num_rows, - llvm::IRBuilder<>* builder) { - // Adds `addend` to the given `dim` of `index`. - auto offset_dim = [builder](llvm_ir::IrArray::Index index, - llvm::Value* addend, int64 dim) { - index[dim] = builder->CreateAdd(index[dim], addend); - return index; - }; - CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape())); - - Shape input_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input.GetShape()); - Shape output_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - output.GetShape()); - input = input.CastToShape(input_shape, builder); - output = output.CastToShape(output_shape, builder); - - llvm::Type* tile_type = llvm::ArrayType::get( - llvm::ArrayType::get(input.GetElementLlvmType(), tile_size), - // One extra here to avoid share memory bank conflict - tile_size + 1); - auto* tile = new llvm::GlobalVariable( - *builder->GetInsertBlock()->getParent()->getParent(), tile_type, - /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, - llvm::UndefValue::get(tile_type), "tile", nullptr, - llvm::GlobalValue::NotThreadLocal, - /*AddressSpace=*/3 /* GPU shared memory */); - - // let x = threadIdx.x - llvm::Value* x = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, num_rows * tile_size, - static_cast<llvm::Instruction*>(x)); - x = builder->CreateIntCast(x, builder->getInt64Ty(), /*isSigned=*/true, - "thread.id.x"); - - // computing logical thread ids - // logical_x = x % tile_size - auto logical_x = builder->CreateURem(x, builder->getInt64(tile_size)); - - // logical_y = x / tile_size - auto logical_y = builder->CreateUDiv(x, builder->getInt64(tile_size)); - - // `emit_cp` emits equivalent to following pseudocode: - // if (tile_size == tile_width && tile_size == tile_height) { - // unroll for (i in range(0, tile_size, num_rows)) { - // emit_cp_element(index + {0, i, 0}, y + logical_y); - // } - // } else if (x < tile_width) { - // tile_height_upperbound = ceil(tile_height / num_rows) * num_rows; - // for (i in range(0, tile_height_upperbound, num_rows)) { - // y_loc = i + logical_y; - // if (y_loc < tile_height) - // emit_cp_element(index + {0, i, 0}, y_loc); - // } - // } - // - // We use this to emit both the copy from input to tile and the copy from tile - // to output. - // - // `index` is the origin of the row or column in the input or output array. - // - // `emit_cp_element(index, y)` emits code to copy a single element between the - // tile and the input or output array, where `y` is the `y`-position in the - // tile, whether which is row or column is a function of whether we're copying - // from input or to output, and `index` is the index into the input or output - // array. - auto emit_cp_tile = [builder, tile_size, &offset_dim, num_rows, logical_x, - logical_y]( - std::function<void(const llvm_ir::IrArray::Index&, - llvm::Value*)> - emit_cp_element, - llvm::Value* tile_width, llvm::Value* tile_height, - const llvm_ir::IrArray::Index& index, - const string& loop_name) { - llvm_ir::LlvmIfData if_not_last_row = llvm_ir::EmitIfThenElse( - builder->CreateAnd( - builder->CreateICmpEQ(builder->getInt64(tile_size), tile_width), - builder->CreateICmpEQ(builder->getInt64(tile_size), tile_height)), - "not_last_row", builder); - builder->SetInsertPoint(if_not_last_row.true_block->getTerminator()); - for (int64 i = 0; i < tile_size; i += num_rows) { - auto source_idx = offset_dim(index, builder->getInt64(i), /*dim=*/1); - auto y_loc = builder->CreateAdd(builder->getInt64(i), logical_y); - emit_cp_element(source_idx, y_loc); - } - builder->SetInsertPoint(if_not_last_row.false_block->getTerminator()); - llvm_ir::LlvmIfData if_in_tile = llvm_ir::EmitIfThenElse( - builder->CreateICmpULT(logical_x, tile_width), "x_in_tile", builder); - builder->SetInsertPoint(if_in_tile.true_block->getTerminator()); - - // tile_height_upper_bound = ceil(tile_height / num_rows) * num_rows - auto tile_height_upper_bound = builder->CreateMul( - builder->CreateUDiv( - builder->CreateAdd(tile_height, builder->getInt64(num_rows - 1)), - builder->getInt64(num_rows)), - builder->getInt64(num_rows)); - - auto loop = llvm_ir::ForLoop::EmitForLoop( - loop_name, builder->getInt64(0), tile_height_upper_bound, - builder->getInt64(num_rows), builder); - llvm_ir::SetToFirstInsertPoint(loop->GetHeaderBasicBlock(), builder); - builder->SetInsertPoint(loop->GetBodyBasicBlock()->getTerminator()); - - auto y_loc = builder->CreateAdd(loop->GetIndVarValue(), logical_y); - auto if_y_in_tile = llvm_ir::EmitIfThenElse( - builder->CreateICmpULT(y_loc, tile_height), "y_in_tile", builder); - builder->SetInsertPoint(if_y_in_tile.true_block->getTerminator()); - - emit_cp_element(offset_dim(index, loop->GetIndVarValue(), /*dim=*/1), - y_loc); - builder->SetInsertPoint(if_not_last_row.after_block->getTerminator()); - }; - - auto input_dims_in_tiles = input_shape.dimensions(); - // Unpermuted dimensions are untiled. - for (int i = 1; i < 3; ++i) { - input_dims_in_tiles[i] = - CeilOfRatio<int64>(input_dims_in_tiles[i], tile_size); - } - int64 num_tiles = - std::accumulate(input_dims_in_tiles.begin(), input_dims_in_tiles.end(), 1, - std::multiplies<int64>()); - const llvm_ir::IrArray::Index input_tile_index( - /*linear=*/builder->CreateIntCast( - llvm_ir::AddRangeMetadata( - 0, num_tiles, - static_cast<llvm::Instruction*>(llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, - builder))), - builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"), - ShapeUtil::MakeShapeWithDescendingLayout( - PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)), - builder); - const llvm_ir::IrArray::Index input_tile_origin = ({ - llvm_ir::IrArray::Index index = input_tile_index; - for (int i = 1; i < 3; ++i) { - index[i] = builder->CreateMul(index[i], builder->getInt64(tile_size), - "tile_origin." + std::to_string(i)); - } - index; - }); - const llvm_ir::IrArray::Index input_index = - offset_dim(offset_dim(input_tile_origin, logical_x, /*dim=*/2), logical_y, - /*dim=*/1); - std::vector<llvm::Value*> tile_dims(input_shape.dimensions().size()); - // Only last row or column may not have full size. - for (int i = 1; i < 3; ++i) { - tile_dims[i] = builder->CreateSelect( - builder->CreateICmpEQ(input_tile_index[i], - builder->getInt64(input_dims_in_tiles[i] - 1)), - builder->getInt64(input_shape.dimensions(i) - - (input_dims_in_tiles[i] - 1) * tile_size), - builder->getInt64(tile_size), "tile_size"); - } - - // Load data from input memory to shared memory tile. - emit_cp_tile( - // tile[y, x] = input_array[index] - [builder, tile, &input, logical_x](const llvm_ir::IrArray::Index& index, - llvm::Value* y) { - builder->CreateStore( - input.EmitReadArrayElement(index, builder, "input_element"), - builder->CreateGEP(tile, {builder->getInt64(0), y, logical_x})); - }, - tile_dims[2], tile_dims[1], input_index, "input"); + int unroll_factor = ComputeMaxUnrollFactor(fusion); - // Wait for all threads to reach this point, lest we copy a value from tile to - // output before the other thread copies it from input to tile. - // This is `__syncthreads` in CUDA. - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, builder); - - const llvm_ir::IrArray::Index output_tile_index( - Permute({0, 2, 1}, input_tile_index.multidim())); - const llvm_ir::IrArray::Index output_tile_origin( - Permute({0, 2, 1}, input_tile_origin.multidim())); - const llvm_ir::IrArray::Index output_index = - offset_dim(offset_dim(output_tile_origin, logical_x, /*dim=*/2), - logical_y, /*dim=*/1); - - // Store data from shared memory tile to output memory. - emit_cp_tile( - // output_array[index] = tile[x, y] - [builder, tile, &output, logical_x](const llvm_ir::IrArray::Index& index, - llvm::Value* y) { - output.EmitWriteArrayElement( - index, - builder->CreateLoad( - builder->CreateGEP(tile, {builder->getInt64(0), logical_x, y}), - "output_element"), - builder); - }, - tile_dims[1], tile_dims[2], output_index, "output"); - - return num_tiles; + thunk_sequence_->emplace_back(BuildKernelThunk( + fusion, /*implements_whole_instruction=*/true, unroll_factor)); + return IrEmitter::HandleFusion(fusion); } -} // namespace - Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { if (ImplementedAsHostToDeviceMemcpy(ir_emitter_context_->buffer_assignment(), *copy)) { @@ -1006,25 +728,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { thunk_sequence_->emplace_back(BuildDeviceToDeviceCopyThunk(copy)); return Status::OK(); } - bool is_transpose_021; - Shape reduced_input_shape, reduced_output_shape; - std::tie(is_transpose_021, reduced_input_shape, reduced_output_shape) = - IsTranspose021(copy->operand(0)->shape(), copy->shape()); - if (is_transpose_021 && - reduced_input_shape.dimensions(1) >= kMinDimensionToTransposeTiled && - reduced_input_shape.dimensions(2) >= kMinDimensionToTransposeTiled) { - thunk_sequence_->emplace_back(BuildKernelThunk(copy)); - VLOG(3) << "Emitting tiled 0-2-1 transposition"; - constexpr int64 tile_size = 32; - constexpr int64 num_rows = 8; - int64 num_tiles = EmitTranspose021Tiled( - GetIrArray(*copy->operand(0), *copy) - .CastToShape(reduced_input_shape, &ir_builder_), - GetIrArray(*copy, *copy) - .CastToShape(reduced_output_shape, &ir_builder_), - tile_size, num_rows, &ir_builder_); - UpdateLaunchDimensions(LaunchDimensions(num_tiles, num_rows * tile_size), - LastThunk(), ir_emitter_context_->llvm_module()); + if (CheckAndEmitHloWithTile021(copy)) { return Status::OK(); } @@ -1032,7 +736,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { } Status IrEmitterUnnested::EmitExtraOutputsForReduce( - const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, + const HloInstruction* reduce, const IrArray::Index& index, tensorflow::gtl::ArraySlice< std::pair<llvm_ir::ElementGenerator, ShapeIndex>> extra_output_gens) { @@ -1040,11 +744,11 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( const HloInstruction* output = reduce->parent()->FusionInstruction(); llvm::Value* extra_output_address = GetIrArray(*output, *output, extra_output_gens[i].second) - .EmitArrayElementAddress(index, &ir_builder_, + .EmitArrayElementAddress(index, &b_, "extra_output_element_address"); TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, extra_output_gens[i].first(index)); - ir_builder_.CreateStore(extra_output_ir_value, extra_output_address); + b_.CreateStore(extra_output_ir_value, extra_output_address); } return Status::OK(); } @@ -1074,12 +778,10 @@ Status IrEmitterUnnested::EmitReductionToScalar( LaunchDimensions launch_dimensions = CalculateLaunchDimensions( tiled_input_shape, ir_emitter_context_->device_description()); - llvm::Type* index_ty = GetIndexTypeForKernel( - reduce, - launch_dimensions.block_count() * launch_dimensions.threads_per_block(), - &ir_builder_); + llvm::Type* index_ty = + GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_); - auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; @@ -1121,59 +823,57 @@ Status IrEmitterUnnested::EmitReductionToScalar( // // and threads_per_block is a multiple of warpSize. // reduce_kernel<<<num_blocks, threads_per_block>>>(); // - auto loop_body_emitter = - [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { const int num_reduces = reducers.size(); llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); std::vector<llvm::Value*> partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN( - llvm::Value* const init_ir_value, - init_value_gens[i](llvm_ir::IrArray::Index(index_ty))); - ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + llvm::Value* partial_reduction_result_address = + b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](IrArray::Index(index_ty))); + b_.CreateStore(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } llvm::Value* x_in_tiles = tile_index[0]; - x_in_tiles = ir_builder_.CreateZExtOrTrunc(x_in_tiles, index_ty); + x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty); // Emit an inner for-loop that reduces the elements in the tile. auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { std::unique_ptr<llvm_ir::ForLoop> tile_element_loop = llvm_ir::ForLoop::EmitForLoop( - "element_id_in_tile", index_typed_const(0), - index_typed_const(kTileSize), index_typed_const(1), &ir_builder_); + "element_id_in_tile", index_typed_constant(0), + index_typed_constant(kTileSize), index_typed_constant(1), &b_); // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &ir_builder_); - llvm::Value* x = ir_builder_.CreateNSWAdd( - ir_builder_.CreateNSWMul(x_in_tiles, index_typed_const(kTileSize)), + &b_); + llvm::Value* x = b_.CreateNSWAdd( + b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)), tile_element_loop->GetIndVarValue()); // Unless we know the tile is entirely in bounds, we have to emit a // x-in-bounds check before reading from the input. if (!tile_in_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(x, index_typed_const(num_elems)), - "x_in_bounds", &ir_builder_); + b_.CreateICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", + &b_); // Emit code that reads the input element and accumulates it to // the partial reduction result. - llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); } - llvm_ir::IrArray::Index input_index( - /*linear=*/x, input_shape, &ir_builder_); - llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); + IrArray::Index input_index( + /*linear=*/x, input_shape, &b_); + llvm::Value* input_address = b_.CreateAlloca(element_ir_type); for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); + b_.CreateStore(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], input_address}, @@ -1184,49 +884,48 @@ Status IrEmitterUnnested::EmitReductionToScalar( // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's // immediately beyond the tile. - llvm::Value* x_end = ir_builder_.CreateNSWAdd( - index_typed_const(kTileSize), - ir_builder_.CreateNSWMul(x_in_tiles, index_typed_const(kTileSize))); + llvm::Value* x_end = b_.CreateNSWAdd( + index_typed_constant(kTileSize), + b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize))); // The tile is entirely in bound if all_threads_in_bounds or // x_end <= num_elems. - llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.CreateICmpULE(x_end, index_typed_const(num_elems)), - ir_builder_.getInt1(all_threads_in_bounds)); + llvm::Value* tile_in_bounds = + b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(num_elems)), + b_.getInt1(all_threads_in_bounds)); llvm_ir::LlvmIfData if_tile_in_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, - &ir_builder_); + llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_); TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, - &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, &b_); TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); // After the if-then-else statement on tile_in_bounds, emit calls to // shfl_down that accumulate the partial reduction results of all threads // from the warp. - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, - &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, &b_); int bit_width = llvm_ir::GetSizeInBits(element_ir_type); // bitcast cannot be applied to aggregate types (even packed ones), so we // instead bitcast addresses of load/store to intN* of the same bit-width. llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() - ? ir_builder_.getIntNTy(bit_width) + ? b_.getIntNTy(bit_width) : element_ir_type; for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; shuffle_distance /= 2) { - llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( - element_ir_type, nullptr, "result_from_other_lane"); + llvm::Value* result_from_other_lane = + b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane"); for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), + llvm::Value* partial_reduction_result = b_.CreateLoad( + b_.CreateBitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), "partial_reduction_result"); - ir_builder_.CreateStore( - EmitShuffleDown(partial_reduction_result, - ir_builder_.getInt32(shuffle_distance), - &ir_builder_), - ir_builder_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); + CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) + << "Requires block size a multiple of the warp size, otherwise we " + "will read undefined elements."; + b_.CreateStore( + EmitFullWarpShuffleDown(partial_reduction_result, + b_.getInt32(shuffle_distance), &b_), + b_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], result_from_other_lane}, @@ -1240,24 +939,23 @@ Status IrEmitterUnnested::EmitReductionToScalar( // Emit an atomic operation that accumulates the partial reduction result of // lane 0 (which holds the partially accumulated result for its warp) to the // output element. - llvm::Value* lane_id = ir_builder_.CreateURem( - x_in_tiles, index_typed_const(kWarpSize), "lane_id"); + llvm::Value* lane_id = + b_.CreateURem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id"); llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpEQ(lane_id, index_typed_const(0)), - "lane_id_is_zero", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, - &ir_builder_); + b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", + &b_); + llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); for (int i = 0; i != num_reduces; ++i) { llvm::Value* output_address = GetIrArray(*output, *output, reduce_output_shapes[i]) .EmitArrayElementAddress( - llvm_ir::IrArray::Index( - /*linear=*/ir_builder_.getInt64(0), + IrArray::Index( + /*linear=*/b_.getInt64(0), ShapeUtil::GetSubshape(output->shape(), reduce_output_shapes[i]), - &ir_builder_), - &ir_builder_, "output_element_address"); + &b_), + &b_, "output_element_address"); TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( *reducers[i], output_address, partial_reduction_result_addresses[i])); } @@ -1271,7 +969,7 @@ Status IrEmitterUnnested::EmitReductionToScalar( static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(), ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &ir_builder_) + launch_dimensions, &b_) .EmitLoop(IrName(reduce), index_ty); } @@ -1284,8 +982,8 @@ Status IrEmitterUnnested::EmitColumnReduction( tensorflow::gtl::ArraySlice< std::pair<llvm_ir::ElementGenerator, ShapeIndex>> extra_output_gens) { - // Divide the input matrix into tiles of size Kx1. For example, when the - // input matrix is 4x4 and K=2, the tiled matrix looks like + // Divide the input matrix into tiles of size KxL. For example, when the + // input matrix is 4x4, K=2, and L=1 the tiled matrix looks like // // 0123 // 0123 @@ -1297,100 +995,131 @@ Status IrEmitterUnnested::EmitColumnReduction( // // We choose 128 as the tile size based on empirical evidence. It's big enough // to reduce the amount of atomic adds in the end, maximizing the memory - // bandwidth. - constexpr int64 kTileSize = 128; + // bandwidth. A tile width of 2 allows for high memory bandwidth utilization + // on 16b input data. + constexpr int64 kTileHeight = 128; + constexpr int64 kTileWidth = 2; - // If the height is not a multiple of the tile size, we pad the bottom of the + // If the height is not a multiple of kTileHeight, we pad the bottom of the // input matrix. - const int64 height_in_tiles = CeilOfRatio(height, kTileSize); - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {height_in_tiles, width}, {1, 0}); + const int64 height_in_tiles = CeilOfRatio(height, kTileHeight); + // If width is not a multiple of kTileWidth the rightmost thread will process + // fewer input elements. + const int64 width_in_tiles = CeilOfRatio(width, kTileWidth); + Shape tiled_input_shape = + ShapeUtil::MakeShapeWithLayout(reduce->shape().element_type(), + {height_in_tiles, width_in_tiles}, {1, 0}); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( tiled_input_shape, ir_emitter_context_->device_description()); // TODO(b/110211620): Convert to use i32 index_type when it is possible. - llvm::Type* index_ty = ir_builder_.getInt64Ty(); + llvm::Type* index_ty = b_.getInt64Ty(); - auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; - // linear_index < height_in_tiles * width; + // linear_index < height_in_tiles * width_in_tiles; // linear_index += blockDim.x * gridDim.x) { - // y_in_tiles = linear_index / width; - // x = linear_index % width; + // y_in_tiles = linear_index / width_in_tiles; + // x_in_tiles = linear_index % width_in_tiles; // - // partial_result = init_value; - // if (height % kTileSize == 0 || - // y_in_tiles * kTileSize + kTileSize <= height) { - // for (element_id_in_tile : range(kTileSize)) { - // y = y_in_tiles * kTileSize + element_id_in_tile; - // partial_result = Reducer(partial_result, input[y][x]); + // partial_results[kTileWidth] = init_values; + // tile_in_y_bounds = height % kTileHeight == 0 || + // y_in_tiles * kTileHeight + kTileHeight <= height; + // tile_in_x_bounds = width % kTileWidth == 0 || + // x_in_tiles * kTileWidth + kTileWidth <= width; + // // The implementation handles y and x bound checks separately. + // if (tile_in_y_bounds && tile_in_x_bounds) { + // for (y_offset : range(kTileHeight)) { + // y = y_in_tiles * kTileHeight + y_offset; + // for (x_offset : range(kTileWidth)) { + // x = x_in_tiles * kTileWidth + x_offset; + // partial_result = Reducer(partial_result[x_offset], input[y][x]); + // } // } // } else { - // for (element_id_in_tile : range(kTileSize)) { - // y = y_in_tiles * kTileSize + element_id_in_tile; - // if (y < height) { - // partial_result = Reducer(partial_result, input[y][x]); + // for (y_offset : range(kTileHeight)) { + // y = y_in_tiles * kTileHeight + y_offset; + // for (y_offset : range(kTileHeight)) { + // x = x_in_tiles * kTileWidth + x_offset; + // if (y < height && x < width) { + // partial_result = Reducer(partial_result, input[y][x]); + // } // } // } // } - // AtomicReducer(&output[x], partial_result); + // for (x_offset : range(kTileWidth)) { + // AtomicReducer(&output[x + x_offset], partial_result[x_offset]); + // } // } - auto loop_body_emitter = - [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { const int num_reduces = reducers.size(); // Emit the loop body that reduces one tile. llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); std::vector<llvm::Value*> partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN( - llvm::Value* const init_ir_value, - init_value_gens[i](llvm_ir::IrArray::Index(index_ty))); - ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); - partial_reduction_result_addresses.push_back( - partial_reduction_result_address); + for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { + llvm::Value* partial_reduction_result_address = + b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + + llvm::Twine(i * kTileWidth + x_offset)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](IrArray::Index(index_ty))); + b_.CreateStore(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); + } } // Emit an inner for-loop that partially reduces the elements in the given // tile. llvm::Value* y_in_tiles = tile_index[0]; - llvm::Value* x = tile_index[1]; + llvm::Value* x_in_tiles = tile_index[1]; - y_in_tiles = ir_builder_.CreateZExtOrTrunc(y_in_tiles, index_ty); - x = ir_builder_.CreateZExtOrTrunc(x, index_ty); + y_in_tiles = b_.CreateZExtOrTrunc(y_in_tiles, index_ty); + x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty); - auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { + auto emit_tile_element_loop = [=](bool tile_in_y_bounds, + bool tile_in_x_bounds) -> Status { std::unique_ptr<llvm_ir::ForLoop> tile_element_loop = llvm_ir::ForLoop::EmitForLoop( - "element_id_in_tile", index_typed_const(0), - index_typed_const(kTileSize), index_typed_const(1), &ir_builder_); + "element_id_in_tile", index_typed_constant(0), + index_typed_constant(kTileHeight), index_typed_constant(1), &b_); // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &ir_builder_); - llvm::Value* y = ir_builder_.CreateNSWAdd( - ir_builder_.CreateNSWMul(y_in_tiles, index_typed_const(kTileSize)), + &b_); + llvm::Value* y = b_.CreateNSWAdd( + b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)), tile_element_loop->GetIndVarValue()); - // Unless we know the tile is entirely in bounds, we have to emit a - // y-in-bounds check before reading from the input. - if (!tile_in_bounds) { + // Unless we know that y is in bounds, we have to emit a check before + // reading from the input. + if (!tile_in_y_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(y, index_typed_const(height)), - "y_in_bounds", &ir_builder_); + b_.CreateICmpULT(y, index_typed_constant(height)), "y_in_bounds", + &b_); // Emit code that reads the input element and accumulates it to // the partial reduction result. - llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); } - llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); - { + for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { + llvm::Value* x = b_.CreateNSWAdd( + b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); + // Unless we know that x is in bounds, we have to emit a check before + // reading from the input. + if (!tile_in_x_bounds) { + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + b_.CreateICmpULT(x, index_typed_constant(width)), "x_in_bounds", + &b_); + llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); + } + llvm::Value* input_address = b_.CreateAlloca(element_ir_type); // {y,x} is an index to input_matrix_shape [height,width]. We need to // convert that to an index to input_shape (the shape of the operand of // "reduce"). This conversion is composed of a transposition from @@ -1406,67 +1135,95 @@ Status IrEmitterUnnested::EmitColumnReduction( const Shape input_matrix_shape = ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), {height, width}); - const llvm_ir::IrArray::Index input_matrix_index( - {y, x}, input_matrix_shape, &ir_builder_); - const llvm_ir::IrArray::Index input_index = + const IrArray::Index input_matrix_index({y, x}, input_matrix_shape, + &b_); + const IrArray::Index input_index = input_matrix_index .SourceIndexOfReshape(input_matrix_shape, - normalized_input_shape, &ir_builder_) + normalized_input_shape, &b_) .SourceIndexOfTranspose(normalized_input_shape, input_shape, - transpose_dimension_mapping, - &ir_builder_); + transpose_dimension_mapping, &b_); for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); + b_.CreateStore(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], - {partial_reduction_result_addresses[i], input_address}, - partial_reduction_result_addresses[i])); + {partial_reduction_result_addresses[i * kTileWidth + x_offset], + input_address}, + partial_reduction_result_addresses[i * kTileWidth + x_offset])); + TF_RETURN_IF_ERROR(EmitExtraOutputsForReduce(reduce, input_index, + extra_output_gens)); } - return EmitExtraOutputsForReduce(reduce, input_index, - extra_output_gens); } + return Status::OK(); }; - // y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's - // immediately beyond the tile. - llvm::Value* y_end = ir_builder_.CreateNSWAdd( - index_typed_const(kTileSize), - ir_builder_.CreateNSWMul(y_in_tiles, index_typed_const(kTileSize))); - llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.CreateICmpULE(y_end, index_typed_const(height)), - ir_builder_.getInt1(height % kTileSize == 0)); - // The tile is entirely in bound if "height" is a multiple of kTileSize or + // y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location + // that's immediately beyond the tile. + llvm::Value* y_end = b_.CreateNSWAdd( + index_typed_constant(kTileHeight), + b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight))); + // x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location + // that's immediately beyond the tile. + llvm::Value* x_end = b_.CreateNSWAdd( + index_typed_constant(kTileWidth), + b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth))); + llvm::Value* tile_in_y_bounds = + b_.CreateOr(b_.CreateICmpULE(y_end, index_typed_constant(height)), + b_.getInt1(height % kTileHeight == 0)); + llvm::Value* tile_in_x_bounds = + b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(width)), + b_.getInt1(width % kTileWidth == 0)); + // The tile is in y bounds if "height" is a multiple of kTileHeight or // y_end <= height. - llvm_ir::LlvmIfData if_tile_in_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, - &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, - &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); - - // After the if-then-else statement on tile_in_bounds, emit atomic - // operations to accumulate the partial reduction result to the output - // element. - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, - &ir_builder_); + llvm_ir::LlvmIfData if_tile_in_y_bounds_data = + llvm_ir::EmitIfThenElse(tile_in_y_bounds, "tile_in_y_bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.true_block, &b_); + // The tile is in x bounds if "width" is a multiple of kTileWidth or + // x_end <= width. + llvm_ir::LlvmIfData if_tile_in_x_bounds_data = + llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true, + /*tile_in_x_bounds=*/true)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true, + /*tile_in_x_bounds=*/false)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.false_block, &b_); + if_tile_in_x_bounds_data = + llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false, + /*tile_in_x_bounds=*/true)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false, + /*tile_in_x_bounds=*/false)); + + // After the nested if-then-else statement on tile_in_y_bounds and + // tile_in_x_bounds, emit atomic operations to accumulate the partial + // reduction result to the output element. + llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.after_block, &b_); const HloInstruction* output = reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; for (int i = 0; i != num_reduces; ++i) { - llvm::Value* output_address = - GetIrArray(*output, *output, reduce_output_shapes[i]) - .EmitArrayElementAddress( - llvm_ir::IrArray::Index( - x, - ShapeUtil::GetSubshape(output->shape(), - reduce_output_shapes[i]), - &ir_builder_), - &ir_builder_, "output_element_address"); - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, partial_reduction_result_addresses[i])); + for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { + llvm::Value* x = b_.CreateNSWAdd( + b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + IrArray::Index( + x, + ShapeUtil::GetSubshape(output->shape(), + reduce_output_shapes[i]), + &b_), + &b_, "output_element_address"); + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, + partial_reduction_result_addresses[i * kTileWidth + x_offset])); + } } return Status::OK(); }; @@ -1478,7 +1235,7 @@ Status IrEmitterUnnested::EmitColumnReduction( static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(), ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &ir_builder_) + launch_dimensions, &b_) .EmitLoop(IrName(reduce), index_ty); } @@ -1628,28 +1385,25 @@ Status IrEmitterUnnested::EmitRowReduction( {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0}); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( tiled_input_shape, ir_emitter_context_->device_description()); - llvm::Type* index_ty = GetIndexTypeForKernel( - reduce, - launch_dimensions.block_count() * launch_dimensions.threads_per_block(), - &ir_builder_); + llvm::Type* index_ty = + GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_); - auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; - auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) { + auto loop_body_emitter = [=](const IrArray::Index& tile_index) { const int num_reduces = reducers.size(); llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( input_shape.element_type(), ir_emitter_context_->llvm_module()); std::vector<llvm::Value*> partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN( - llvm::Value* const init_ir_value, - init_value_gens[i](llvm_ir::IrArray::Index(index_ty))); - ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + llvm::Value* partial_reduction_result_address = + b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](IrArray::Index(index_ty))); + b_.CreateStore(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } @@ -1658,25 +1412,25 @@ Status IrEmitterUnnested::EmitRowReduction( llvm::Value* y = tile_index[1]; llvm::Value* x_tile = tile_index[2]; - x_tile = ir_builder_.CreateZExtOrTrunc(x_tile, index_ty); + x_tile = b_.CreateZExtOrTrunc(x_tile, index_ty); llvm::Value* warp_id = - ir_builder_.CreateUDiv(x_tile, index_typed_const(kWarpSize), "warp_id"); + b_.CreateUDiv(x_tile, index_typed_constant(kWarpSize), "warp_id"); llvm::Value* lane_id = - ir_builder_.CreateURem(x_tile, index_typed_const(kWarpSize), "lane_id"); + b_.CreateURem(x_tile, index_typed_constant(kWarpSize), "lane_id"); // The x-location of the last element in this z-x-tile. // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size); - llvm::Value* last_x = ir_builder_.CreateNSWAdd( - lane_id, ir_builder_.CreateNSWMul( - index_typed_const(kWarpSize), - ir_builder_.CreateNSWAdd( - index_typed_const(x_tile_size - 1), - ir_builder_.CreateNSWMul( - warp_id, index_typed_const(x_tile_size))))); + llvm::Value* last_x = b_.CreateNSWAdd( + lane_id, + b_.CreateNSWMul( + index_typed_constant(kWarpSize), + b_.CreateNSWAdd( + index_typed_constant(x_tile_size - 1), + b_.CreateNSWMul(warp_id, index_typed_constant(x_tile_size))))); KernelSupportLibrary ksl( - &ir_builder_, + &b_, /*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll, /*prevent_vectorization=*/false); @@ -1685,22 +1439,22 @@ Status IrEmitterUnnested::EmitRowReduction( auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, int64 x_tile_loop_bound) -> Status { auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { - llvm::Value* z = ir_builder_.CreateNSWAdd( + llvm::Value* z = b_.CreateNSWAdd( z_indvar, - ir_builder_.CreateNSWMul(index_typed_const(z_tile_size), z_tile)); + b_.CreateNSWMul(index_typed_constant(z_tile_size), z_tile)); TF_RETURN_IF_ERROR(ksl.For( "x_tile", - /*start=*/index_typed_const(0), - /*end=*/index_typed_const(x_tile_loop_bound), + /*start=*/index_typed_constant(0), + /*end=*/index_typed_constant(x_tile_loop_bound), /*step=*/1, [&](llvm::Value* x_indvar) -> Status { // x = lane_id + // warpSize * (element_id_in_x_tile + warp_id * x_tile_size); - llvm::Value* x = ir_builder_.CreateNSWAdd( + llvm::Value* x = b_.CreateNSWAdd( lane_id, - ir_builder_.CreateNSWMul( - index_typed_const(kWarpSize), - ir_builder_.CreateNSWAdd( - x_indvar, ir_builder_.CreateNSWMul( + b_.CreateNSWMul( + index_typed_constant(kWarpSize), + b_.CreateNSWAdd( + x_indvar, b_.CreateNSWMul( warp_id, llvm::ConstantInt::get( index_ty, x_tile_size))))); @@ -1709,17 +1463,16 @@ Status IrEmitterUnnested::EmitRowReduction( if (!x_tile_in_bounds) { llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(x, index_typed_const(width)), - "x_in_bounds", &ir_builder_); - // Points ir_builder_ to the then-block. + b_.CreateICmpULT(x, index_typed_constant(width)), + "x_in_bounds", &b_); + // Points b_ to the then-block. llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, - &ir_builder_); + &b_); } // Emit code that reads the input element and accumulates it // to the partial reduction result. - llvm::Value* input_address = - ir_builder_.CreateAlloca(element_ir_type); + llvm::Value* input_address = b_.CreateAlloca(element_ir_type); { // {z,y,x} is an index to input_3d_tensor_shape // [depth,height,width]. We need to convert that to an index @@ -1737,21 +1490,20 @@ Status IrEmitterUnnested::EmitRowReduction( const Shape input_3d_tensor_shape = ShapeUtil::MakeShapeWithDescendingLayout( input_shape.element_type(), {depth, height, width}); - const llvm_ir::IrArray::Index input_3d_tensor_index( - {z, y, x}, input_3d_tensor_shape, &ir_builder_); - const llvm_ir::IrArray::Index input_index = + const IrArray::Index input_3d_tensor_index( + {z, y, x}, input_3d_tensor_shape, &b_); + const IrArray::Index input_index = input_3d_tensor_index .SourceIndexOfReshape(input_3d_tensor_shape, - normalized_input_shape, - &ir_builder_) + normalized_input_shape, &b_) .SourceIndexOfTranspose( normalized_input_shape, input_shape, - transpose_dimension_mapping, &ir_builder_); + transpose_dimension_mapping, &b_); for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); + b_.CreateStore(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], input_address}, @@ -1765,14 +1517,14 @@ Status IrEmitterUnnested::EmitRowReduction( }; return ksl.For("z_tile", - /*start=*/index_typed_const(0), - /*end=*/index_typed_const(z_tile_size), + /*start=*/index_typed_constant(0), + /*end=*/index_typed_constant(z_tile_size), /*step=*/1, emit_z_tile_element_loop); }; - llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.getInt1(width % (x_tile_size * kWarpSize) == 0), - ir_builder_.CreateICmpULT(last_x, index_typed_const(width))); + llvm::Value* tile_in_bounds = + b_.CreateOr(b_.getInt1(width % (x_tile_size * kWarpSize) == 0), + b_.CreateICmpULT(last_x, index_typed_constant(width))); TF_RETURN_IF_ERROR( ksl.If(tile_in_bounds, @@ -1795,23 +1547,25 @@ Status IrEmitterUnnested::EmitRowReduction( // bitcast cannot be applied to aggregate types (even packed ones), so we // instead bitcast addresses of load/store to intN* of the same bit-width. llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() - ? ir_builder_.getIntNTy(bit_width) + ? b_.getIntNTy(bit_width) : element_ir_type; for (int shuffle_distance = 16; shuffle_distance >= 1; shuffle_distance /= 2) { - llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( - element_ir_type, nullptr, "result_from_other_lane"); + llvm::Value* result_from_other_lane = + b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane"); for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), + llvm::Value* partial_reduction_result = b_.CreateLoad( + b_.CreateBitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), "partial_reduction_result"); - ir_builder_.CreateStore( - EmitShuffleDown(partial_reduction_result, - ir_builder_.getInt32(shuffle_distance), - &ir_builder_), - ir_builder_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); + CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) + << "Requires block size a multiple of the warp size, otherwise we " + "will read undefined elements."; + b_.CreateStore( + EmitFullWarpShuffleDown(partial_reduction_result, + b_.getInt32(shuffle_distance), &b_), + b_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], result_from_other_lane}, @@ -1826,20 +1580,18 @@ Status IrEmitterUnnested::EmitRowReduction( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpEQ(lane_id, index_typed_const(0)), - "lane_id_is_zero", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, - &ir_builder_); + b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", + &b_); + llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); for (int i = 0; i != num_reduces; ++i) { llvm::Value* output_address = GetIrArray(*output, *output, reduce_output_shapes[i]) .EmitArrayElementAddress( - llvm_ir::IrArray::Index( - y, - ShapeUtil::GetSubshape(output->shape(), - reduce_output_shapes[i]), - &ir_builder_), - &ir_builder_, "output_element_address"); + IrArray::Index(y, + ShapeUtil::GetSubshape( + output->shape(), reduce_output_shapes[i]), + &b_), + &b_, "output_element_address"); // We don't need to emit atomic operations if there is only one tile of // results. 'depth' is the z dimension, 'width' is the x dimension. if (z_tile_size >= depth && x_tile_size >= width) { @@ -1863,7 +1615,7 @@ Status IrEmitterUnnested::EmitRowReduction( static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(), ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &ir_builder_) + launch_dimensions, &b_) .EmitLoop(IrName(reduce), index_ty); } @@ -1982,31 +1734,33 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { BuildInitializerThunk(reduce)); std::vector<std::unique_ptr<Thunk>> thunks; thunks.push_back(std::move(initializer_thunk)); - thunks.push_back(BuildKernelThunk(reduce)); + thunks.push_back( + BuildKernelThunk(reduce, /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( MakeUnique<SequentialThunk>(std::move(thunks), reduce)); return EmitReductionToVector( - reduce, input->shape(), {[&](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*input, *reduce) - .EmitReadArrayElement(index, &ir_builder_); + reduce, input->shape(), {[&](const IrArray::Index& index) { + return GetIrArray(*input, *reduce).EmitReadArrayElement(index, &b_); }}, - {[&](const llvm_ir::IrArray::Index& index) { + {[&](const IrArray::Index& index) { return GetIrArray(*init_value, *reduce) - .EmitReadArrayElement(index, &ir_builder_); + .EmitReadArrayElement(index, &b_); }}, dimensions_to_reduce, {reducer}, {{}}, {}); } - thunk_sequence_->emplace_back(BuildKernelThunk(reduce)); + thunk_sequence_->emplace_back( + BuildKernelThunk(reduce, /*implements_whole_instruction=*/true)); return IrEmitter::HandleReduce(reduce); } Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { bool all_tuple_elements_have_buffer = c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { - return ir_emitter_context_->buffer_assignment().HasTopLevelAllocation( - tuple_element); + return ir_emitter_context_->buffer_assignment() + .GetUniqueTopLevelSlice(tuple_element) + .ok(); }); // Tuples (especially tuples that are the final result of a computation) can // be so huge that if we were to emit a kernel that took each tuple element as @@ -2027,7 +1781,8 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { tuple_element_buffers, GetAllocationSlice(*tuple), tuple)); return Status::OK(); } - thunk_sequence_->emplace_back(BuildKernelThunk(tuple)); + thunk_sequence_->emplace_back( + BuildKernelThunk(tuple, /*implements_whole_instruction=*/true)); return IrEmitter::HandleTuple(tuple); } @@ -2052,7 +1807,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( BuildInitializerThunk(select_and_scatter)); std::vector<std::unique_ptr<Thunk>> thunks; thunks.push_back(std::move(initializer_thunk)); - thunks.push_back(BuildKernelThunk(select_and_scatter)); + thunks.push_back(BuildKernelThunk(select_and_scatter, + /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( MakeUnique<SequentialThunk>(std::move(thunks), select_and_scatter)); @@ -2065,8 +1821,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( LaunchDimensions launch_dimensions = CalculateLaunchDimensions( source->shape(), ir_emitter_context_->device_description()); llvm::Type* index_type = GetIndexTypeForKernel( - select_and_scatter, launch_dimensions.launch_bound(), &ir_builder_); - auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + select_and_scatter, launch_dimensions.launch_bound(), &b_); + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_type, c); }; @@ -2089,114 +1845,106 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // selected_index = I // initialized_flag = true // output(selected_index) = scatter(output(selected_index), source(S)) - auto loop_body_emitter = - [=](const llvm_ir::IrArray::Index& source_index) -> Status { + auto loop_body_emitter = [=](const IrArray::Index& source_index) -> Status { // Allocate space to keep the currently selected value, its index, and a // boolean flag if the value is initialized. The initialized_flag is set // false. llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(operand_element_type, ir_emitter_context_->llvm_module()), - "selected_value_address", &ir_builder_); + "selected_value_address", &b_); llvm::Value* selected_index_address = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - index_type, index_typed_const(rank), "selected_index_address", - &ir_builder_); + index_type, index_typed_constant(rank), "selected_index_address", + &b_); llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( - ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_); - ir_builder_.CreateStore(ir_builder_.getInt1(false), - initialized_flag_address); + b_.getInt1Ty(), "initialized_flag_address", &b_); + b_.CreateStore(b_.getInt1(false), initialized_flag_address); // Create the inner loop to iterate over the window. - llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), - &ir_builder_, index_type); + llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_, + index_type); std::vector<int64> window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); CHECK_GT(dim.size(), 0); } - const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape( + const IrArray::Index window_index = window_loops.AddLoopsForShape( ShapeUtil::MakeShape(operand_element_type, window_size), "window"); llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), - &ir_builder_); + &b_); // Compute the operand index to visit and evaluate the condition whether the // operand index is within the bounds. The unsigned comparison includes // checking whether the operand index >= 0. - llvm_ir::IrArray::Index operand_index(index_type, source_index.size()); - llvm::Value* in_bounds_condition = ir_builder_.getInt1(true); + IrArray::Index operand_index(index_type, source_index.size()); + llvm::Value* in_bounds_condition = b_.getInt1(true); for (int64 i = 0; i < rank; ++i) { - llvm::Value* strided_index = ir_builder_.CreateNSWMul( - source_index[i], index_typed_const(window.dimensions(i).stride())); - operand_index[i] = ir_builder_.CreateNSWSub( - ir_builder_.CreateNSWAdd(strided_index, window_index[i]), - index_typed_const(window.dimensions(i).padding_low())); - llvm::Value* index_condition = ir_builder_.CreateICmpULT( + llvm::Value* strided_index = b_.CreateNSWMul( + source_index[i], index_typed_constant(window.dimensions(i).stride())); + operand_index[i] = b_.CreateNSWSub( + b_.CreateNSWAdd(strided_index, window_index[i]), + index_typed_constant(window.dimensions(i).padding_low())); + llvm::Value* index_condition = b_.CreateICmpULT( operand_index[i], - index_typed_const(ShapeUtil::GetDimension(operand->shape(), i))); - in_bounds_condition = - ir_builder_.CreateAnd(in_bounds_condition, index_condition); + index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i))); + in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); } CHECK(in_bounds_condition != nullptr); // Only need to do something if the operand index is within the bounds. // First check if the initialized_flag is set. llvm_ir::LlvmIfData if_in_bounds = - llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &ir_builder_); + llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_); llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( - ir_builder_.CreateLoad(initialized_flag_address), "initialized", - &ir_builder_); + b_.CreateLoad(initialized_flag_address), "initialized", &b_); // If the initialized_flag is false, initialize the selected value and index // with the currently visiting operand. - llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_); - const auto save_operand_index = [&]( - const llvm_ir::IrArray::Index& operand_index) { + llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &b_); + const auto save_operand_index = [&](const IrArray::Index& operand_index) { for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - ir_builder_.CreateInBoundsGEP(selected_index_address, - {ir_builder_.getInt32(i)}); - ir_builder_.CreateStore(operand_index[i], selected_index_address_slot); + b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); + b_.CreateStore(operand_index[i], selected_index_address_slot); } }; - llvm_ir::IrArray operand_array = GetIrArray(*operand, *select_and_scatter); + IrArray operand_array = GetIrArray(*operand, *select_and_scatter); llvm::Value* operand_data = - operand_array.EmitReadArrayElement(operand_index, &ir_builder_); - ir_builder_.CreateStore(operand_data, selected_value_address); + operand_array.EmitReadArrayElement(operand_index, &b_); + b_.CreateStore(operand_data, selected_value_address); save_operand_index(operand_index); - ir_builder_.CreateStore(ir_builder_.getInt1(true), - initialized_flag_address); + b_.CreateStore(b_.getInt1(true), initialized_flag_address); // If the initialized_flag is true, call the `select` function to // potentially update the selected value and index with the currently // visiting operand. - llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_); const Shape output_shape = ShapeUtil::MakeShape(PRED, {}); llvm::Value* operand_address = - operand_array.EmitArrayElementAddress(operand_index, &ir_builder_); + operand_array.EmitArrayElementAddress(operand_index, &b_); llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(PRED, ir_emitter_context_->llvm_module()), - "select_return_buffer", &ir_builder_); + "select_return_buffer", &b_); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *select_and_scatter->select(), {selected_value_address, operand_address}, select_return_buffer)); - llvm::Value* result = ir_builder_.CreateLoad(select_return_buffer); + llvm::Value* result = b_.CreateLoad(select_return_buffer); // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. - llvm::Value* cond = ir_builder_.CreateICmpNE( + llvm::Value* cond = b_.CreateICmpNE( result, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType( PRED, ir_emitter_context_->llvm_module()), 0), "boolean_predicate"); llvm_ir::LlvmIfData if_select_lhs = - llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &ir_builder_); - ir_builder_.CreateStore(ir_builder_.CreateLoad(operand_address), - selected_value_address); + llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_); + llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_); + b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address); save_operand_index(operand_index); // After iterating over the window elements, scatter the source element to @@ -2204,20 +1952,19 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // location is computed by calling the `scatter` function with the source // value and the current output value. llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), - &ir_builder_); - llvm_ir::IrArray::Index selected_index(operand_index.GetType()); + &b_); + IrArray::Index selected_index(operand_index.GetType()); for (int64 i = 0; i < rank; ++i) { - llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( - selected_index_address, {ir_builder_.getInt32(i)}); - selected_index.push_back( - ir_builder_.CreateLoad(selected_index_address_slot)); + llvm::Value* selected_index_address_slot = + b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); + selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); } llvm::Value* source_value_address = GetIrArray(*source, *select_and_scatter) - .EmitArrayElementAddress(source_index, &ir_builder_); + .EmitArrayElementAddress(source_index, &b_); llvm::Value* output_value_address = GetIrArray(*select_and_scatter, *select_and_scatter) - .EmitArrayElementAddress(selected_index, &ir_builder_); + .EmitArrayElementAddress(selected_index, &b_); return EmitAtomicOperationForNestedComputation( *select_and_scatter->scatter(), output_value_address, source_value_address); @@ -2232,7 +1979,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(), ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, source->shape(), - launch_dimensions, &ir_builder_) + launch_dimensions, &b_) .EmitLoop(IrName(select_and_scatter), index_type); } @@ -2260,15 +2007,92 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { } Status IrEmitterUnnested::HandleRng(HloInstruction* random) { - thunk_sequence_->push_back(BuildKernelThunk(random)); + thunk_sequence_->push_back( + BuildKernelThunk(random, /*implements_whole_instruction=*/true)); return IrEmitter::HandleRng(random); } Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { - thunk_sequence_->push_back(BuildKernelThunk(select)); + thunk_sequence_->push_back( + BuildKernelThunk(select, /*implements_whole_instruction=*/true)); return IrEmitter::HandleSelect(select); } +Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { + std::vector<std::unique_ptr<Thunk>> thunks; + auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; + if (values != nullptr) { + // TODO(b/26783907): Also sort the values by their corresponding key. + return Unimplemented("Key/Value Sort is not implemented on GPU"); + } + + // First copy the operand to the output, so that we can sort in-place. + // TODO(b/26783907): Share buffer of output and operand when it is possible. + if (sort->operand(0)->IsConstant()) { + thunks.push_back(MakeUnique<HostToDeviceCopyThunk>( + /*source_address=*/sort->operand(0)->literal().untyped_data(), + /*destination_buffer=*/GetAllocationSlice(*sort), + /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort)); + } else { + thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>( + /*source_address=*/GetAllocationSlice(*sort->operand(0)), + /*destination_buffer=*/GetAllocationSlice(*sort), + /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort)); + } + + int64 dimension_to_sort = sort->dimensions(0); + int64 dimension_to_sort_bound = sort->shape().dimensions(dimension_to_sort); + int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); + auto index_type = b_.getInt64Ty(); + + // Naive C++ code for the outer loops: + // + // for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound); + // ++stage) { + // int64 first_xor_mask = (1LL << (stage + 1)) - 1; + // SortInPlace(first_xor_mask); + // for (int64 mask = stage - 1; mask >= 0; --mask) { + // int64 later_xor_mask = 1LL << mask; + // SortInPlace(later_xor_mask); + // } + // } + // + // This follows the algorithm described on Wikipedia: + // https://en.wikipedia.org/wiki/Bitonic_sorter + + for (int64 stage = 0; stage < num_stages; ++stage) { + for (int64 mask = stage; mask >= 0; --mask) { + thunks.push_back( + BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + sort->shape(), ir_emitter_context_->device_description()); + UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), + ir_emitter_context_->llvm_module()); + + llvm::Value* xor_mask; + if (mask == stage) { + xor_mask = llvm::ConstantInt::get(index_type, (1LL << (stage + 1)) - 1); + } else { + xor_mask = llvm::ConstantInt::get(index_type, 1LL << mask); + } + + TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace( + dimension_to_sort, GetIrArray(*sort, *sort), IrName(sort), xor_mask, + &b_, &launch_dimensions)); + } + } + + thunk_sequence_->emplace_back( + MakeUnique<SequentialThunk>(std::move(thunks), sort)); + return Status::OK(); +} + +Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) { + thunk_sequence_->push_back( + BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true)); + return IrEmitter::HandleTupleSelect(tuple_select); +} + Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { if (hlo_module_config_.replica_count() != 1) { // TODO(b/33011107): Support nontrivial cross replica sum on GPU. @@ -2304,12 +2128,12 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>( /*source_address=*/GetAllocationSlice(*crs->operand(i)), /*destination_buffer=*/tuple_element_buffers.back(), - /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), crs)); + /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr)); } // Output a tuple of the buffers above. thunks.push_back(MakeUnique<TupleThunk>(tuple_element_buffers, - GetAllocationSlice(*crs), crs)); + GetAllocationSlice(*crs), nullptr)); thunk_sequence_->push_back( MakeUnique<SequentialThunk>(std::move(thunks), crs)); return Status::OK(); @@ -2324,6 +2148,11 @@ Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { return Status::OK(); } +Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) { + thunk_sequence_->emplace_back(BuildOutfeedThunk(outfeed)); + return Status::OK(); +} + // Figures out how to access the buffers for all subshapes of hlo's operands and // for hlo itself (i.e. all the buffers produced by HLO). // @@ -2443,7 +2272,8 @@ GetHloBufferSlices(const HloInstruction* hlo, } std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk( - const HloInstruction* inst, int unroll_factor) { + const HloInstruction* inst, bool implements_whole_instruction, + int unroll_factor) { const BufferAssignment& buffer_assn = ir_emitter_context_->buffer_assignment(); @@ -2508,18 +2338,16 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk( << " is found in slice " << slice.ToString() << " at GTE index " << gte_index.ToString(); - llvm::Value* loc = - ir_builder_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), - {ir_builder_.getInt64(slice.offset())}); + llvm::Value* loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), + {b_.getInt64(slice.offset())}); // If gte_index is nonempty, we have to dereference `loc` to get to the // value we're ultimately interested in. llvm::Type* int8_double_pointer = - llvm::PointerType::get(ir_builder_.getInt8PtrTy(), /*AddressSpace=*/0); + llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0); for (int64 idx : gte_index) { - loc = ir_builder_.CreateBitCast(loc, int8_double_pointer); - loc = ir_builder_.CreateLoad( - ir_builder_.CreateInBoundsGEP(loc, {ir_builder_.getInt64(idx)})); + loc = b_.CreateBitCast(loc, int8_double_pointer); + loc = b_.CreateLoad(b_.CreateInBoundsGEP(loc, {b_.getInt64(idx)})); } bindings_.BindHloToIrValue(*instr, loc, index); @@ -2531,11 +2359,12 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk( bindings_.SetTempBufferBase(kernel_args.at(*temp_buffer)); } else { bindings_.SetTempBufferBase( - llvm::ConstantPointerNull::get(ir_builder_.getInt8PtrTy())); + llvm::ConstantPointerNull::get(b_.getInt8PtrTy())); } return MakeUnique<KernelThunk>(buffers, llvm_ir::AsString(kernel->getName()), - inst, unroll_factor); + implements_whole_instruction ? inst : nullptr, + unroll_factor); } std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk( @@ -2569,7 +2398,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk( ShapeTree<BufferAllocation::Slice> slices(inst->shape()); slices.ForEachMutableElement( - [this, inst](const ShapeIndex& index, BufferAllocation::Slice* slice) { + [&](const ShapeIndex& index, BufferAllocation::Slice* slice) { *slice = ir_emitter_context_->buffer_assignment() .GetUniqueSlice(inst, index) .ConsumeValueOrDie(); @@ -2577,6 +2406,23 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk( return MakeUnique<InfeedThunk>(slices, inst); } +std::unique_ptr<Thunk> IrEmitterUnnested::BuildOutfeedThunk( + const HloInstruction* inst) { + CHECK_EQ(HloOpcode::kOutfeed, inst->opcode()); + + ShapeTree<BufferAllocation::Slice> slices(inst->operand(0)->shape()); + slices.ForEachMutableElement( + [&](const ShapeIndex& index, BufferAllocation::Slice* slice) { + auto status_or_slice = + ir_emitter_context_->buffer_assignment().GetUniqueSlice( + inst->operand(0), index); + if (status_or_slice.ok()) { + *slice = status_or_slice.ConsumeValueOrDie(); + } + }); + return MakeUnique<OutfeedThunk>(std::move(slices), inst); +} + namespace { double GetScalarConstantAsDouble(const Literal& literal) { switch (literal.shape().element_type()) { @@ -2692,6 +2538,11 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( init_value = hlo->operand(init_value->parameter_number()); } + // Initializer thunks don't implement a whole instruction, and we want to + // profile the whole instruction instead of the individual thunks it consists + // of. Therefore we pass nullptr as the HloInstruction* to the thunks we + // generate below. + // // In the common case, the initializer is a constant. In this case, emit a // device-memset call if we can. Currently StreamExecutor only supports // zeroing and 32-bit memsets. @@ -2705,7 +2556,8 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( ArraySlice<uint8> literal_bytes( reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes); if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { - return {MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo, index), hlo)}; + return { + MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo, index), nullptr)}; } // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by @@ -2723,7 +2575,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( } uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); return {MakeUnique<Memset32BitValueThunk>( - pattern32, GetAllocationSlice(*hlo, index), hlo)}; + pattern32, GetAllocationSlice(*hlo, index), nullptr)}; } // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit @@ -2734,12 +2586,13 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( uint32 word; memcpy(&word, literal_bytes.data(), sizeof(word)); return {MakeUnique<Memset32BitValueThunk>( - word, GetAllocationSlice(*hlo, index), hlo)}; + word, GetAllocationSlice(*hlo, index), nullptr)}; } } // Otherwise fall back to our slow initializer code. - std::unique_ptr<KernelThunk> kernel_thunk = BuildKernelThunk(hlo); + std::unique_ptr<KernelThunk> kernel_thunk = + BuildKernelThunk(hlo, /*implements_whole_instruction=*/false); LaunchDimensions launch_dimensions = CalculateLaunchDimensions(ShapeUtil::GetSubshape(hlo->shape(), index), ir_emitter_context_->device_description()); @@ -2751,12 +2604,11 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( TF_RETURN_IF_ERROR(HandleConstant(const_cast<HloInstruction*>(init_value))); } TF_RETURN_IF_ERROR(ParallelLoopEmitter( - [=](const llvm_ir::IrArray::Index& index) { + [=](const IrArray::Index& index) { return GetIrArray(*init_value, *hlo) - .EmitReadArrayElement(index, &ir_builder_); + .EmitReadArrayElement(index, &b_); }, - GetIrArray(*hlo, *hlo, index), launch_dimensions, - &ir_builder_) + GetIrArray(*hlo, *hlo, index), launch_dimensions, &b_) .EmitLoop(IrName(hlo))); // Clean up state left behind by emitting the loop above. (This is normally @@ -2940,41 +2792,546 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( ir_emitter_context_->llvm_module()); if (!hlo.IsMultiOutputFusion()) { return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo), - launch_dimensions, &ir_builder_, unroll_factor) - .EmitLoop(IrName(&hlo), - GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), - &ir_builder_)); + launch_dimensions, &b_, unroll_factor) + .EmitLoop( + IrName(&hlo), + GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), &b_)); } - // For multiple outputs fusion, we need to emit each operand and the root. - std::vector<llvm_ir::IrArray> output_arrays; + // For multioutput fusion, we need to emit each operand and the root. + std::vector<IrArray> output_arrays; for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { output_arrays.push_back(GetIrArray(hlo, hlo, {i})); } TF_RETURN_IF_ERROR( ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, - &ir_builder_, unroll_factor) + &b_, unroll_factor) .EmitLoop(IrName(&hlo), GetIndexTypeForKernel( - &hlo, launch_dimensions.launch_bound(), &ir_builder_))); + &hlo, launch_dimensions.launch_bound(), &b_))); std::vector<llvm::Value*> tuple_operand_ptrs; for (int64 i = 0; i < output_arrays.size(); ++i) { tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } - ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_, - module_); + b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_); return Status::OK(); } Status IrEmitterUnnested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) { - CHECK(Thunk::Kind::kKernel == LastThunk()->kind()); + CHECK_EQ(Thunk::Kind::kKernel, LastThunk()->kind()); return EmitTargetElementLoopInThunk(hlo, element_generator, static_cast<KernelThunk*>(LastThunk())); } +int IrEmitterUnnested::ConstructIrArrayForOutputs( + const HloInstruction& hlo, std::vector<IrArray>* output_arrays) { + int64 num_outputs = 1; + if (hlo.IsMultiOutputFusion()) { + num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); + output_arrays->reserve(num_outputs); + for (int64 i = 0; i < num_outputs; ++i) { + output_arrays->push_back(GetIrArray(hlo, hlo, {i})); + } + } else { + output_arrays->push_back(GetIrArray(hlo, hlo)); + } + return num_outputs; +} + +int IrEmitterUnnested::ConstructIrArrayForInputs( + const HloInstruction& hlo, std::vector<IrArray>* param_arrays) { + int64 num_params = hlo.operands().size(); + param_arrays->reserve(num_params); + for (const HloInstruction* param : hlo.operands()) { + param_arrays->push_back(GetIrArray(*param, hlo)); + } + return num_params; +} + +int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( + const HloInstruction& hlo, const std::vector<IrArray>& output_arrays, + tensorflow::gtl::ArraySlice<int64> reduced_output_dims, + std::vector<Shape>* output_reduced_shapes, + std::vector<IrArray>* output_in_reduced_shape_arrays) { + int64 num_outputs = 1; + if (hlo.IsMultiOutputFusion()) { + num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); + output_in_reduced_shape_arrays->reserve(num_outputs); + output_reduced_shapes->reserve(num_outputs); + for (int64 i = 0; i < num_outputs; ++i) { + output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( + ShapeUtil::GetSubshape(hlo.shape(), {i}).element_type(), + reduced_output_dims)); + output_in_reduced_shape_arrays->push_back( + output_arrays[i].CastToShape((*output_reduced_shapes)[i], &b_)); + } + } else { + output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( + hlo.shape().element_type(), reduced_output_dims)); + output_in_reduced_shape_arrays->push_back( + output_arrays[0].CastToShape((*output_reduced_shapes)[0], &b_)); + } + return num_outputs; +} + +int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( + const HloInstruction& hlo, const std::vector<IrArray>& param_arrays, + const std::vector<llvm::Value*>& param_buffers, + tensorflow::gtl::ArraySlice<int64> reduced_output_dims, + std::vector<Shape>* param_reduced_shapes, + std::vector<IrArray>* param_in_reduced_shape_arrays) { + int64 num_params = hlo.operands().size(); + param_in_reduced_shape_arrays->reserve(num_params); + param_reduced_shapes->reserve(num_params); + for (int64 id = 0; id < num_params; ++id) { + if (param_buffers[id] == nullptr) { + param_reduced_shapes->push_back(Shape()); + param_in_reduced_shape_arrays->push_back(IrArray()); + continue; + } + const HloInstruction* param = hlo.operand(id); + param_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( + param->shape().element_type(), + Permute({0, 2, 1}, reduced_output_dims))); + param_in_reduced_shape_arrays->push_back( + param_arrays[id].CastToShape((*param_reduced_shapes)[id], &b_)); + } + return num_params; +} + +namespace { + +// Reads thread_idx.x and converts it to a (y,x) coordinate, assuming that the +// thread lives within a square tile of size tile_size (so thread blocks are of +// size tile_size * tile_size). +std::tuple<llvm::Value*, llvm::Value*> CalculateYXCoordinateWithinTile( + llvm::IRBuilder<>* builder, llvm::Value* tile_size, + int64 threads_per_tile) { + // Calculate the starting element coordinate within a tile for the current + // thread, (y, x) from thread_id. + llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder); + llvm_ir::AddRangeMetadata(0, threads_per_tile, + llvm::cast<llvm::Instruction>(thread_id)); + thread_id = builder->CreateIntCast(thread_id, tile_size->getType(), + /*isSigned=*/true, "thread.id.x"); + auto x = builder->CreateURem(thread_id, tile_size); + auto y = builder->CreateUDiv(thread_id, tile_size); + return std::make_tuple(y, x); +} + +// Reads block_idx.x, casts it to type index_ty, and adds the assumption that +// it's in the range [0, num_blocks]. +llvm::Value* GetBlockIdx(llvm::IRBuilder<>* builder, llvm::Type* index_ty, + int64 num_blocks) { + llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, builder); + llvm_ir::AddRangeMetadata(0, num_blocks, + llvm::cast<llvm::Instruction>(block_id)); + return builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true, + "block.id.x"); +} + +// Emits code to process up to (tile_size/num_rows) elements in a tile, given +// `emit_elem_function` is the function to emit code to process one element, `y` +// and `x` are the coordinates for the first element to process, and `index` is +// the index for the origin of the tile. Emits bounds check to ensure that each +// processed element is within the boundary defined by `tile_width` and +// `tile_height`. +void EmitTiledElementalCodeWithBoundsCheck( + int64 tile_size, int64 num_rows, const IrArray::Index& index, + const string& loop_name, KernelSupportLibrary* ksl, + llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, + llvm::Value* tile_width, llvm::Value* tile_height, + const std::function<void(const IrArray::Index&, llvm::Value*)>& + emit_elem_function) { + llvm::Type* index_ty = tile_width->getType(); + // Emits a constant value with index type. + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + // Adds `addend` to the given `dim` of `index`. + auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { + index[dim] = builder->CreateAdd(index[dim], addend); + return index; + }; + + auto emit_full_tile = [&] { + for (int64 i = 0; i < tile_size; i += num_rows) { + auto source_idx = offset_dim(index, index_typed_constant(i), /*dim=*/1); + auto y_loc = builder->CreateAdd(index_typed_constant(i), y); + emit_elem_function(source_idx, y_loc); + } + }; + + auto emit_last_row = [&] { + ksl->IfReturnVoid("x_in_tile", builder->CreateICmpULT(x, tile_width), [&] { + // tile_height_upper_bound = + // ceil(tile_height / num_rows) * num_rows + auto tile_height_upper_bound = builder->CreateMul( + builder->CreateUDiv( + builder->CreateAdd(tile_height, + index_typed_constant(num_rows - 1)), + index_typed_constant(num_rows)), + index_typed_constant(num_rows)); + ksl->ForReturnVoid( + loop_name, /*start=*/index_typed_constant(0), + /*end=*/tile_height_upper_bound, + /*step=*/index_typed_constant(num_rows), [&](llvm::Value* y_indvar) { + auto y_loc = builder->CreateAdd(y_indvar, y); + ksl->IfReturnVoid( + "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), [&] { + emit_elem_function(offset_dim(index, y_indvar, /*dim=*/1), + y_loc); + }); + }); + }); + }; + ksl->IfReturnVoid( + "full_tile", + builder->CreateAnd( + builder->CreateICmpEQ(index_typed_constant(tile_size), tile_width), + builder->CreateICmpEQ(index_typed_constant(tile_size), tile_height)), + emit_full_tile, emit_last_row); +} +} // namespace + +// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose +// algorithm to improve the memory access patterns for the input parameters +// which have a shape that is a 0-2-1 transpose of the output tensors. +// +// For the purpose of tiling, the output tensors have a logical shape of three +// components 0-2-1 while the relevant input parameters have a logical shape of +// three components 0-1-2 in the order major to minor. The x- and y- dimensions +// of the tensors are tiled in square tiles of edge length `kTileSize`. Each +// thread block of `kTileSize` x `kNumRows` threads transposes one tile: each +// thread copies kTileSize/kNumRows elements from the input to a shared memory +// tile, then the otherwise "regular hlo kernel" reads from the shared memory +// instead of the original input. +// +// This is similar to the following CUDA algorithm in TensorFlow: +// https://goo.gl/MStRV6. +// +// `kTileSize` should usually be same as warp size. We currently choose 32 for +// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`. +// +// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient +// to launch fewer blocks so each transposes many tiles. +LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( + HloInstruction* hlo, tensorflow::gtl::ArraySlice<int64> reduced_output_dims, + tensorflow::gtl::ArraySlice<int64> tiled_param_ids) { + // Parameters for the tiling algorithm. + constexpr int64 kTileSize = 32; + constexpr int64 kNumRows = 4; + constexpr int64 kThreadsPerTile = kTileSize * kNumRows; + + // Construct IrArrays for the inputs and outputs. + std::vector<IrArray> output_arrays; + int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays); + std::vector<IrArray> param_arrays; + int64 num_params = ConstructIrArrayForInputs(*hlo, ¶m_arrays); + + // Allocate shared memory buffers to store the tiled inputs. + std::vector<llvm::Value*> param_shmem_buffers(num_params, nullptr); + for (int64 id : tiled_param_ids) { + const HloInstruction* param = hlo->operand(id); + // Add 1 to the minor dimension to reduce shared memory bank conflicts. + llvm::Type* tile_type = llvm::ArrayType::get( + llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType( + param->shape().element_type(), module_), + kTileSize + 1), + kTileSize); + const int kNVPTXSharedMemoryAddrSpace = 3; + auto* tile_base_ptr = new llvm::GlobalVariable( + *b_.GetInsertBlock()->getParent()->getParent(), tile_type, + /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, + llvm::UndefValue::get(tile_type), + llvm_ir::AsStringRef(IrName(hlo, StrCat("tile", id))), nullptr, + llvm::GlobalValue::NotThreadLocal, kNVPTXSharedMemoryAddrSpace); + param_shmem_buffers[id] = tile_base_ptr; + VLOG(3) << "Added shmem buffer for parameter " << id << ": " + << llvm_ir::DumpToString(*tile_base_ptr); + } + + // The 0-2-1 shape of the tiling scheme is the reduced shape of the HLO result + // for the purpose of tiling. Calculate the logical output dimensions in the + // tile from the reduced output dimensions. + std::vector<int64> output_dims_in_tiles = std::vector<int64>( + reduced_output_dims.begin(), reduced_output_dims.end()); + CHECK_EQ(output_dims_in_tiles.size(), 3); + for (int i = 1; i < 3; ++i) { + output_dims_in_tiles[i] = + CeilOfRatio<int64>(output_dims_in_tiles[i], kTileSize); + } + const int64 num_tiles = + c_accumulate(output_dims_in_tiles, 1, std::multiplies<int64>()); + LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile); + + llvm::Type* index_ty = + GetIndexTypeForKernel(hlo, launch_dimensions.launch_bound(), &b_); + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + + // Cast each output IrArray to its corresponding reduced shape and keep the + // reduced shape live during IR emission. + std::vector<IrArray> output_in_reduced_shape_arrays; + std::vector<Shape> output_reduced_shapes; + CHECK_EQ(ConstructOutputReducedShapeAndCastOutputIrArrayToShape( + *hlo, output_arrays, reduced_output_dims, &output_reduced_shapes, + &output_in_reduced_shape_arrays), + num_outputs); + + // For each tiled parameter, cast its input IrArray to the corresponding + // reduced shape and keep the reduced shape live during IR emission. + std::vector<IrArray> param_in_reduced_shape_arrays; + std::vector<Shape> param_reduced_shapes; + CHECK_EQ(ConstructInputReducedShapeAndCastInputIrArrayToShape( + *hlo, param_arrays, param_shmem_buffers, reduced_output_dims, + ¶m_reduced_shapes, ¶m_in_reduced_shape_arrays), + num_params); + + // Calculate the starting element coordinate within a tile for the current + // thread, (y, x) from thread_id. + llvm::Value* x; + llvm::Value* y; + std::tie(y, x) = CalculateYXCoordinateWithinTile( + &b_, index_typed_constant(kTileSize), kThreadsPerTile); + + // Calculate the index for the current output tile from block_id. + const IrArray::Index output_tile_index( + GetBlockIdx(&b_, index_ty, num_tiles), + ShapeUtil::MakeShapeWithDescendingLayout(PRED /*arbitrary*/, + output_dims_in_tiles), + &b_); + + // Output tile origin is the index for the first element of the current output + // tile. + const IrArray::Index output_tile_origin = [&] { + IrArray::Index index = output_tile_index; + for (int i = 1; i < 3; ++i) { + index[i] = + b_.CreateMul(output_tile_index[i], index_typed_constant(kTileSize), + "tile_origin." + std::to_string(i)); + } + return index; + }(); + + // Calculate the input tile origin from the output tile origin. + const IrArray::Index input_tile_origin( + Permute({0, 2, 1}, output_tile_origin.multidim())); + + // Calculate the current output tile bounds in each of the logical dimensions. + std::vector<llvm::Value*> output_tile_bounds(3); + for (int i = 1; i < 3; ++i) { + // Only last row or column may not have full size. + output_tile_bounds[i] = b_.CreateSelect( + b_.CreateICmpEQ(output_tile_index[i], + index_typed_constant(output_dims_in_tiles[i] - 1)), + index_typed_constant(reduced_output_dims[i] - + (output_dims_in_tiles[i] - 1) * kTileSize), + index_typed_constant(kTileSize), "kTileSize"); + } + + KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); + + // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck. + auto emit_tiled_elemental_code_with_bounds_check = + [&](const IrArray::Index& index, const string& loop_name, + llvm::Value* tile_width, llvm::Value* tile_height, + const std::function<void(const IrArray::Index&, llvm::Value*)>& + emit_elem_function) { + EmitTiledElementalCodeWithBoundsCheck( + kTileSize, kNumRows, index, loop_name, &ksl, &b_, y, x, tile_width, + tile_height, emit_elem_function); + }; + + // Adds `addend` to the given `dim` of `index`. + auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { + index[dim] = b_.CreateAdd(index[dim], addend); + return index; + }; + const IrArray::Index input_index = + offset_dim(offset_dim(input_tile_origin, x, /*dim=*/2), y, /*dim=*/1); + + // Copy input parameter values to shared memory buffers: + // tile[y, x] = input[index] + emit_tiled_elemental_code_with_bounds_check( + input_index, "input", output_tile_bounds[1], output_tile_bounds[2], + [&](const IrArray::Index& index, llvm::Value* y_loc) { + for (int64 id : tiled_param_ids) { + IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id]; + llvm::Value* shmem_buffer = param_shmem_buffers[id]; + // TODO(jlebar): Add AA metadata to this store. Tile buffers are + // global variables, so LLVM can't infer much about it. + b_.CreateStore( + input_in_logical_shape.EmitReadArrayElement(index, &b_, + "input_element"), + b_.CreateGEP(shmem_buffer, {index_typed_constant(0), y_loc, x})); + } + }); + + // Wait for all threads to reach this point, lest we copy a value from tile to + // output before the other thread copies it from input to tile. + // This is `__syncthreads` in CUDA. + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); + + llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x); + + const IrArray::Index output_index = + offset_dim(offset_dim(output_tile_origin, x, /*dim=*/2), y, /*dim=*/1); + + // Write to output[index] by emitting code like normal, except that values for + // the tiled parameters are read from the shmem buffers. + if (hlo->opcode() == HloOpcode::kCopy) { + emit_tiled_elemental_code_with_bounds_check( + output_index, "output", output_tile_bounds[2], output_tile_bounds[1], + [&](const IrArray::Index& index, llvm::Value* y_loc) { + // TODO(jlebar): Add AA metadata to this load. + llvm::Instruction* load_from_shmem_buffer = b_.CreateLoad( + b_.CreateGEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}), + "output_element"); + output_in_reduced_shape_arrays[0].EmitWriteArrayElement( + index, load_from_shmem_buffer, &b_); + }); + } else { + CHECK_EQ(hlo->opcode(), HloOpcode::kFusion); + emit_tiled_elemental_code_with_bounds_check( + output_index, "output", output_tile_bounds[2], output_tile_bounds[1], + [&](const IrArray::Index& index, llvm::Value* y_loc) { + GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, + GetNestedComputer()); + FusedIrEmitter fused_emitter(param_arrays, &elem_emitter); + tiled_param_info.set_y(y_loc); + fused_emitter.SetTiledParameterInfo(&tiled_param_info); + TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter)); + IrArray::Index untiled_index = llvm_ir::GetUnreducedOutputIndex( + index, output_reduced_shapes[0], output_arrays[0].GetShape(), + &b_); + const llvm_ir::ElementGenerator& output_generator = + fused_emitter.GetRootGenerator(); + llvm::Value* output_value = + output_generator(untiled_index).ValueOrDie(); + if (hlo->IsMultiOutputFusion()) { + CHECK(output_value->getType()->isStructTy()); + CHECK_EQ(output_value->getType()->getStructNumElements(), + output_in_reduced_shape_arrays.size()); + for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) { + output_in_reduced_shape_arrays[i].EmitWriteArrayElement( + index, b_.CreateExtractValue(output_value, i), &b_); + } + } else { + output_in_reduced_shape_arrays[0].EmitWriteArrayElement( + index, output_value, &b_); + } + }); + } + + // For multioutput fusion, emit a tuple with all the individual outputs. + if (hlo->IsMultiOutputFusion()) { + std::vector<llvm::Value*> tuple_operand_ptrs; + for (int64 i = 0; i < output_arrays.size(); ++i) { + tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); + } + llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &b_, + module_); + } + + return launch_dimensions; +} + +bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { + HloOpcode opcode = hlo->opcode(); + CHECK(opcode == HloOpcode::kFusion || opcode == HloOpcode::kCopy); + CHECK(opcode != HloOpcode::kFusion || + hlo->fusion_kind() == HloInstruction::FusionKind::kLoop) + << "Only loop fusions are supported."; + + const Shape& output_shape = hlo->IsMultiOutputFusion() + ? ShapeUtil::GetSubshape(hlo->shape(), {0}) + : hlo->shape(); + + // If the output_shape is reduced to 021 shape, find all the parameters of the + // hlo that are in the corresponding 012 shape. + std::vector<int64> params_012; + optional<std::vector<int64>> reduced_dims_021; + for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); + ++operand_idx) { + HloInstruction* operand = hlo->mutable_operand(operand_idx); + auto find_transpose_result = + llvm_ir::FindTranspose021(operand->shape(), output_shape); + if (!find_transpose_result.has_value()) { + continue; + } + const std::vector<int64>& curr_reduced_dims_021 = *find_transpose_result; + if (!reduced_dims_021.has_value()) { + reduced_dims_021 = curr_reduced_dims_021; + } + if (!ContainersEqual(*reduced_dims_021, curr_reduced_dims_021)) { + // There is more than one possible transpose. Instead of picking one + // transpose, we simply give up here. + return false; + } + params_012.push_back(operand_idx); + } + + if (!reduced_dims_021.has_value()) { + return false; + } + + if ((*reduced_dims_021)[1] < kMinDimensionToTransposeTiled || + (*reduced_dims_021)[2] < kMinDimensionToTransposeTiled) { + return false; + } + + // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the + // elements are of size 4 bytes), and CUDA has an architectural limit of 48kb + // shared memory per SM. (This is increased to 96kb in Volta, but we don't + // use this, in part because it eats into our L1 cache space.) + // + // For correctness we need to ensure that we don't make more than 48kb worth + // of shmem tiles per block. And for performance, we'd probably like to use + // significantly less, so that we can fit more than one block at a time on a + // gpu core. + // + // We say without benchmarks that we want at least 3 threads/block, + // corresponding to 3 shmem tiles if the elements are 32 bits wide. We choose + // which params get the shmem transpose treatment arbitrarily; it's not clear + // if there's a Right Choice. + // + // This is only sound if tiled transposes are the only place where we use + // shared memory in fusions. If in the future other fusile ops use shared + // memory, we'll have to adjust this heuristic. + constexpr int kMinBlocksPerCore = 3; + constexpr int64 kShmemPerCore = 48 * 1024; + int64 shmem_used = 0; + for (int64 i = 0; i < params_012.size(); ++i) { + const HloInstruction* operand = hlo->operand(params_012[i]); + shmem_used += + 32 * 33 * + ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type()); + + if (kMinBlocksPerCore * shmem_used > kShmemPerCore) { + // Erase this element and everything after it from params_012. + params_012.resize(i); + break; + } + } + + VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString(); + thunk_sequence_->emplace_back( + BuildKernelThunk(hlo, /*implements_whole_instruction=*/true)); + const LaunchDimensions launch_dimensions = + EmitHlo021Tile(hlo, *reduced_dims_021, params_012); + UpdateLaunchDimensions(launch_dimensions, LastThunk(), + ir_emitter_context_->llvm_module()); + + return true; +} + } // namespace gpu } // namespace xla |