aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc101
1 files changed, 67 insertions, 34 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index f4187cdade..958408e875 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -544,9 +544,9 @@ bool AreShapesForTranspose021(const Shape& a, const Shape& b) {
// Emits a tiled 0-2-1 transpose, assuming both input and output lain out from
// major to minor. The x- and y- dimensions are tiled in square tiles of edge
-// length `tile_size`. Each thread block of `tile_size` threads transposes one
-// tile: each thread copies a row from the input to a shared memory tile, then
-// copies a column from the shared memory tile to the output.
+// length `tile_size`. Each thread block of `tile_size` x `num_rows` threads
+// transposes one tile: each thread copies a row from the input to a shared
+// memory tile, then copies a column from the shared memory tile to the output.
//
// `tile_size` should usually be same as warp size.
//
@@ -557,9 +557,10 @@ bool AreShapesForTranspose021(const Shape& a, const Shape& b) {
// in any case, the number of blocks we can launch is limited.
//
// This is the same algorithm in CUDA:
-// https://github.com/tensorflow/tensorflow/blob/6172351b81af76d0b819fea6bb478cbd4016d6c2/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L183
+// https://github.com/tensorflow/tensorflow/blob/d2693c8a70567cc78b2e8a9ac8020d321620ca83/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L189
int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
- const int64 tile_size, llvm::IRBuilder<>* builder) {
+ const int64 tile_size, const int64 num_rows,
+ llvm::IRBuilder<>* builder) {
// Adds `addend` to the given `dim` of `index`.
auto offset_dim = [builder](llvm_ir::IrArray::Index index,
llvm::Value* addend, int64 dim) {
@@ -590,18 +591,29 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
// let x = threadIdx.x
llvm::Value* x = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder);
- llvm_ir::AddRangeMetadata(0, tile_size, static_cast<llvm::Instruction*>(x));
+ llvm_ir::AddRangeMetadata(0, num_rows * tile_size,
+ static_cast<llvm::Instruction*>(x));
x = builder->CreateIntCast(x, builder->getInt64Ty(), /*isSigned=*/true,
"thread.id.x");
+ // computing logical thread ids
+ // logical_x = x % tile_size
+ auto logical_x = builder->CreateURem(x, builder->getInt64(tile_size));
+
+ // logical_y = x / tile_size
+ auto logical_y = builder->CreateUDiv(x, builder->getInt64(tile_size));
+
// `emit_cp` emits equivalent to following pseudocode:
// if (tile_size == tile_width && tile_size == tile_height) {
- // unroll for (y in 0..tile_size) {
- // emit_cp_element(index + {0, y, 0}, y);
+ // unroll for (i in range(0, tile_size, num_rows)) {
+ // emit_cp_element(index + {0, i, 0}, y + logical_y);
// }
// } else if (x < tile_width) {
- // for (y in 0..tile_height) {
- // emit_cp_element(index + {0, y, 0}, y);
+ // tile_height_upperbound = ceil(tile_height / num_rows) * num_rows;
+ // for (i in range(0, tile_height_upperbound, num_rows)) {
+ // y_loc = i + logical_y;
+ // if (y_loc < tile_height)
+ // emit_cp_element(index + {0, i, 0}, y_loc);
// }
// }
//
@@ -615,32 +627,50 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
// tile, whether which is row or column is a function of whether we're copying
// from input or to output, and `index` is the index into the input or output
// array.
- auto emit_cp_tile = [builder, tile_size, x, &offset_dim](
- std::function<void(const llvm_ir::IrArray::Index&, llvm::Value*)>
- emit_cp_element,
- llvm::Value* tile_width, llvm::Value* tile_height,
- const llvm_ir::IrArray::Index& index, const string& loop_name) {
+ auto emit_cp_tile = [builder, tile_size, &offset_dim, num_rows, logical_x,
+ logical_y](
+ std::function<void(const llvm_ir::IrArray::Index&,
+ llvm::Value*)>
+ emit_cp_element,
+ llvm::Value* tile_width, llvm::Value* tile_height,
+ const llvm_ir::IrArray::Index& index,
+ const string& loop_name) {
llvm_ir::LlvmIfData if_not_last_row = llvm_ir::EmitIfThenElse(
builder->CreateAnd(
builder->CreateICmpEQ(builder->getInt64(tile_size), tile_width),
builder->CreateICmpEQ(builder->getInt64(tile_size), tile_height)),
"not_last_row", builder);
builder->SetInsertPoint(if_not_last_row.true_block->getTerminator());
- for (int64 i = 0; i < tile_size; ++i) {
- emit_cp_element(offset_dim(index, builder->getInt64(i), /*dim=*/1),
- builder->getInt64(i));
+ for (int64 i = 0; i < tile_size; i += num_rows) {
+ auto source_idx = offset_dim(index, builder->getInt64(i), /*dim=*/1);
+ auto y_loc = builder->CreateAdd(builder->getInt64(i), logical_y);
+ emit_cp_element(source_idx, y_loc);
}
builder->SetInsertPoint(if_not_last_row.false_block->getTerminator());
llvm_ir::LlvmIfData if_in_tile = llvm_ir::EmitIfThenElse(
- builder->CreateICmpULT(x, tile_width), "in_tile", builder);
+ builder->CreateICmpULT(logical_x, tile_width), "x_in_tile", builder);
builder->SetInsertPoint(if_in_tile.true_block->getTerminator());
- auto loop = llvm_ir::ForLoop::EmitForLoop(loop_name, builder->getInt64(0),
- tile_height, builder->getInt64(1),
- builder);
+
+ // tile_height_upper_bound = ceil(tile_height / num_rows) * num_rows
+ auto tile_height_upper_bound = builder->CreateMul(
+ builder->CreateUDiv(
+ builder->CreateAdd(tile_height, builder->getInt64(num_rows - 1)),
+ builder->getInt64(num_rows)),
+ builder->getInt64(num_rows));
+
+ auto loop = llvm_ir::ForLoop::EmitForLoop(
+ loop_name, builder->getInt64(0), tile_height_upper_bound,
+ builder->getInt64(num_rows), builder);
llvm_ir::SetToFirstInsertPoint(loop->GetHeaderBasicBlock(), builder);
builder->SetInsertPoint(loop->GetBodyBasicBlock()->getTerminator());
+
+ auto y_loc = builder->CreateAdd(loop->GetIndVarValue(), logical_y);
+ auto if_y_in_tile = llvm_ir::EmitIfThenElse(
+ builder->CreateICmpULT(y_loc, tile_height), "y_in_tile", builder);
+ builder->SetInsertPoint(if_y_in_tile.true_block->getTerminator());
+
emit_cp_element(offset_dim(index, loop->GetIndVarValue(), /*dim=*/1),
- loop->GetIndVarValue());
+ y_loc);
builder->SetInsertPoint(if_not_last_row.after_block->getTerminator());
};
@@ -673,7 +703,8 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
index;
});
const llvm_ir::IrArray::Index input_index =
- offset_dim(input_tile_origin, x, /*dim=*/2);
+ offset_dim(offset_dim(input_tile_origin, logical_x, /*dim=*/2), logical_y,
+ /*dim=*/1);
std::vector<llvm::Value*> tile_dims(input_shape.dimensions().size());
// Only last row or column may not have full size.
for (int i = 1; i < 3; ++i) {
@@ -688,11 +719,11 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
// Load data from input memory to shared memory tile.
emit_cp_tile(
// tile[y, x] = input_array[index]
- [builder, tile, x, &input](const llvm_ir::IrArray::Index& index,
- llvm::Value* y) {
+ [builder, tile, &input, logical_x](const llvm_ir::IrArray::Index& index,
+ llvm::Value* y) {
builder->CreateStore(
input.EmitReadArrayElement(index, builder, "input_element"),
- builder->CreateGEP(tile, {builder->getInt64(0), y, x}));
+ builder->CreateGEP(tile, {builder->getInt64(0), y, logical_x}));
},
tile_dims[2], tile_dims[1], input_index, "input");
@@ -706,17 +737,18 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
const llvm_ir::IrArray::Index output_tile_origin(
Permute({0, 2, 1}, input_tile_origin.multidim()));
const llvm_ir::IrArray::Index output_index =
- offset_dim(output_tile_origin, x, /*dim=*/2);
+ offset_dim(offset_dim(output_tile_origin, logical_x, /*dim=*/2),
+ logical_y, /*dim=*/1);
// Store data from shared memory tile to output memory.
emit_cp_tile(
// output_array[index] = tile[x, y]
- [builder, tile, x, &output](const llvm_ir::IrArray::Index& index,
- llvm::Value* y) {
+ [builder, tile, &output, logical_x](const llvm_ir::IrArray::Index& index,
+ llvm::Value* y) {
output.EmitWriteArrayElement(
index,
builder->CreateLoad(
- builder->CreateGEP(tile, {builder->getInt64(0), x, y}),
+ builder->CreateGEP(tile, {builder->getInt64(0), logical_x, y}),
"output_element"),
builder);
},
@@ -742,13 +774,14 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
thunk_sequence_->emplace_back(BuildKernelThunk(copy));
VLOG(3) << "Emitting tiled 0-2-1 transposition";
constexpr int64 tile_size = 32;
+ constexpr int64 num_rows = 8;
int64 num_tiles = EmitTranspose021Tiled(
GetIrArray(*(copy->operand(0)))
.CastToShape(reduced_input_shape, &ir_builder_),
GetIrArray(*copy).CastToShape(reduced_output_shape, &ir_builder_),
- tile_size, &ir_builder_);
- UpdateLaunchDimensions(LaunchDimensions(num_tiles, tile_size), LastThunk(),
- ir_emitter_context_->llvm_module());
+ tile_size, num_rows, &ir_builder_);
+ UpdateLaunchDimensions(LaunchDimensions(num_tiles, num_rows * tile_size),
+ LastThunk(), ir_emitter_context_->llvm_module());
return Status::OK();
}