aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
diff options
context:
space:
mode:
authorGravatar Bixia Zheng <bixia@google.com>2018-07-04 15:55:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-04 15:58:20 -0700
commiteb8a110e52a8e8ed7b9db48e01115d3110e918c2 (patch)
treeff984481a05efb9efd0c78c0c7065b71d268c8e6 /tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
parent548601aee34f3eb0e6857b5ce6df25db1b50a4b4 (diff)
[XLA:GPU] Enhance the tiled 0-2-1 transpose algorithm to handle fusion.
Add class TiledParameterInfo to provide information for FusedIrEmitter to read the content of a tiled parameter from the tile buffer instead of the original input memory. Reimplement the tiled 0-2-1 transpose algorithm for copy instructions only in a more general way so that it can handle both fusion instructions and copy instructions. The original tiled 0-2-1 transpose implementation incorrectly used (tile_size+1) rows for a tile buffer to reduce share memory bank conflicts while it should be (tile_size+1) column instead. This is a performance issue and is fixed in the new implementation. The original tiled 0-2-1 transpose implementation did not generate LLVM alias meta data for the loads and stores of the tensors. This was due to a bug where function IrArray::CastToShape miss copying meta data to the new IrArray object. This is also a performance issue and is fixed in this change. Modified KernelSupportLibrary to support emitting an if-stmt with a given branch name prefix. Add test cases to test the new implementation. PiperOrigin-RevId: 203310403
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc826
1 files changed, 504 insertions, 322 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 945ad371e8..186672047b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -697,6 +697,10 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
return Status::OK();
}
+ if (CheckAndEmitHloWithTile021(fusion)) {
+ return Status::OK();
+ }
+
CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop);
int unroll_factor = ComputeMaxUnrollFactor(fusion);
@@ -705,302 +709,6 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
return IrEmitter::HandleFusion(fusion);
}
-namespace {
-
-// Returns the indices of the first elements of all consecutive subarrays of the
-// given array. For example:
-// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
-std::vector<size_t> ConsecutiveSegments(tensorflow::gtl::ArraySlice<int64> xs) {
- std::vector<size_t> is = {0};
- for (size_t i = 1; i < xs.size(); ++i) {
- if (1 != xs[i] - xs[i - 1]) {
- is.push_back(i);
- }
- }
- return is;
-}
-
-// Merges the sequences of dimensions of the given shape which start at the
-// given indices `segs`.
-Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs,
- const Shape& shape) {
- std::vector<int64> dimensions;
- for (size_t i = 1; i <= segs.size(); ++i) {
- dimensions.push_back(std::accumulate(
- shape.dimensions().begin() + segs[i - 1],
- shape.dimensions().begin() +
- (segs.size() == i ? shape.dimensions().size() : segs[i]),
- 1, std::multiplies<int64>()));
- }
- return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
- dimensions);
-}
-
-// Returns whether the given shapes and permutation are a 0-2-1 transpose, and
-// if so, the normalized and rank-reduced shapes. The shapes must have the same
-// dimensions, so this considers layout only.
-//
-// This function recognizes higher-rank transposes which are elementwise
-// equivalent to a 0-2-1 transpose.
-std::tuple<bool, Shape, Shape> IsTranspose021(const Shape& a, const Shape& b) {
- CHECK(ShapeUtil::Compatible(a, b));
- std::vector<int64> perm(a.dimensions().size());
- {
- auto layout_a_orig = LayoutUtil::MinorToMajor(a);
- std::vector<int64> layout_a(layout_a_orig.rbegin(), layout_a_orig.rend());
- auto layout_b_orig = LayoutUtil::MinorToMajor(b);
- std::vector<int64> layout_b(layout_b_orig.rbegin(), layout_b_orig.rend());
- for (size_t i = 0; i < perm.size(); ++i) {
- perm[i] = PositionInContainer(layout_b, layout_a[i]);
- }
- }
- auto segs = ConsecutiveSegments(perm);
- Shape norm_a =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
- Shape norm_b =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(b);
- if (3 == segs.size() && 0 == perm[0]) {
- Shape reduced_a = MergeDimensions(segs, norm_a);
- Shape reduced_b = ShapeUtil::MakeShapeWithDescendingLayout(
- b.element_type(),
- Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions())));
- return std::make_tuple(true, reduced_a, reduced_b);
- }
- return std::make_tuple(false, ShapeUtil::MakeNil(), ShapeUtil::MakeNil());
-}
-
-// Returns whether the given shapes are potentially of a 0-2-1 transpose.
-// As 0-2-1 is a self-inverse permutation, which shape is input or output is
-// arbitrary.
-bool AreShapesForTranspose021(const Shape& a, const Shape& b) {
- return 3 == b.dimensions().size() &&
- ShapeUtil::Compatible(
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a),
- ShapeUtil::PermuteDimensions(
- {0, 2, 1},
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
- b)));
-}
-
-// Emits a tiled 0-2-1 transpose, assuming both input and output lain out from
-// major to minor. The x- and y- dimensions are tiled in square tiles of edge
-// length `tile_size`. Each thread block of `tile_size` x `num_rows` threads
-// transposes one tile: each thread copies a row from the input to a shared
-// memory tile, then copies a column from the shared memory tile to the output.
-//
-// `tile_size` should usually be same as warp size.
-//
-// Returns (number of tiles = number of thread blocks needed).
-//
-// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient
-// to launch fewer blocks so each transposes many tiles, and
-// in any case, the number of blocks we can launch is limited.
-//
-// This is the same algorithm in CUDA:
-// https://github.com/tensorflow/tensorflow/blob/d2693c8a70567cc78b2e8a9ac8020d321620ca83/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L189
-int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
- const int64 tile_size, const int64 num_rows,
- llvm::IRBuilder<>* builder) {
- // Adds `addend` to the given `dim` of `index`.
- auto offset_dim = [builder](llvm_ir::IrArray::Index index,
- llvm::Value* addend, int64 dim) {
- index[dim] = builder->CreateAdd(index[dim], addend);
- return index;
- };
-
- CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape()));
-
- Shape input_shape =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
- input.GetShape());
- Shape output_shape =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
- output.GetShape());
- input = input.CastToShape(input_shape, builder);
- output = output.CastToShape(output_shape, builder);
-
- llvm::Type* tile_type = llvm::ArrayType::get(
- llvm::ArrayType::get(input.GetElementLlvmType(), tile_size),
- // One extra here to avoid share memory bank conflict
- tile_size + 1);
- auto* tile = new llvm::GlobalVariable(
- *builder->GetInsertBlock()->getParent()->getParent(), tile_type,
- /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
- llvm::UndefValue::get(tile_type), "tile", nullptr,
- llvm::GlobalValue::NotThreadLocal,
- /*AddressSpace=*/3 /* GPU shared memory */);
-
- // let x = threadIdx.x
- llvm::Value* x = llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder);
- llvm_ir::AddRangeMetadata(0, num_rows * tile_size,
- static_cast<llvm::Instruction*>(x));
- x = builder->CreateIntCast(x, builder->getInt64Ty(), /*isSigned=*/true,
- "thread.id.x");
-
- // computing logical thread ids
- // logical_x = x % tile_size
- auto logical_x = builder->CreateURem(x, builder->getInt64(tile_size));
-
- // logical_y = x / tile_size
- auto logical_y = builder->CreateUDiv(x, builder->getInt64(tile_size));
-
- // `emit_cp` emits equivalent to following pseudocode:
- // if (tile_size == tile_width && tile_size == tile_height) {
- // unroll for (i in range(0, tile_size, num_rows)) {
- // emit_cp_element(index + {0, i, 0}, y + logical_y);
- // }
- // } else if (x < tile_width) {
- // tile_height_upperbound = ceil(tile_height / num_rows) * num_rows;
- // for (i in range(0, tile_height_upperbound, num_rows)) {
- // y_loc = i + logical_y;
- // if (y_loc < tile_height)
- // emit_cp_element(index + {0, i, 0}, y_loc);
- // }
- // }
- //
- // We use this to emit both the copy from input to tile and the copy from tile
- // to output.
- //
- // `index` is the origin of the row or column in the input or output array.
- //
- // `emit_cp_element(index, y)` emits code to copy a single element between the
- // tile and the input or output array, where `y` is the `y`-position in the
- // tile, whether which is row or column is a function of whether we're copying
- // from input or to output, and `index` is the index into the input or output
- // array.
- auto emit_cp_tile = [builder, tile_size, &offset_dim, num_rows, logical_x,
- logical_y](
- std::function<void(const llvm_ir::IrArray::Index&,
- llvm::Value*)>
- emit_cp_element,
- llvm::Value* tile_width, llvm::Value* tile_height,
- const llvm_ir::IrArray::Index& index,
- const string& loop_name) {
- llvm_ir::LlvmIfData if_not_last_row = llvm_ir::EmitIfThenElse(
- builder->CreateAnd(
- builder->CreateICmpEQ(builder->getInt64(tile_size), tile_width),
- builder->CreateICmpEQ(builder->getInt64(tile_size), tile_height)),
- "not_last_row", builder);
- builder->SetInsertPoint(if_not_last_row.true_block->getTerminator());
- for (int64 i = 0; i < tile_size; i += num_rows) {
- auto source_idx = offset_dim(index, builder->getInt64(i), /*dim=*/1);
- auto y_loc = builder->CreateAdd(builder->getInt64(i), logical_y);
- emit_cp_element(source_idx, y_loc);
- }
- builder->SetInsertPoint(if_not_last_row.false_block->getTerminator());
- llvm_ir::LlvmIfData if_in_tile = llvm_ir::EmitIfThenElse(
- builder->CreateICmpULT(logical_x, tile_width), "x_in_tile", builder);
- builder->SetInsertPoint(if_in_tile.true_block->getTerminator());
-
- // tile_height_upper_bound = ceil(tile_height / num_rows) * num_rows
- auto tile_height_upper_bound = builder->CreateMul(
- builder->CreateUDiv(
- builder->CreateAdd(tile_height, builder->getInt64(num_rows - 1)),
- builder->getInt64(num_rows)),
- builder->getInt64(num_rows));
-
- auto loop = llvm_ir::ForLoop::EmitForLoop(
- loop_name, builder->getInt64(0), tile_height_upper_bound,
- builder->getInt64(num_rows), builder);
- llvm_ir::SetToFirstInsertPoint(loop->GetHeaderBasicBlock(), builder);
- builder->SetInsertPoint(loop->GetBodyBasicBlock()->getTerminator());
-
- auto y_loc = builder->CreateAdd(loop->GetIndVarValue(), logical_y);
- auto if_y_in_tile = llvm_ir::EmitIfThenElse(
- builder->CreateICmpULT(y_loc, tile_height), "y_in_tile", builder);
- builder->SetInsertPoint(if_y_in_tile.true_block->getTerminator());
-
- emit_cp_element(offset_dim(index, loop->GetIndVarValue(), /*dim=*/1),
- y_loc);
- builder->SetInsertPoint(if_not_last_row.after_block->getTerminator());
- };
-
- auto input_dims_in_tiles = input_shape.dimensions();
- // Unpermuted dimensions are untiled.
- for (int i = 1; i < 3; ++i) {
- input_dims_in_tiles[i] =
- CeilOfRatio<int64>(input_dims_in_tiles[i], tile_size);
- }
- int64 num_tiles =
- std::accumulate(input_dims_in_tiles.begin(), input_dims_in_tiles.end(), 1,
- std::multiplies<int64>());
- const llvm_ir::IrArray::Index input_tile_index(
- /*linear=*/builder->CreateIntCast(
- llvm_ir::AddRangeMetadata(
- 0, num_tiles,
- static_cast<llvm::Instruction*>(llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {},
- builder))),
- builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"),
- ShapeUtil::MakeShapeWithDescendingLayout(
- PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)),
- builder);
- const llvm_ir::IrArray::Index input_tile_origin = ({
- llvm_ir::IrArray::Index index = input_tile_index;
- for (int i = 1; i < 3; ++i) {
- index[i] = builder->CreateMul(index[i], builder->getInt64(tile_size),
- "tile_origin." + std::to_string(i));
- }
- index;
- });
- const llvm_ir::IrArray::Index input_index =
- offset_dim(offset_dim(input_tile_origin, logical_x, /*dim=*/2), logical_y,
- /*dim=*/1);
- std::vector<llvm::Value*> tile_dims(input_shape.dimensions().size());
- // Only last row or column may not have full size.
- for (int i = 1; i < 3; ++i) {
- tile_dims[i] = builder->CreateSelect(
- builder->CreateICmpEQ(input_tile_index[i],
- builder->getInt64(input_dims_in_tiles[i] - 1)),
- builder->getInt64(input_shape.dimensions(i) -
- (input_dims_in_tiles[i] - 1) * tile_size),
- builder->getInt64(tile_size), "tile_size");
- }
-
- // Load data from input memory to shared memory tile.
- emit_cp_tile(
- // tile[y, x] = input_array[index]
- [builder, tile, &input, logical_x](const llvm_ir::IrArray::Index& index,
- llvm::Value* y) {
- builder->CreateStore(
- input.EmitReadArrayElement(index, builder, "input_element"),
- builder->CreateGEP(tile, {builder->getInt64(0), y, logical_x}));
- },
- tile_dims[2], tile_dims[1], input_index, "input");
-
- // Wait for all threads to reach this point, lest we copy a value from tile to
- // output before the other thread copies it from input to tile.
- // This is `__syncthreads` in CUDA.
- llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, builder);
-
- const llvm_ir::IrArray::Index output_tile_index(
- Permute({0, 2, 1}, input_tile_index.multidim()));
- const llvm_ir::IrArray::Index output_tile_origin(
- Permute({0, 2, 1}, input_tile_origin.multidim()));
- const llvm_ir::IrArray::Index output_index =
- offset_dim(offset_dim(output_tile_origin, logical_x, /*dim=*/2),
- logical_y, /*dim=*/1);
-
- // Store data from shared memory tile to output memory.
- emit_cp_tile(
- // output_array[index] = tile[x, y]
- [builder, tile, &output, logical_x](const llvm_ir::IrArray::Index& index,
- llvm::Value* y) {
- output.EmitWriteArrayElement(
- index,
- builder->CreateLoad(
- builder->CreateGEP(tile, {builder->getInt64(0), logical_x, y}),
- "output_element"),
- builder);
- },
- tile_dims[1], tile_dims[2], output_index, "output");
-
- return num_tiles;
-}
-
-} // namespace
-
Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
if (ImplementedAsHostToDeviceMemcpy(ir_emitter_context_->buffer_assignment(),
*copy)) {
@@ -1012,26 +720,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
thunk_sequence_->emplace_back(BuildDeviceToDeviceCopyThunk(copy));
return Status::OK();
}
- bool is_transpose_021;
- Shape reduced_input_shape, reduced_output_shape;
- std::tie(is_transpose_021, reduced_input_shape, reduced_output_shape) =
- IsTranspose021(copy->operand(0)->shape(), copy->shape());
- if (is_transpose_021 &&
- reduced_input_shape.dimensions(1) >= kMinDimensionToTransposeTiled &&
- reduced_input_shape.dimensions(2) >= kMinDimensionToTransposeTiled) {
- thunk_sequence_->emplace_back(
- BuildKernelThunk(copy, /*implements_whole_instruction=*/true));
- VLOG(3) << "Emitting tiled 0-2-1 transposition";
- constexpr int64 tile_size = 32;
- constexpr int64 num_rows = 8;
- int64 num_tiles = EmitTranspose021Tiled(
- GetIrArray(*copy->operand(0), *copy)
- .CastToShape(reduced_input_shape, &ir_builder_),
- GetIrArray(*copy, *copy)
- .CastToShape(reduced_output_shape, &ir_builder_),
- tile_size, num_rows, &ir_builder_);
- UpdateLaunchDimensions(LaunchDimensions(num_tiles, num_rows * tile_size),
- LastThunk(), ir_emitter_context_->llvm_module());
+ if (CheckAndEmitHloWithTile021(copy)) {
return Status::OK();
}
@@ -1082,9 +771,7 @@ Status IrEmitterUnnested::EmitReductionToScalar(
tiled_input_shape, ir_emitter_context_->device_description());
llvm::Type* index_ty = GetIndexTypeForKernel(
- reduce,
- launch_dimensions.block_count() * launch_dimensions.threads_per_block(),
- &ir_builder_);
+ reduce, launch_dimensions.launch_bound(), &ir_builder_);
auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_ty, c);
@@ -1636,9 +1323,7 @@ Status IrEmitterUnnested::EmitRowReduction(
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
tiled_input_shape, ir_emitter_context_->device_description());
llvm::Type* index_ty = GetIndexTypeForKernel(
- reduce,
- launch_dimensions.block_count() * launch_dimensions.threads_per_block(),
- &ir_builder_);
+ reduce, launch_dimensions.launch_bound(), &ir_builder_);
auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_ty, c);
@@ -3004,5 +2689,502 @@ Status IrEmitterUnnested::EmitTargetElementLoop(
static_cast<KernelThunk*>(LastThunk()));
}
+int IrEmitterUnnested::ConstructIrArrayForOutputs(
+ const HloInstruction& hlo, std::vector<llvm_ir::IrArray>* output_arrays) {
+ int64 num_outputs = 1;
+ if (hlo.IsMultiOutputFusion()) {
+ num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
+ output_arrays->reserve(num_outputs);
+ for (int64 i = 0; i < num_outputs; ++i) {
+ output_arrays->push_back(GetIrArray(hlo, hlo, {i}));
+ }
+ } else {
+ output_arrays->push_back(GetIrArray(hlo, hlo));
+ }
+ return num_outputs;
+}
+
+int IrEmitterUnnested::ConstructIrArrayForInputs(
+ const HloInstruction& hlo, std::vector<llvm_ir::IrArray>* param_arrays) {
+ int64 num_params = hlo.operands().size();
+ param_arrays->reserve(num_params);
+ for (const HloInstruction* param : hlo.operands()) {
+ param_arrays->push_back(GetIrArray(*param, hlo));
+ }
+ return num_params;
+}
+
+int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
+ const HloInstruction& hlo,
+ const std::vector<llvm_ir::IrArray>& output_arrays,
+ tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ std::vector<Shape>* output_reduced_shapes,
+ std::vector<llvm_ir::IrArray>* output_in_reduced_shape_arrays) {
+ int64 num_outputs = 1;
+ if (hlo.IsMultiOutputFusion()) {
+ num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
+ output_in_reduced_shape_arrays->reserve(num_outputs);
+ output_reduced_shapes->reserve(num_outputs);
+ for (int64 i = 0; i < num_outputs; ++i) {
+ output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout(
+ ShapeUtil::GetSubshape(hlo.shape(), {i}).element_type(),
+ reduced_output_dims));
+ output_in_reduced_shape_arrays->push_back(output_arrays[i].CastToShape(
+ (*output_reduced_shapes)[i], &ir_builder_));
+ }
+ } else {
+ output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout(
+ hlo.shape().element_type(), reduced_output_dims));
+ output_in_reduced_shape_arrays->push_back(output_arrays[0].CastToShape(
+ (*output_reduced_shapes)[0], &ir_builder_));
+ }
+ return num_outputs;
+}
+
+int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape(
+ const HloInstruction& hlo,
+ const std::vector<llvm_ir::IrArray>& param_arrays,
+ const std::vector<llvm::Value*>& param_buffers,
+ tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ std::vector<Shape>* param_reduced_shapes,
+ std::vector<llvm_ir::IrArray>* param_in_reduced_shape_arrays) {
+ int64 num_params = hlo.operands().size();
+ param_in_reduced_shape_arrays->reserve(num_params);
+ param_reduced_shapes->reserve(num_params);
+ for (int64 id = 0; id < num_params; ++id) {
+ if (param_buffers[id] == nullptr) {
+ param_reduced_shapes->push_back(Shape());
+ param_in_reduced_shape_arrays->push_back(llvm_ir::IrArray());
+ continue;
+ }
+ const HloInstruction* param = hlo.operand(id);
+ param_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout(
+ param->shape().element_type(),
+ Permute({0, 2, 1}, reduced_output_dims)));
+ param_in_reduced_shape_arrays->push_back(param_arrays[id].CastToShape(
+ (*param_reduced_shapes)[id], &ir_builder_));
+ }
+ return num_params;
+}
+
+namespace {
+
+// Emits the intrinsic call and meta data for thread_id. Calculates the (y,x)
+// coordinate from thread_id and returns the (y, x) coordinate.
+std::tuple<llvm::Value*, llvm::Value*> CalculateYXCoordinateWithinTile(
+ llvm::IRBuilder<>* builder, llvm::Value* tile_size,
+ int64 threads_per_tile) {
+ // Calculate the starting element coordinate within a tile for the current
+ // thread, (y, x) from thread_id.
+ llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder);
+ llvm_ir::AddRangeMetadata(0, threads_per_tile,
+ static_cast<llvm::Instruction*>(thread_id));
+ thread_id = builder->CreateIntCast(thread_id, tile_size->getType(),
+ /*isSigned=*/true, "thread.id.x");
+ auto x = builder->CreateURem(thread_id, tile_size);
+ auto y = builder->CreateUDiv(thread_id, tile_size);
+
+ return std::make_tuple(y, x);
+}
+
+// Emits the intrinsic call and meta data for block_id. Returns the block_id.
+llvm::Value* BuildBlockId(llvm::IRBuilder<>* builder, llvm::Type* index_ty,
+ int64 num_tiles) {
+ llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, builder);
+ llvm_ir::AddRangeMetadata(0, num_tiles,
+ static_cast<llvm::Instruction*>(block_id));
+ return builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true,
+ "block.id.x");
+}
+
+// Emits code to process up to (tile_size/num_rows) elements in a tile, given
+// `emit_ele_function` is the function to emit code to process one element, `y`
+// and `x` are the coordinates for the first element to process, and `index` is
+// the index for the origin of the tile. Emits bound check to ensure that each
+// processed element is within the boundary defined by `tile_width` and
+// `tile_height`.
+void EmitCodeWithBoundCheck(
+ int64 tile_size, int64 num_rows,
+ std::function<void(const llvm_ir::IrArray::Index&, llvm::Value*)>
+ emit_ele_function,
+ const llvm_ir::IrArray::Index& index, const string& loop_name,
+ KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y,
+ llvm::Value* x, llvm::Value* tile_width, llvm::Value* tile_height) {
+ llvm::Type* index_ty = tile_width->getType();
+ // Emits a constant value with index type.
+ auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_ty, c);
+ };
+ // Adds `addend` to the given `dim` of `index`.
+ auto offset_dim = [&](llvm_ir::IrArray::Index index, llvm::Value* addend,
+ int64 dim) {
+ index[dim] = builder->CreateAdd(index[dim], addend);
+ return index;
+ };
+ auto emit_not_last_row_true = [&]() {
+ for (int64 i = 0; i < tile_size; i += num_rows) {
+ auto source_idx = offset_dim(index, index_typed_const(i), /*dim=*/1);
+ auto y_loc = builder->CreateAdd(index_typed_const(i), y);
+ emit_ele_function(source_idx, y_loc);
+ }
+ };
+ auto emit_not_last_row_false = [&]() {
+ ksl->IfReturnVoid(
+ "x_in_tile", builder->CreateICmpULT(x, tile_width), [&]() {
+ // tile_height_upper_bound = ceil(tile_height / num_rows) *
+ // num_rows
+ auto tile_height_upper_bound = builder->CreateMul(
+ builder->CreateUDiv(
+ builder->CreateAdd(tile_height,
+ index_typed_const(num_rows - 1)),
+ index_typed_const(num_rows)),
+ index_typed_const(num_rows));
+
+ ksl->ForReturnVoid(
+ loop_name,
+ /*start=*/index_typed_const(0),
+ /*end=*/tile_height_upper_bound,
+ /*step=*/index_typed_const(num_rows), [&](llvm::Value* y_indvar) {
+ auto y_loc = builder->CreateAdd(y_indvar, y);
+ ksl->IfReturnVoid(
+ "y_in_tile", builder->CreateICmpULT(y_loc, tile_height),
+ [&]() {
+ emit_ele_function(offset_dim(index, y_indvar, /*dim=*/1),
+ y_loc);
+ });
+ });
+ });
+ };
+ ksl->IfReturnVoid(
+ "not_last_row",
+ builder->CreateAnd(
+ builder->CreateICmpEQ(index_typed_const(tile_size), tile_width),
+ builder->CreateICmpEQ(index_typed_const(tile_size), tile_height)),
+ emit_not_last_row_true, emit_not_last_row_false);
+}
+
+} // namespace
+
+// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
+// algorithm to improve the memory access patterns for the input parameters
+// which have a shape that is a 0-2-1 transpose of the output tensors.
+//
+// For the purpose of tiling, the output tensors have a logical shape of three
+// compoents 0-2-1 while the relevant input parameters have a logical shape of
+// three components 0-1-2 in the order major to minor. The x- and y- dimensions
+// of the tensores are tiled in square tiles of edge length `tile_size`. Each
+// thread block of `tile_size` x `num_rows` threads transposes one tile: each
+// thread copies tile_size/num_rows elements from the input to a shared memory
+// tile, then the otherwise "regular hlo kernel" reads from the shared memory
+// instead of the original input.
+//
+// This is similar to the algorithm in CUDA:
+// https://github.com/tensorflow/tensorflow/blob/d2693c8a70567cc78b2e8a9ac8020d321620ca83/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L189
+//
+// `tile_size` should usually be same as warp size. We currently choose
+// 32 for `tile_size` and 4 for `num_rows`. The CUDA algorithm usesa 8 for
+// `num_rows`.
+//
+// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient
+// to launch fewer blocks so each transposes many tiles, and
+// in any case, the number of blocks we can launch is limited.
+//
+LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
+ HloInstruction* hlo, tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ tensorflow::gtl::ArraySlice<int64> tiled_param_ids) {
+ // Parameters for the tiling algorithm.
+ constexpr int64 tile_size = 32;
+ constexpr int64 num_rows = 4;
+ constexpr int64 threads_per_tile = tile_size * num_rows;
+
+ // Construct the IrArray for the outputs.
+ std::vector<llvm_ir::IrArray> output_arrays;
+ int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays);
+
+ // Construct the IrArray for the inputs and initialize the array to store
+ // the tile buffers for the tiled parameters.
+ std::vector<llvm_ir::IrArray> param_arrays;
+ int64 num_params = ConstructIrArrayForInputs(*hlo, &param_arrays);
+ std::vector<llvm::Value*> param_buffers;
+ param_buffers.reserve(num_params);
+ for (int i = 0; i < num_params; ++i) {
+ param_buffers.push_back(nullptr);
+ }
+
+ // Allocate share memory buffers to store the tiled parameters.
+ for (int64 id : tiled_param_ids) {
+ const HloInstruction* param = hlo->operand(id);
+ // Add 1 to the minor dimension to reduce share memory bank conflict.
+ llvm::Type* tile_type = llvm::ArrayType::get(
+ llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType(
+ param->shape().element_type(), module_),
+ tile_size + 1),
+ tile_size);
+ auto* tile_base_ptr = new llvm::GlobalVariable(
+ *ir_builder_.GetInsertBlock()->getParent()->getParent(), tile_type,
+ /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
+ llvm::UndefValue::get(tile_type), "tile", nullptr,
+ llvm::GlobalValue::NotThreadLocal,
+ /*AddressSpace=*/3 /* GPU shared memory */);
+ param_buffers[id] = tile_base_ptr;
+ VLOG(3) << "Add buffer for parameter " << id << " "
+ << llvm_ir::DumpToString(*tile_base_ptr);
+ }
+
+ // The 0-2-1 shape of the tiling scheme is the reduced shape of the HLO result
+ // for the purpose of tiling. Calculate the logical output dimensions in tile
+ // from the reduced output dimensions.
+ std::vector<int64> output_dims_in_tiles = std::vector<int64>(
+ reduced_output_dims.begin(), reduced_output_dims.end());
+ CHECK_EQ(output_dims_in_tiles.size(), 3);
+ for (int i = 1; i < 3; ++i) {
+ output_dims_in_tiles[i] =
+ CeilOfRatio<int64>(output_dims_in_tiles[i], tile_size);
+ }
+ const int64 num_tiles =
+ std::accumulate(output_dims_in_tiles.begin(), output_dims_in_tiles.end(),
+ 1, std::multiplies<int64>());
+ LaunchDimensions launch_dimensions(num_tiles, threads_per_tile);
+
+ llvm::Type* index_ty = GetIndexTypeForKernel(
+ hlo, launch_dimensions.launch_bound(), &ir_builder_);
+ auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_ty, c);
+ };
+
+ // Cast each output IrArray to its corresponding reduced shape and keep the
+ // reduced shape live during IR emission.
+ std::vector<llvm_ir::IrArray> output_in_reduced_shape_arrays;
+ std::vector<Shape> output_reduced_shapes;
+ CHECK_EQ(ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
+ *hlo, output_arrays, reduced_output_dims, &output_reduced_shapes,
+ &output_in_reduced_shape_arrays),
+ num_outputs);
+
+ // For each tiled parameter, cast its input IrArray to the corresponding
+ // reduced shape and keep the reduced shape live during IR emission.
+ std::vector<llvm_ir::IrArray> param_in_reduced_shape_arrays;
+ std::vector<Shape> param_reduced_shapes;
+ CHECK_EQ(ConstructInputReducedShapeAndCastInputIrArrayToShape(
+ *hlo, param_arrays, param_buffers, reduced_output_dims,
+ &param_reduced_shapes, &param_in_reduced_shape_arrays),
+ num_params);
+
+ // Calculate the starting element coordinate within a tile for the current
+ // thread, (y, x) from thread_id.
+ llvm::Value* x;
+ llvm::Value* y;
+ std::tie(y, x) = CalculateYXCoordinateWithinTile(
+ &ir_builder_, index_typed_const(tile_size), threads_per_tile);
+
+ // Calculate the index for the current output tile from block_id.
+ const llvm_ir::IrArray::Index output_tile_index(
+ BuildBlockId(&ir_builder_, index_ty, num_tiles),
+ ShapeUtil::MakeShapeWithDescendingLayout(PRED /*arbitrary*/,
+ output_dims_in_tiles),
+ &ir_builder_);
+
+ // Output tile origin is the index for the first element of the current output
+ // tile.
+ const llvm_ir::IrArray::Index output_tile_origin = ({
+ llvm_ir::IrArray::Index index = output_tile_index;
+ for (int i = 1; i < 3; ++i) {
+ index[i] = ir_builder_.CreateMul(index[i], index_typed_const(tile_size),
+ "tile_origin." + std::to_string(i));
+ }
+ index;
+ });
+
+ // Calculate the input tile origin from the output tile origin.
+ const llvm_ir::IrArray::Index input_tile_origin(
+ Permute({0, 2, 1}, output_tile_origin.multidim()));
+
+ // A utility to add `addend` to the given `dim` of `index`.
+ auto offset_dim = [&](llvm_ir::IrArray::Index index, llvm::Value* addend,
+ int64 dim) {
+ index[dim] = ir_builder_.CreateAdd(index[dim], addend);
+ return index;
+ };
+
+ // Calculate the current output tile bounds in each of the logical dimension.
+ std::vector<llvm::Value*> output_tile_bounds(3);
+ for (int i = 1; i < 3; ++i) {
+ // Only last row or column may not have full size.
+ output_tile_bounds[i] = ir_builder_.CreateSelect(
+ ir_builder_.CreateICmpEQ(
+ output_tile_index[i],
+ index_typed_const(output_dims_in_tiles[i] - 1)),
+ index_typed_const(reduced_output_dims[i] -
+ (output_dims_in_tiles[i] - 1) * tile_size),
+ index_typed_const(tile_size), "tile_size");
+ }
+
+ KernelSupportLibrary ksl(&ir_builder_, llvm_ir::UnrollMode::kDefaultUnroll);
+ // A wrapper of EmitCodeWithBoundCheck, to pass in the parameters common to
+ // both call sites to of the routine.
+ auto emit_code_with_bound_check =
+ [&](std::function<void(const llvm_ir::IrArray::Index&, llvm::Value*)>
+ emit_ele_function,
+ const llvm_ir::IrArray::Index& index, const string& loop_name,
+ llvm::Value* tile_width, llvm::Value* tile_height) {
+ EmitCodeWithBoundCheck(tile_size, num_rows, emit_ele_function, index,
+ loop_name, &ksl, &ir_builder_, y, x, tile_width,
+ tile_height);
+ };
+
+ auto copy_tiled_params_to_buffers = [&](const llvm_ir::IrArray::Index& index,
+ llvm::Value* y_loc) {
+ for (int64 id : tiled_param_ids) {
+ llvm_ir::IrArray& input_in_logical_shape =
+ param_in_reduced_shape_arrays[id];
+ llvm::Value* buffer = param_buffers[id];
+ llvm::Instruction* store_to_buffer = ir_builder_.CreateStore(
+ input_in_logical_shape.EmitReadArrayElement(index, &ir_builder_,
+ "input_element"),
+ ir_builder_.CreateGEP(buffer, {index_typed_const(0), y_loc, x}));
+ param_arrays[id].AnnotateBufferLoadStoreInstructionWithMetadata(
+ store_to_buffer);
+ }
+ };
+
+ const llvm_ir::IrArray::Index input_index =
+ offset_dim(offset_dim(input_tile_origin, x, /*dim=*/2), y, /*dim=*/1);
+ // Copy input parameter values to shared memory buffers.
+ emit_code_with_bound_check(
+ // tile[y, x] = input[index]
+ copy_tiled_params_to_buffers, input_index, "input", output_tile_bounds[1],
+ output_tile_bounds[2]);
+
+ // Wait for all threads to reach this point, lest we copy a value from tile to
+ // output before the other thread copies it from input to tile.
+ // This is `__syncthreads` in CUDA.
+ llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {},
+ &ir_builder_);
+
+ llvm_ir::TiledParameterInfo tiled_param_info(param_buffers, y, x);
+ // A helper to emit an otherwise usual kernel for an hlo instruction except
+ // that values for the tiled parameters are read from tile buffers.
+ auto emit_kernel_using_buffers = [&](const llvm_ir::IrArray::Index& index,
+ llvm::Value* y_loc) {
+ if (hlo->opcode() == HloOpcode::kCopy) {
+ llvm::Instruction* load_from_buffer = ir_builder_.CreateLoad(
+ ir_builder_.CreateGEP(param_buffers[0],
+ {ir_builder_.getInt64(0), x, y_loc}),
+ "output_element");
+ param_arrays[0].AnnotateBufferLoadStoreInstructionWithMetadata(
+ load_from_buffer);
+ output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
+ index, load_from_buffer, &ir_builder_);
+ } else {
+ CHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
+ GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_,
+ &ir_builder_, GetNestedComputer());
+ FusedIrEmitter fused_emitter(param_arrays, &elem_emitter);
+ tiled_param_info.set_y(y_loc);
+ fused_emitter.SetTiledParameterInfo(&tiled_param_info);
+ TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter));
+ llvm_ir::IrArray::Index untiled_index = llvm_ir::GetUnreducedOutputIndex(
+ index, output_reduced_shapes[0], output_arrays[0].GetShape(),
+ &ir_builder_);
+ const llvm_ir::ElementGenerator& output_generator =
+ fused_emitter.GetRootGenerator();
+ llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
+ if (hlo->IsMultiOutputFusion()) {
+ CHECK(output_value->getType()->isStructTy())
+ << "This BodyEmitter is for multi-output fusion, but target "
+ "element "
+ "generator does not produce values of struct type.";
+ CHECK_EQ(output_value->getType()->getStructNumElements(),
+ output_in_reduced_shape_arrays.size());
+ for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) {
+ output_in_reduced_shape_arrays[i].EmitWriteArrayElement(
+ index, ir_builder_.CreateExtractValue(output_value, i),
+ &ir_builder_);
+ }
+ } else {
+ output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
+ index, output_value, &ir_builder_);
+ }
+ }
+ };
+
+ const llvm_ir::IrArray::Index output_index =
+ offset_dim(offset_dim(output_tile_origin, x, /*dim=*/2), y, /*dim=*/1);
+ emit_code_with_bound_check(
+ // output[index] = ...
+ emit_kernel_using_buffers, output_index, "output", output_tile_bounds[2],
+ output_tile_bounds[1]);
+
+ // For multiple output fusion, emit a tuple with all the individual outputs.
+ if (hlo->IsMultiOutputFusion()) {
+ std::vector<llvm::Value*> tuple_operand_ptrs;
+ for (int64 i = 0; i < output_arrays.size(); ++i) {
+ tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
+ }
+ llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &ir_builder_,
+ module_);
+ }
+
+ return launch_dimensions;
+}
+
+bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
+ HloOpcode opcode = hlo->opcode();
+ CHECK(opcode == HloOpcode::kFusion || opcode == HloOpcode::kCopy);
+ const Shape& output_shape = hlo->IsMultiOutputFusion()
+ ? ShapeUtil::GetSubshape(hlo->shape(), {0})
+ : hlo->shape();
+ // If the output_shape is reduced to 021 shape find all the parameters of the
+ // hlo that are in the corresponding 012 shape.
+ std::vector<int64> params_012;
+ bool found = false;
+ std::vector<int64> reduced_dims_021;
+ int64 index = 0;
+ for (HloInstruction* operand : hlo->operands()) {
+ auto find_transpose_result =
+ llvm_ir::FindTranspose021(operand->shape(), output_shape);
+ if (find_transpose_result.has_value()) {
+ const std::vector<int64>& curr_reduced_dims_021 = *find_transpose_result;
+ if (found) {
+ if (!ContainersEqual(reduced_dims_021, curr_reduced_dims_021)) {
+ // There are more than one possible transpose. Instead of picking one
+ // transpose, we simply give up here.
+ found = false;
+ break;
+ }
+ } else {
+ found = true;
+ reduced_dims_021 = curr_reduced_dims_021;
+ }
+ if (found) {
+ params_012.push_back(index);
+ }
+ }
+ ++index;
+ }
+
+ if (!found) {
+ return false;
+ }
+
+ if (reduced_dims_021[1] < kMinDimensionToTransposeTiled ||
+ reduced_dims_021[2] < kMinDimensionToTransposeTiled) {
+ return false;
+ }
+
+ VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString();
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(hlo,
+ /*implements_whole_instruction=*/true));
+ const LaunchDimensions launch_dimensions =
+ EmitHlo021Tile(hlo, reduced_dims_021, params_012);
+ UpdateLaunchDimensions(launch_dimensions, LastThunk(),
+ ir_emitter_context_->llvm_module());
+
+ return true;
+}
+
} // namespace gpu
} // namespace xla