diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | 101 |
1 files changed, 67 insertions, 34 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index f4187cdade..958408e875 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -544,9 +544,9 @@ bool AreShapesForTranspose021(const Shape& a, const Shape& b) { // Emits a tiled 0-2-1 transpose, assuming both input and output lain out from // major to minor. The x- and y- dimensions are tiled in square tiles of edge -// length `tile_size`. Each thread block of `tile_size` threads transposes one -// tile: each thread copies a row from the input to a shared memory tile, then -// copies a column from the shared memory tile to the output. +// length `tile_size`. Each thread block of `tile_size` x `num_rows` threads +// transposes one tile: each thread copies a row from the input to a shared +// memory tile, then copies a column from the shared memory tile to the output. // // `tile_size` should usually be same as warp size. // @@ -557,9 +557,10 @@ bool AreShapesForTranspose021(const Shape& a, const Shape& b) { // in any case, the number of blocks we can launch is limited. // // This is the same algorithm in CUDA: -// https://github.com/tensorflow/tensorflow/blob/6172351b81af76d0b819fea6bb478cbd4016d6c2/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L183 +// https://github.com/tensorflow/tensorflow/blob/d2693c8a70567cc78b2e8a9ac8020d321620ca83/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L189 int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, - const int64 tile_size, llvm::IRBuilder<>* builder) { + const int64 tile_size, const int64 num_rows, + llvm::IRBuilder<>* builder) { // Adds `addend` to the given `dim` of `index`. auto offset_dim = [builder](llvm_ir::IrArray::Index index, llvm::Value* addend, int64 dim) { @@ -590,18 +591,29 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, // let x = threadIdx.x llvm::Value* x = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, tile_size, static_cast<llvm::Instruction*>(x)); + llvm_ir::AddRangeMetadata(0, num_rows * tile_size, + static_cast<llvm::Instruction*>(x)); x = builder->CreateIntCast(x, builder->getInt64Ty(), /*isSigned=*/true, "thread.id.x"); + // computing logical thread ids + // logical_x = x % tile_size + auto logical_x = builder->CreateURem(x, builder->getInt64(tile_size)); + + // logical_y = x / tile_size + auto logical_y = builder->CreateUDiv(x, builder->getInt64(tile_size)); + // `emit_cp` emits equivalent to following pseudocode: // if (tile_size == tile_width && tile_size == tile_height) { - // unroll for (y in 0..tile_size) { - // emit_cp_element(index + {0, y, 0}, y); + // unroll for (i in range(0, tile_size, num_rows)) { + // emit_cp_element(index + {0, i, 0}, y + logical_y); // } // } else if (x < tile_width) { - // for (y in 0..tile_height) { - // emit_cp_element(index + {0, y, 0}, y); + // tile_height_upperbound = ceil(tile_height / num_rows) * num_rows; + // for (i in range(0, tile_height_upperbound, num_rows)) { + // y_loc = i + logical_y; + // if (y_loc < tile_height) + // emit_cp_element(index + {0, i, 0}, y_loc); // } // } // @@ -615,32 +627,50 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, // tile, whether which is row or column is a function of whether we're copying // from input or to output, and `index` is the index into the input or output // array. - auto emit_cp_tile = [builder, tile_size, x, &offset_dim]( - std::function<void(const llvm_ir::IrArray::Index&, llvm::Value*)> - emit_cp_element, - llvm::Value* tile_width, llvm::Value* tile_height, - const llvm_ir::IrArray::Index& index, const string& loop_name) { + auto emit_cp_tile = [builder, tile_size, &offset_dim, num_rows, logical_x, + logical_y]( + std::function<void(const llvm_ir::IrArray::Index&, + llvm::Value*)> + emit_cp_element, + llvm::Value* tile_width, llvm::Value* tile_height, + const llvm_ir::IrArray::Index& index, + const string& loop_name) { llvm_ir::LlvmIfData if_not_last_row = llvm_ir::EmitIfThenElse( builder->CreateAnd( builder->CreateICmpEQ(builder->getInt64(tile_size), tile_width), builder->CreateICmpEQ(builder->getInt64(tile_size), tile_height)), "not_last_row", builder); builder->SetInsertPoint(if_not_last_row.true_block->getTerminator()); - for (int64 i = 0; i < tile_size; ++i) { - emit_cp_element(offset_dim(index, builder->getInt64(i), /*dim=*/1), - builder->getInt64(i)); + for (int64 i = 0; i < tile_size; i += num_rows) { + auto source_idx = offset_dim(index, builder->getInt64(i), /*dim=*/1); + auto y_loc = builder->CreateAdd(builder->getInt64(i), logical_y); + emit_cp_element(source_idx, y_loc); } builder->SetInsertPoint(if_not_last_row.false_block->getTerminator()); llvm_ir::LlvmIfData if_in_tile = llvm_ir::EmitIfThenElse( - builder->CreateICmpULT(x, tile_width), "in_tile", builder); + builder->CreateICmpULT(logical_x, tile_width), "x_in_tile", builder); builder->SetInsertPoint(if_in_tile.true_block->getTerminator()); - auto loop = llvm_ir::ForLoop::EmitForLoop(loop_name, builder->getInt64(0), - tile_height, builder->getInt64(1), - builder); + + // tile_height_upper_bound = ceil(tile_height / num_rows) * num_rows + auto tile_height_upper_bound = builder->CreateMul( + builder->CreateUDiv( + builder->CreateAdd(tile_height, builder->getInt64(num_rows - 1)), + builder->getInt64(num_rows)), + builder->getInt64(num_rows)); + + auto loop = llvm_ir::ForLoop::EmitForLoop( + loop_name, builder->getInt64(0), tile_height_upper_bound, + builder->getInt64(num_rows), builder); llvm_ir::SetToFirstInsertPoint(loop->GetHeaderBasicBlock(), builder); builder->SetInsertPoint(loop->GetBodyBasicBlock()->getTerminator()); + + auto y_loc = builder->CreateAdd(loop->GetIndVarValue(), logical_y); + auto if_y_in_tile = llvm_ir::EmitIfThenElse( + builder->CreateICmpULT(y_loc, tile_height), "y_in_tile", builder); + builder->SetInsertPoint(if_y_in_tile.true_block->getTerminator()); + emit_cp_element(offset_dim(index, loop->GetIndVarValue(), /*dim=*/1), - loop->GetIndVarValue()); + y_loc); builder->SetInsertPoint(if_not_last_row.after_block->getTerminator()); }; @@ -673,7 +703,8 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, index; }); const llvm_ir::IrArray::Index input_index = - offset_dim(input_tile_origin, x, /*dim=*/2); + offset_dim(offset_dim(input_tile_origin, logical_x, /*dim=*/2), logical_y, + /*dim=*/1); std::vector<llvm::Value*> tile_dims(input_shape.dimensions().size()); // Only last row or column may not have full size. for (int i = 1; i < 3; ++i) { @@ -688,11 +719,11 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, // Load data from input memory to shared memory tile. emit_cp_tile( // tile[y, x] = input_array[index] - [builder, tile, x, &input](const llvm_ir::IrArray::Index& index, - llvm::Value* y) { + [builder, tile, &input, logical_x](const llvm_ir::IrArray::Index& index, + llvm::Value* y) { builder->CreateStore( input.EmitReadArrayElement(index, builder, "input_element"), - builder->CreateGEP(tile, {builder->getInt64(0), y, x})); + builder->CreateGEP(tile, {builder->getInt64(0), y, logical_x})); }, tile_dims[2], tile_dims[1], input_index, "input"); @@ -706,17 +737,18 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, const llvm_ir::IrArray::Index output_tile_origin( Permute({0, 2, 1}, input_tile_origin.multidim())); const llvm_ir::IrArray::Index output_index = - offset_dim(output_tile_origin, x, /*dim=*/2); + offset_dim(offset_dim(output_tile_origin, logical_x, /*dim=*/2), + logical_y, /*dim=*/1); // Store data from shared memory tile to output memory. emit_cp_tile( // output_array[index] = tile[x, y] - [builder, tile, x, &output](const llvm_ir::IrArray::Index& index, - llvm::Value* y) { + [builder, tile, &output, logical_x](const llvm_ir::IrArray::Index& index, + llvm::Value* y) { output.EmitWriteArrayElement( index, builder->CreateLoad( - builder->CreateGEP(tile, {builder->getInt64(0), x, y}), + builder->CreateGEP(tile, {builder->getInt64(0), logical_x, y}), "output_element"), builder); }, @@ -742,13 +774,14 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { thunk_sequence_->emplace_back(BuildKernelThunk(copy)); VLOG(3) << "Emitting tiled 0-2-1 transposition"; constexpr int64 tile_size = 32; + constexpr int64 num_rows = 8; int64 num_tiles = EmitTranspose021Tiled( GetIrArray(*(copy->operand(0))) .CastToShape(reduced_input_shape, &ir_builder_), GetIrArray(*copy).CastToShape(reduced_output_shape, &ir_builder_), - tile_size, &ir_builder_); - UpdateLaunchDimensions(LaunchDimensions(num_tiles, tile_size), LastThunk(), - ir_emitter_context_->llvm_module()); + tile_size, num_rows, &ir_builder_); + UpdateLaunchDimensions(LaunchDimensions(num_tiles, num_rows * tile_size), + LastThunk(), ir_emitter_context_->llvm_module()); return Status::OK(); } |