diff options
author | Bixia Zheng <bixia@google.com> | 2018-07-04 15:55:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-04 15:58:20 -0700 |
commit | eb8a110e52a8e8ed7b9db48e01115d3110e918c2 (patch) | |
tree | ff984481a05efb9efd0c78c0c7065b71d268c8e6 /tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | |
parent | 548601aee34f3eb0e6857b5ce6df25db1b50a4b4 (diff) |
[XLA:GPU] Enhance the tiled 0-2-1 transpose algorithm to handle fusion.
Add class TiledParameterInfo to provide information for FusedIrEmitter to read
the content of a tiled parameter from the tile buffer instead of the original
input memory.
Reimplement the tiled 0-2-1 transpose algorithm for copy instructions only in a
more general way so that it can handle both fusion instructions and copy
instructions.
The original tiled 0-2-1 transpose implementation incorrectly used
(tile_size+1) rows for a tile buffer to reduce share memory bank conflicts while
it should be (tile_size+1) column instead. This is a performance issue and is
fixed in the new implementation.
The original tiled 0-2-1 transpose implementation did not generate LLVM alias
meta data for the loads and stores of the tensors. This was due to a bug where
function IrArray::CastToShape miss copying meta data to the new IrArray
object. This is also a performance issue and is fixed in this change.
Modified KernelSupportLibrary to support emitting an if-stmt with a given
branch name prefix.
Add test cases to test the new implementation.
PiperOrigin-RevId: 203310403
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | 826 |
1 files changed, 504 insertions, 322 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 945ad371e8..186672047b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -697,6 +697,10 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { return Status::OK(); } + if (CheckAndEmitHloWithTile021(fusion)) { + return Status::OK(); + } + CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); int unroll_factor = ComputeMaxUnrollFactor(fusion); @@ -705,302 +709,6 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { return IrEmitter::HandleFusion(fusion); } -namespace { - -// Returns the indices of the first elements of all consecutive subarrays of the -// given array. For example: -// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4} -std::vector<size_t> ConsecutiveSegments(tensorflow::gtl::ArraySlice<int64> xs) { - std::vector<size_t> is = {0}; - for (size_t i = 1; i < xs.size(); ++i) { - if (1 != xs[i] - xs[i - 1]) { - is.push_back(i); - } - } - return is; -} - -// Merges the sequences of dimensions of the given shape which start at the -// given indices `segs`. -Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs, - const Shape& shape) { - std::vector<int64> dimensions; - for (size_t i = 1; i <= segs.size(); ++i) { - dimensions.push_back(std::accumulate( - shape.dimensions().begin() + segs[i - 1], - shape.dimensions().begin() + - (segs.size() == i ? shape.dimensions().size() : segs[i]), - 1, std::multiplies<int64>())); - } - return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), - dimensions); -} - -// Returns whether the given shapes and permutation are a 0-2-1 transpose, and -// if so, the normalized and rank-reduced shapes. The shapes must have the same -// dimensions, so this considers layout only. -// -// This function recognizes higher-rank transposes which are elementwise -// equivalent to a 0-2-1 transpose. -std::tuple<bool, Shape, Shape> IsTranspose021(const Shape& a, const Shape& b) { - CHECK(ShapeUtil::Compatible(a, b)); - std::vector<int64> perm(a.dimensions().size()); - { - auto layout_a_orig = LayoutUtil::MinorToMajor(a); - std::vector<int64> layout_a(layout_a_orig.rbegin(), layout_a_orig.rend()); - auto layout_b_orig = LayoutUtil::MinorToMajor(b); - std::vector<int64> layout_b(layout_b_orig.rbegin(), layout_b_orig.rend()); - for (size_t i = 0; i < perm.size(); ++i) { - perm[i] = PositionInContainer(layout_b, layout_a[i]); - } - } - auto segs = ConsecutiveSegments(perm); - Shape norm_a = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a); - Shape norm_b = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(b); - if (3 == segs.size() && 0 == perm[0]) { - Shape reduced_a = MergeDimensions(segs, norm_a); - Shape reduced_b = ShapeUtil::MakeShapeWithDescendingLayout( - b.element_type(), - Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions()))); - return std::make_tuple(true, reduced_a, reduced_b); - } - return std::make_tuple(false, ShapeUtil::MakeNil(), ShapeUtil::MakeNil()); -} - -// Returns whether the given shapes are potentially of a 0-2-1 transpose. -// As 0-2-1 is a self-inverse permutation, which shape is input or output is -// arbitrary. -bool AreShapesForTranspose021(const Shape& a, const Shape& b) { - return 3 == b.dimensions().size() && - ShapeUtil::Compatible( - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a), - ShapeUtil::PermuteDimensions( - {0, 2, 1}, - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - b))); -} - -// Emits a tiled 0-2-1 transpose, assuming both input and output lain out from -// major to minor. The x- and y- dimensions are tiled in square tiles of edge -// length `tile_size`. Each thread block of `tile_size` x `num_rows` threads -// transposes one tile: each thread copies a row from the input to a shared -// memory tile, then copies a column from the shared memory tile to the output. -// -// `tile_size` should usually be same as warp size. -// -// Returns (number of tiles = number of thread blocks needed). -// -// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient -// to launch fewer blocks so each transposes many tiles, and -// in any case, the number of blocks we can launch is limited. -// -// This is the same algorithm in CUDA: -// https://github.com/tensorflow/tensorflow/blob/d2693c8a70567cc78b2e8a9ac8020d321620ca83/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L189 -int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, - const int64 tile_size, const int64 num_rows, - llvm::IRBuilder<>* builder) { - // Adds `addend` to the given `dim` of `index`. - auto offset_dim = [builder](llvm_ir::IrArray::Index index, - llvm::Value* addend, int64 dim) { - index[dim] = builder->CreateAdd(index[dim], addend); - return index; - }; - - CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape())); - - Shape input_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input.GetShape()); - Shape output_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - output.GetShape()); - input = input.CastToShape(input_shape, builder); - output = output.CastToShape(output_shape, builder); - - llvm::Type* tile_type = llvm::ArrayType::get( - llvm::ArrayType::get(input.GetElementLlvmType(), tile_size), - // One extra here to avoid share memory bank conflict - tile_size + 1); - auto* tile = new llvm::GlobalVariable( - *builder->GetInsertBlock()->getParent()->getParent(), tile_type, - /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, - llvm::UndefValue::get(tile_type), "tile", nullptr, - llvm::GlobalValue::NotThreadLocal, - /*AddressSpace=*/3 /* GPU shared memory */); - - // let x = threadIdx.x - llvm::Value* x = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, num_rows * tile_size, - static_cast<llvm::Instruction*>(x)); - x = builder->CreateIntCast(x, builder->getInt64Ty(), /*isSigned=*/true, - "thread.id.x"); - - // computing logical thread ids - // logical_x = x % tile_size - auto logical_x = builder->CreateURem(x, builder->getInt64(tile_size)); - - // logical_y = x / tile_size - auto logical_y = builder->CreateUDiv(x, builder->getInt64(tile_size)); - - // `emit_cp` emits equivalent to following pseudocode: - // if (tile_size == tile_width && tile_size == tile_height) { - // unroll for (i in range(0, tile_size, num_rows)) { - // emit_cp_element(index + {0, i, 0}, y + logical_y); - // } - // } else if (x < tile_width) { - // tile_height_upperbound = ceil(tile_height / num_rows) * num_rows; - // for (i in range(0, tile_height_upperbound, num_rows)) { - // y_loc = i + logical_y; - // if (y_loc < tile_height) - // emit_cp_element(index + {0, i, 0}, y_loc); - // } - // } - // - // We use this to emit both the copy from input to tile and the copy from tile - // to output. - // - // `index` is the origin of the row or column in the input or output array. - // - // `emit_cp_element(index, y)` emits code to copy a single element between the - // tile and the input or output array, where `y` is the `y`-position in the - // tile, whether which is row or column is a function of whether we're copying - // from input or to output, and `index` is the index into the input or output - // array. - auto emit_cp_tile = [builder, tile_size, &offset_dim, num_rows, logical_x, - logical_y]( - std::function<void(const llvm_ir::IrArray::Index&, - llvm::Value*)> - emit_cp_element, - llvm::Value* tile_width, llvm::Value* tile_height, - const llvm_ir::IrArray::Index& index, - const string& loop_name) { - llvm_ir::LlvmIfData if_not_last_row = llvm_ir::EmitIfThenElse( - builder->CreateAnd( - builder->CreateICmpEQ(builder->getInt64(tile_size), tile_width), - builder->CreateICmpEQ(builder->getInt64(tile_size), tile_height)), - "not_last_row", builder); - builder->SetInsertPoint(if_not_last_row.true_block->getTerminator()); - for (int64 i = 0; i < tile_size; i += num_rows) { - auto source_idx = offset_dim(index, builder->getInt64(i), /*dim=*/1); - auto y_loc = builder->CreateAdd(builder->getInt64(i), logical_y); - emit_cp_element(source_idx, y_loc); - } - builder->SetInsertPoint(if_not_last_row.false_block->getTerminator()); - llvm_ir::LlvmIfData if_in_tile = llvm_ir::EmitIfThenElse( - builder->CreateICmpULT(logical_x, tile_width), "x_in_tile", builder); - builder->SetInsertPoint(if_in_tile.true_block->getTerminator()); - - // tile_height_upper_bound = ceil(tile_height / num_rows) * num_rows - auto tile_height_upper_bound = builder->CreateMul( - builder->CreateUDiv( - builder->CreateAdd(tile_height, builder->getInt64(num_rows - 1)), - builder->getInt64(num_rows)), - builder->getInt64(num_rows)); - - auto loop = llvm_ir::ForLoop::EmitForLoop( - loop_name, builder->getInt64(0), tile_height_upper_bound, - builder->getInt64(num_rows), builder); - llvm_ir::SetToFirstInsertPoint(loop->GetHeaderBasicBlock(), builder); - builder->SetInsertPoint(loop->GetBodyBasicBlock()->getTerminator()); - - auto y_loc = builder->CreateAdd(loop->GetIndVarValue(), logical_y); - auto if_y_in_tile = llvm_ir::EmitIfThenElse( - builder->CreateICmpULT(y_loc, tile_height), "y_in_tile", builder); - builder->SetInsertPoint(if_y_in_tile.true_block->getTerminator()); - - emit_cp_element(offset_dim(index, loop->GetIndVarValue(), /*dim=*/1), - y_loc); - builder->SetInsertPoint(if_not_last_row.after_block->getTerminator()); - }; - - auto input_dims_in_tiles = input_shape.dimensions(); - // Unpermuted dimensions are untiled. - for (int i = 1; i < 3; ++i) { - input_dims_in_tiles[i] = - CeilOfRatio<int64>(input_dims_in_tiles[i], tile_size); - } - int64 num_tiles = - std::accumulate(input_dims_in_tiles.begin(), input_dims_in_tiles.end(), 1, - std::multiplies<int64>()); - const llvm_ir::IrArray::Index input_tile_index( - /*linear=*/builder->CreateIntCast( - llvm_ir::AddRangeMetadata( - 0, num_tiles, - static_cast<llvm::Instruction*>(llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, - builder))), - builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"), - ShapeUtil::MakeShapeWithDescendingLayout( - PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)), - builder); - const llvm_ir::IrArray::Index input_tile_origin = ({ - llvm_ir::IrArray::Index index = input_tile_index; - for (int i = 1; i < 3; ++i) { - index[i] = builder->CreateMul(index[i], builder->getInt64(tile_size), - "tile_origin." + std::to_string(i)); - } - index; - }); - const llvm_ir::IrArray::Index input_index = - offset_dim(offset_dim(input_tile_origin, logical_x, /*dim=*/2), logical_y, - /*dim=*/1); - std::vector<llvm::Value*> tile_dims(input_shape.dimensions().size()); - // Only last row or column may not have full size. - for (int i = 1; i < 3; ++i) { - tile_dims[i] = builder->CreateSelect( - builder->CreateICmpEQ(input_tile_index[i], - builder->getInt64(input_dims_in_tiles[i] - 1)), - builder->getInt64(input_shape.dimensions(i) - - (input_dims_in_tiles[i] - 1) * tile_size), - builder->getInt64(tile_size), "tile_size"); - } - - // Load data from input memory to shared memory tile. - emit_cp_tile( - // tile[y, x] = input_array[index] - [builder, tile, &input, logical_x](const llvm_ir::IrArray::Index& index, - llvm::Value* y) { - builder->CreateStore( - input.EmitReadArrayElement(index, builder, "input_element"), - builder->CreateGEP(tile, {builder->getInt64(0), y, logical_x})); - }, - tile_dims[2], tile_dims[1], input_index, "input"); - - // Wait for all threads to reach this point, lest we copy a value from tile to - // output before the other thread copies it from input to tile. - // This is `__syncthreads` in CUDA. - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, builder); - - const llvm_ir::IrArray::Index output_tile_index( - Permute({0, 2, 1}, input_tile_index.multidim())); - const llvm_ir::IrArray::Index output_tile_origin( - Permute({0, 2, 1}, input_tile_origin.multidim())); - const llvm_ir::IrArray::Index output_index = - offset_dim(offset_dim(output_tile_origin, logical_x, /*dim=*/2), - logical_y, /*dim=*/1); - - // Store data from shared memory tile to output memory. - emit_cp_tile( - // output_array[index] = tile[x, y] - [builder, tile, &output, logical_x](const llvm_ir::IrArray::Index& index, - llvm::Value* y) { - output.EmitWriteArrayElement( - index, - builder->CreateLoad( - builder->CreateGEP(tile, {builder->getInt64(0), logical_x, y}), - "output_element"), - builder); - }, - tile_dims[1], tile_dims[2], output_index, "output"); - - return num_tiles; -} - -} // namespace - Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { if (ImplementedAsHostToDeviceMemcpy(ir_emitter_context_->buffer_assignment(), *copy)) { @@ -1012,26 +720,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { thunk_sequence_->emplace_back(BuildDeviceToDeviceCopyThunk(copy)); return Status::OK(); } - bool is_transpose_021; - Shape reduced_input_shape, reduced_output_shape; - std::tie(is_transpose_021, reduced_input_shape, reduced_output_shape) = - IsTranspose021(copy->operand(0)->shape(), copy->shape()); - if (is_transpose_021 && - reduced_input_shape.dimensions(1) >= kMinDimensionToTransposeTiled && - reduced_input_shape.dimensions(2) >= kMinDimensionToTransposeTiled) { - thunk_sequence_->emplace_back( - BuildKernelThunk(copy, /*implements_whole_instruction=*/true)); - VLOG(3) << "Emitting tiled 0-2-1 transposition"; - constexpr int64 tile_size = 32; - constexpr int64 num_rows = 8; - int64 num_tiles = EmitTranspose021Tiled( - GetIrArray(*copy->operand(0), *copy) - .CastToShape(reduced_input_shape, &ir_builder_), - GetIrArray(*copy, *copy) - .CastToShape(reduced_output_shape, &ir_builder_), - tile_size, num_rows, &ir_builder_); - UpdateLaunchDimensions(LaunchDimensions(num_tiles, num_rows * tile_size), - LastThunk(), ir_emitter_context_->llvm_module()); + if (CheckAndEmitHloWithTile021(copy)) { return Status::OK(); } @@ -1082,9 +771,7 @@ Status IrEmitterUnnested::EmitReductionToScalar( tiled_input_shape, ir_emitter_context_->device_description()); llvm::Type* index_ty = GetIndexTypeForKernel( - reduce, - launch_dimensions.block_count() * launch_dimensions.threads_per_block(), - &ir_builder_); + reduce, launch_dimensions.launch_bound(), &ir_builder_); auto index_typed_const = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); @@ -1636,9 +1323,7 @@ Status IrEmitterUnnested::EmitRowReduction( LaunchDimensions launch_dimensions = CalculateLaunchDimensions( tiled_input_shape, ir_emitter_context_->device_description()); llvm::Type* index_ty = GetIndexTypeForKernel( - reduce, - launch_dimensions.block_count() * launch_dimensions.threads_per_block(), - &ir_builder_); + reduce, launch_dimensions.launch_bound(), &ir_builder_); auto index_typed_const = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); @@ -3004,5 +2689,502 @@ Status IrEmitterUnnested::EmitTargetElementLoop( static_cast<KernelThunk*>(LastThunk())); } +int IrEmitterUnnested::ConstructIrArrayForOutputs( + const HloInstruction& hlo, std::vector<llvm_ir::IrArray>* output_arrays) { + int64 num_outputs = 1; + if (hlo.IsMultiOutputFusion()) { + num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); + output_arrays->reserve(num_outputs); + for (int64 i = 0; i < num_outputs; ++i) { + output_arrays->push_back(GetIrArray(hlo, hlo, {i})); + } + } else { + output_arrays->push_back(GetIrArray(hlo, hlo)); + } + return num_outputs; +} + +int IrEmitterUnnested::ConstructIrArrayForInputs( + const HloInstruction& hlo, std::vector<llvm_ir::IrArray>* param_arrays) { + int64 num_params = hlo.operands().size(); + param_arrays->reserve(num_params); + for (const HloInstruction* param : hlo.operands()) { + param_arrays->push_back(GetIrArray(*param, hlo)); + } + return num_params; +} + +int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( + const HloInstruction& hlo, + const std::vector<llvm_ir::IrArray>& output_arrays, + tensorflow::gtl::ArraySlice<int64> reduced_output_dims, + std::vector<Shape>* output_reduced_shapes, + std::vector<llvm_ir::IrArray>* output_in_reduced_shape_arrays) { + int64 num_outputs = 1; + if (hlo.IsMultiOutputFusion()) { + num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); + output_in_reduced_shape_arrays->reserve(num_outputs); + output_reduced_shapes->reserve(num_outputs); + for (int64 i = 0; i < num_outputs; ++i) { + output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( + ShapeUtil::GetSubshape(hlo.shape(), {i}).element_type(), + reduced_output_dims)); + output_in_reduced_shape_arrays->push_back(output_arrays[i].CastToShape( + (*output_reduced_shapes)[i], &ir_builder_)); + } + } else { + output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( + hlo.shape().element_type(), reduced_output_dims)); + output_in_reduced_shape_arrays->push_back(output_arrays[0].CastToShape( + (*output_reduced_shapes)[0], &ir_builder_)); + } + return num_outputs; +} + +int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( + const HloInstruction& hlo, + const std::vector<llvm_ir::IrArray>& param_arrays, + const std::vector<llvm::Value*>& param_buffers, + tensorflow::gtl::ArraySlice<int64> reduced_output_dims, + std::vector<Shape>* param_reduced_shapes, + std::vector<llvm_ir::IrArray>* param_in_reduced_shape_arrays) { + int64 num_params = hlo.operands().size(); + param_in_reduced_shape_arrays->reserve(num_params); + param_reduced_shapes->reserve(num_params); + for (int64 id = 0; id < num_params; ++id) { + if (param_buffers[id] == nullptr) { + param_reduced_shapes->push_back(Shape()); + param_in_reduced_shape_arrays->push_back(llvm_ir::IrArray()); + continue; + } + const HloInstruction* param = hlo.operand(id); + param_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( + param->shape().element_type(), + Permute({0, 2, 1}, reduced_output_dims))); + param_in_reduced_shape_arrays->push_back(param_arrays[id].CastToShape( + (*param_reduced_shapes)[id], &ir_builder_)); + } + return num_params; +} + +namespace { + +// Emits the intrinsic call and meta data for thread_id. Calculates the (y,x) +// coordinate from thread_id and returns the (y, x) coordinate. +std::tuple<llvm::Value*, llvm::Value*> CalculateYXCoordinateWithinTile( + llvm::IRBuilder<>* builder, llvm::Value* tile_size, + int64 threads_per_tile) { + // Calculate the starting element coordinate within a tile for the current + // thread, (y, x) from thread_id. + llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder); + llvm_ir::AddRangeMetadata(0, threads_per_tile, + static_cast<llvm::Instruction*>(thread_id)); + thread_id = builder->CreateIntCast(thread_id, tile_size->getType(), + /*isSigned=*/true, "thread.id.x"); + auto x = builder->CreateURem(thread_id, tile_size); + auto y = builder->CreateUDiv(thread_id, tile_size); + + return std::make_tuple(y, x); +} + +// Emits the intrinsic call and meta data for block_id. Returns the block_id. +llvm::Value* BuildBlockId(llvm::IRBuilder<>* builder, llvm::Type* index_ty, + int64 num_tiles) { + llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, builder); + llvm_ir::AddRangeMetadata(0, num_tiles, + static_cast<llvm::Instruction*>(block_id)); + return builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true, + "block.id.x"); +} + +// Emits code to process up to (tile_size/num_rows) elements in a tile, given +// `emit_ele_function` is the function to emit code to process one element, `y` +// and `x` are the coordinates for the first element to process, and `index` is +// the index for the origin of the tile. Emits bound check to ensure that each +// processed element is within the boundary defined by `tile_width` and +// `tile_height`. +void EmitCodeWithBoundCheck( + int64 tile_size, int64 num_rows, + std::function<void(const llvm_ir::IrArray::Index&, llvm::Value*)> + emit_ele_function, + const llvm_ir::IrArray::Index& index, const string& loop_name, + KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Value* tile_width, llvm::Value* tile_height) { + llvm::Type* index_ty = tile_width->getType(); + // Emits a constant value with index type. + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + // Adds `addend` to the given `dim` of `index`. + auto offset_dim = [&](llvm_ir::IrArray::Index index, llvm::Value* addend, + int64 dim) { + index[dim] = builder->CreateAdd(index[dim], addend); + return index; + }; + auto emit_not_last_row_true = [&]() { + for (int64 i = 0; i < tile_size; i += num_rows) { + auto source_idx = offset_dim(index, index_typed_const(i), /*dim=*/1); + auto y_loc = builder->CreateAdd(index_typed_const(i), y); + emit_ele_function(source_idx, y_loc); + } + }; + auto emit_not_last_row_false = [&]() { + ksl->IfReturnVoid( + "x_in_tile", builder->CreateICmpULT(x, tile_width), [&]() { + // tile_height_upper_bound = ceil(tile_height / num_rows) * + // num_rows + auto tile_height_upper_bound = builder->CreateMul( + builder->CreateUDiv( + builder->CreateAdd(tile_height, + index_typed_const(num_rows - 1)), + index_typed_const(num_rows)), + index_typed_const(num_rows)); + + ksl->ForReturnVoid( + loop_name, + /*start=*/index_typed_const(0), + /*end=*/tile_height_upper_bound, + /*step=*/index_typed_const(num_rows), [&](llvm::Value* y_indvar) { + auto y_loc = builder->CreateAdd(y_indvar, y); + ksl->IfReturnVoid( + "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), + [&]() { + emit_ele_function(offset_dim(index, y_indvar, /*dim=*/1), + y_loc); + }); + }); + }); + }; + ksl->IfReturnVoid( + "not_last_row", + builder->CreateAnd( + builder->CreateICmpEQ(index_typed_const(tile_size), tile_width), + builder->CreateICmpEQ(index_typed_const(tile_size), tile_height)), + emit_not_last_row_true, emit_not_last_row_false); +} + +} // namespace + +// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose +// algorithm to improve the memory access patterns for the input parameters +// which have a shape that is a 0-2-1 transpose of the output tensors. +// +// For the purpose of tiling, the output tensors have a logical shape of three +// compoents 0-2-1 while the relevant input parameters have a logical shape of +// three components 0-1-2 in the order major to minor. The x- and y- dimensions +// of the tensores are tiled in square tiles of edge length `tile_size`. Each +// thread block of `tile_size` x `num_rows` threads transposes one tile: each +// thread copies tile_size/num_rows elements from the input to a shared memory +// tile, then the otherwise "regular hlo kernel" reads from the shared memory +// instead of the original input. +// +// This is similar to the algorithm in CUDA: +// https://github.com/tensorflow/tensorflow/blob/d2693c8a70567cc78b2e8a9ac8020d321620ca83/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L189 +// +// `tile_size` should usually be same as warp size. We currently choose +// 32 for `tile_size` and 4 for `num_rows`. The CUDA algorithm usesa 8 for +// `num_rows`. +// +// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient +// to launch fewer blocks so each transposes many tiles, and +// in any case, the number of blocks we can launch is limited. +// +LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( + HloInstruction* hlo, tensorflow::gtl::ArraySlice<int64> reduced_output_dims, + tensorflow::gtl::ArraySlice<int64> tiled_param_ids) { + // Parameters for the tiling algorithm. + constexpr int64 tile_size = 32; + constexpr int64 num_rows = 4; + constexpr int64 threads_per_tile = tile_size * num_rows; + + // Construct the IrArray for the outputs. + std::vector<llvm_ir::IrArray> output_arrays; + int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays); + + // Construct the IrArray for the inputs and initialize the array to store + // the tile buffers for the tiled parameters. + std::vector<llvm_ir::IrArray> param_arrays; + int64 num_params = ConstructIrArrayForInputs(*hlo, ¶m_arrays); + std::vector<llvm::Value*> param_buffers; + param_buffers.reserve(num_params); + for (int i = 0; i < num_params; ++i) { + param_buffers.push_back(nullptr); + } + + // Allocate share memory buffers to store the tiled parameters. + for (int64 id : tiled_param_ids) { + const HloInstruction* param = hlo->operand(id); + // Add 1 to the minor dimension to reduce share memory bank conflict. + llvm::Type* tile_type = llvm::ArrayType::get( + llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType( + param->shape().element_type(), module_), + tile_size + 1), + tile_size); + auto* tile_base_ptr = new llvm::GlobalVariable( + *ir_builder_.GetInsertBlock()->getParent()->getParent(), tile_type, + /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, + llvm::UndefValue::get(tile_type), "tile", nullptr, + llvm::GlobalValue::NotThreadLocal, + /*AddressSpace=*/3 /* GPU shared memory */); + param_buffers[id] = tile_base_ptr; + VLOG(3) << "Add buffer for parameter " << id << " " + << llvm_ir::DumpToString(*tile_base_ptr); + } + + // The 0-2-1 shape of the tiling scheme is the reduced shape of the HLO result + // for the purpose of tiling. Calculate the logical output dimensions in tile + // from the reduced output dimensions. + std::vector<int64> output_dims_in_tiles = std::vector<int64>( + reduced_output_dims.begin(), reduced_output_dims.end()); + CHECK_EQ(output_dims_in_tiles.size(), 3); + for (int i = 1; i < 3; ++i) { + output_dims_in_tiles[i] = + CeilOfRatio<int64>(output_dims_in_tiles[i], tile_size); + } + const int64 num_tiles = + std::accumulate(output_dims_in_tiles.begin(), output_dims_in_tiles.end(), + 1, std::multiplies<int64>()); + LaunchDimensions launch_dimensions(num_tiles, threads_per_tile); + + llvm::Type* index_ty = GetIndexTypeForKernel( + hlo, launch_dimensions.launch_bound(), &ir_builder_); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + + // Cast each output IrArray to its corresponding reduced shape and keep the + // reduced shape live during IR emission. + std::vector<llvm_ir::IrArray> output_in_reduced_shape_arrays; + std::vector<Shape> output_reduced_shapes; + CHECK_EQ(ConstructOutputReducedShapeAndCastOutputIrArrayToShape( + *hlo, output_arrays, reduced_output_dims, &output_reduced_shapes, + &output_in_reduced_shape_arrays), + num_outputs); + + // For each tiled parameter, cast its input IrArray to the corresponding + // reduced shape and keep the reduced shape live during IR emission. + std::vector<llvm_ir::IrArray> param_in_reduced_shape_arrays; + std::vector<Shape> param_reduced_shapes; + CHECK_EQ(ConstructInputReducedShapeAndCastInputIrArrayToShape( + *hlo, param_arrays, param_buffers, reduced_output_dims, + ¶m_reduced_shapes, ¶m_in_reduced_shape_arrays), + num_params); + + // Calculate the starting element coordinate within a tile for the current + // thread, (y, x) from thread_id. + llvm::Value* x; + llvm::Value* y; + std::tie(y, x) = CalculateYXCoordinateWithinTile( + &ir_builder_, index_typed_const(tile_size), threads_per_tile); + + // Calculate the index for the current output tile from block_id. + const llvm_ir::IrArray::Index output_tile_index( + BuildBlockId(&ir_builder_, index_ty, num_tiles), + ShapeUtil::MakeShapeWithDescendingLayout(PRED /*arbitrary*/, + output_dims_in_tiles), + &ir_builder_); + + // Output tile origin is the index for the first element of the current output + // tile. + const llvm_ir::IrArray::Index output_tile_origin = ({ + llvm_ir::IrArray::Index index = output_tile_index; + for (int i = 1; i < 3; ++i) { + index[i] = ir_builder_.CreateMul(index[i], index_typed_const(tile_size), + "tile_origin." + std::to_string(i)); + } + index; + }); + + // Calculate the input tile origin from the output tile origin. + const llvm_ir::IrArray::Index input_tile_origin( + Permute({0, 2, 1}, output_tile_origin.multidim())); + + // A utility to add `addend` to the given `dim` of `index`. + auto offset_dim = [&](llvm_ir::IrArray::Index index, llvm::Value* addend, + int64 dim) { + index[dim] = ir_builder_.CreateAdd(index[dim], addend); + return index; + }; + + // Calculate the current output tile bounds in each of the logical dimension. + std::vector<llvm::Value*> output_tile_bounds(3); + for (int i = 1; i < 3; ++i) { + // Only last row or column may not have full size. + output_tile_bounds[i] = ir_builder_.CreateSelect( + ir_builder_.CreateICmpEQ( + output_tile_index[i], + index_typed_const(output_dims_in_tiles[i] - 1)), + index_typed_const(reduced_output_dims[i] - + (output_dims_in_tiles[i] - 1) * tile_size), + index_typed_const(tile_size), "tile_size"); + } + + KernelSupportLibrary ksl(&ir_builder_, llvm_ir::UnrollMode::kDefaultUnroll); + // A wrapper of EmitCodeWithBoundCheck, to pass in the parameters common to + // both call sites to of the routine. + auto emit_code_with_bound_check = + [&](std::function<void(const llvm_ir::IrArray::Index&, llvm::Value*)> + emit_ele_function, + const llvm_ir::IrArray::Index& index, const string& loop_name, + llvm::Value* tile_width, llvm::Value* tile_height) { + EmitCodeWithBoundCheck(tile_size, num_rows, emit_ele_function, index, + loop_name, &ksl, &ir_builder_, y, x, tile_width, + tile_height); + }; + + auto copy_tiled_params_to_buffers = [&](const llvm_ir::IrArray::Index& index, + llvm::Value* y_loc) { + for (int64 id : tiled_param_ids) { + llvm_ir::IrArray& input_in_logical_shape = + param_in_reduced_shape_arrays[id]; + llvm::Value* buffer = param_buffers[id]; + llvm::Instruction* store_to_buffer = ir_builder_.CreateStore( + input_in_logical_shape.EmitReadArrayElement(index, &ir_builder_, + "input_element"), + ir_builder_.CreateGEP(buffer, {index_typed_const(0), y_loc, x})); + param_arrays[id].AnnotateBufferLoadStoreInstructionWithMetadata( + store_to_buffer); + } + }; + + const llvm_ir::IrArray::Index input_index = + offset_dim(offset_dim(input_tile_origin, x, /*dim=*/2), y, /*dim=*/1); + // Copy input parameter values to shared memory buffers. + emit_code_with_bound_check( + // tile[y, x] = input[index] + copy_tiled_params_to_buffers, input_index, "input", output_tile_bounds[1], + output_tile_bounds[2]); + + // Wait for all threads to reach this point, lest we copy a value from tile to + // output before the other thread copies it from input to tile. + // This is `__syncthreads` in CUDA. + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, + &ir_builder_); + + llvm_ir::TiledParameterInfo tiled_param_info(param_buffers, y, x); + // A helper to emit an otherwise usual kernel for an hlo instruction except + // that values for the tiled parameters are read from tile buffers. + auto emit_kernel_using_buffers = [&](const llvm_ir::IrArray::Index& index, + llvm::Value* y_loc) { + if (hlo->opcode() == HloOpcode::kCopy) { + llvm::Instruction* load_from_buffer = ir_builder_.CreateLoad( + ir_builder_.CreateGEP(param_buffers[0], + {ir_builder_.getInt64(0), x, y_loc}), + "output_element"); + param_arrays[0].AnnotateBufferLoadStoreInstructionWithMetadata( + load_from_buffer); + output_in_reduced_shape_arrays[0].EmitWriteArrayElement( + index, load_from_buffer, &ir_builder_); + } else { + CHECK_EQ(hlo->opcode(), HloOpcode::kFusion); + GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, + &ir_builder_, GetNestedComputer()); + FusedIrEmitter fused_emitter(param_arrays, &elem_emitter); + tiled_param_info.set_y(y_loc); + fused_emitter.SetTiledParameterInfo(&tiled_param_info); + TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter)); + llvm_ir::IrArray::Index untiled_index = llvm_ir::GetUnreducedOutputIndex( + index, output_reduced_shapes[0], output_arrays[0].GetShape(), + &ir_builder_); + const llvm_ir::ElementGenerator& output_generator = + fused_emitter.GetRootGenerator(); + llvm::Value* output_value = output_generator(untiled_index).ValueOrDie(); + if (hlo->IsMultiOutputFusion()) { + CHECK(output_value->getType()->isStructTy()) + << "This BodyEmitter is for multi-output fusion, but target " + "element " + "generator does not produce values of struct type."; + CHECK_EQ(output_value->getType()->getStructNumElements(), + output_in_reduced_shape_arrays.size()); + for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) { + output_in_reduced_shape_arrays[i].EmitWriteArrayElement( + index, ir_builder_.CreateExtractValue(output_value, i), + &ir_builder_); + } + } else { + output_in_reduced_shape_arrays[0].EmitWriteArrayElement( + index, output_value, &ir_builder_); + } + } + }; + + const llvm_ir::IrArray::Index output_index = + offset_dim(offset_dim(output_tile_origin, x, /*dim=*/2), y, /*dim=*/1); + emit_code_with_bound_check( + // output[index] = ... + emit_kernel_using_buffers, output_index, "output", output_tile_bounds[2], + output_tile_bounds[1]); + + // For multiple output fusion, emit a tuple with all the individual outputs. + if (hlo->IsMultiOutputFusion()) { + std::vector<llvm::Value*> tuple_operand_ptrs; + for (int64 i = 0; i < output_arrays.size(); ++i) { + tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); + } + llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &ir_builder_, + module_); + } + + return launch_dimensions; +} + +bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { + HloOpcode opcode = hlo->opcode(); + CHECK(opcode == HloOpcode::kFusion || opcode == HloOpcode::kCopy); + const Shape& output_shape = hlo->IsMultiOutputFusion() + ? ShapeUtil::GetSubshape(hlo->shape(), {0}) + : hlo->shape(); + // If the output_shape is reduced to 021 shape find all the parameters of the + // hlo that are in the corresponding 012 shape. + std::vector<int64> params_012; + bool found = false; + std::vector<int64> reduced_dims_021; + int64 index = 0; + for (HloInstruction* operand : hlo->operands()) { + auto find_transpose_result = + llvm_ir::FindTranspose021(operand->shape(), output_shape); + if (find_transpose_result.has_value()) { + const std::vector<int64>& curr_reduced_dims_021 = *find_transpose_result; + if (found) { + if (!ContainersEqual(reduced_dims_021, curr_reduced_dims_021)) { + // There are more than one possible transpose. Instead of picking one + // transpose, we simply give up here. + found = false; + break; + } + } else { + found = true; + reduced_dims_021 = curr_reduced_dims_021; + } + if (found) { + params_012.push_back(index); + } + } + ++index; + } + + if (!found) { + return false; + } + + if (reduced_dims_021[1] < kMinDimensionToTransposeTiled || + reduced_dims_021[2] < kMinDimensionToTransposeTiled) { + return false; + } + + VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString(); + thunk_sequence_->emplace_back( + BuildKernelThunk(hlo, + /*implements_whole_instruction=*/true)); + const LaunchDimensions launch_dimensions = + EmitHlo021Tile(hlo, reduced_dims_021, params_012); + UpdateLaunchDimensions(launch_dimensions, LastThunk(), + ir_emitter_context_->llvm_module()); + + return true; +} + } // namespace gpu } // namespace xla |