aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc1715
1 files changed, 1036 insertions, 679 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index bdb9e77da4..1f31a7f36b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
@@ -48,6 +48,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
@@ -58,10 +59,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
-#include "tensorflow/compiler/xla/service/llvm_ir/ops.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -70,6 +72,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
@@ -79,6 +82,7 @@ namespace gpu {
namespace {
+using llvm_ir::IrArray;
using llvm_ir::IrName;
using tensorflow::gtl::ArraySlice;
using tensorflow::gtl::InlinedVector;
@@ -211,7 +215,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
llvm::LLVMContext& context = module->getContext();
llvm::FunctionType* kernel_type = llvm::FunctionType::get(
/*Result=*/llvm::Type::getVoidTy(context),
- std::vector<llvm::Type*>(args.size(), ir_builder_.getInt8PtrTy()),
+ std::vector<llvm::Type*>(args.size(), b_.getInt8PtrTy()),
/*isVarArg=*/false);
llvm::Function* kernel =
llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage,
@@ -228,7 +232,9 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
kernel->addDereferenceableAttr(arg_no + 1, alloc->size());
kernel->addParamAttr(
arg_no, llvm::Attribute::get(context, llvm::Attribute::Alignment,
- kCudaMallocAlignBytes));
+ alloc->is_entry_computation_parameter()
+ ? kEntryParameterAlignBytes
+ : kXlaAllocatedBufferAlignBytes));
if (alloc->IsPreallocatedTempBuffer()) {
fn_arg->setName("temp_buf");
@@ -247,7 +253,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
nvvm_annotations_node->addOperand(llvm::MDNode::get(
context, {llvm::ConstantAsMetadata::get(kernel),
llvm::MDString::get(context, "kernel"),
- llvm::ConstantAsMetadata::get(ir_builder_.getInt32(1))}));
+ llvm::ConstantAsMetadata::get(b_.getInt32(1))}));
// Update the insert point to the entry basic block.
llvm::BasicBlock* entry_bb =
@@ -255,7 +261,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
// Emit a "return void" at entry_bb's end, and set the insert point before
// that return instruction.
- ir_builder_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb));
+ b_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb));
return kernel;
}
@@ -293,7 +299,7 @@ int ComputeMaxUnrollFactor(const HloInstruction* hlo) {
// range of i32.
// Otherwise, the return type is i64.
llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size,
- llvm::IRBuilder<>* ir_builder) {
+ llvm::IRBuilder<>* b) {
// Find the unnested hlo instructon for which the kernel is generated for.
const HloInstruction* unnested_hlo = hlo;
const HloComputation* computation = hlo->parent();
@@ -314,7 +320,7 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size,
return in_range;
};
- llvm::Type* i64_ty = ir_builder->getInt64Ty();
+ llvm::Type* i64_ty = b->getInt64Ty();
// Check launch dimension
if (!IsInt32(launch_size)) {
return i64_ty;
@@ -343,7 +349,7 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size,
}
}
- return ir_builder->getInt32Ty();
+ return b->getInt32Ty();
}
} // namespace
@@ -355,7 +361,8 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
unroll_factor = ComputeMaxUnrollFactor(hlo);
}
- thunk_sequence_->emplace_back(BuildKernelThunk(hlo, unroll_factor));
+ thunk_sequence_->emplace_back(BuildKernelThunk(
+ hlo, /*implements_whole_instruction=*/true, unroll_factor));
return IrEmitter::DefaultAction(hlo);
}
@@ -369,7 +376,8 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
thunk_sequence_->emplace_back(BuildGemmThunk(dot));
return Status::OK();
}
- thunk_sequence_->emplace_back(BuildKernelThunk(dot));
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(dot, /*implements_whole_instruction=*/true));
return IrEmitter::HandleDot(dot);
}
@@ -379,7 +387,8 @@ Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
}
Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
- thunk_sequence_->emplace_back(BuildKernelThunk(convolution));
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(convolution, /*implements_whole_instruction=*/true));
return IrEmitter::HandleConvolution(convolution);
}
@@ -586,16 +595,17 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
}
}
CHECK(first_reduce != nullptr);
- thunks.push_back(BuildKernelThunk(fusion));
+ thunks.push_back(
+ BuildKernelThunk(fusion, /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
MakeUnique<SequentialThunk>(std::move(thunks), fusion));
- std::vector<llvm_ir::IrArray> parameter_arrays;
+ std::vector<IrArray> parameter_arrays;
for (HloInstruction* operand : fusion->operands()) {
parameter_arrays.push_back(GetIrArray(*operand, *fusion));
}
GpuElementalIrEmitter elemental_emitter(
- hlo_module_config_, ir_emitter_context_->llvm_module(),
- &ir_builder_, GetNestedComputer());
+ hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
+ GetNestedComputer());
FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
TF_RETURN_IF_ERROR(root->Accept(&fused_emitter));
@@ -660,21 +670,22 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
// touching the un-updated elements.
// Set up kernel thunk and fused ir emitter.
- thunk_sequence_->emplace_back(BuildKernelThunk(fusion));
- std::vector<llvm_ir::IrArray> operand_arrays;
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(fusion, /*implements_whole_instruction=*/true));
+ std::vector<IrArray> operand_arrays;
for (HloInstruction* operand : fusion->operands()) {
operand_arrays.push_back(GetIrArray(*operand, *fusion));
}
GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
ir_emitter_context_->llvm_module(),
- &ir_builder_, GetNestedComputer());
+ &b_, GetNestedComputer());
// Shape of the dynamic-update-slice's "update" operand.
Shape update_shape = root->operand(1)->shape();
// Array to write into. Because this is an in-place operation, this is the
// same as operand 0's array.
- llvm_ir::IrArray output_array = GetIrArray(*fusion, *fusion);
+ IrArray output_array = GetIrArray(*fusion, *fusion);
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
update_shape, ir_emitter_context_->device_description());
@@ -685,316 +696,27 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
fusion, operand_arrays, output_array, &elemental_emitter,
- launch_dimensions, &ir_builder_);
+ launch_dimensions, &b_);
}
+
if (ImplementedAsGemm(*fusion)) {
thunk_sequence_->emplace_back(BuildGemmThunk(fusion));
return Status::OK();
}
- CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop);
- int unroll_factor = ComputeMaxUnrollFactor(fusion);
-
- thunk_sequence_->emplace_back(BuildKernelThunk(fusion, unroll_factor));
- return IrEmitter::HandleFusion(fusion);
-}
+ CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop);
-namespace {
-
-// Returns the indices of the first elements of all consecutive subarrays of the
-// given array. For example:
-// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
-std::vector<size_t> ConsecutiveSegments(tensorflow::gtl::ArraySlice<int64> xs) {
- std::vector<size_t> is = {0};
- for (size_t i = 1; i < xs.size(); ++i) {
- if (1 != xs[i] - xs[i - 1]) {
- is.push_back(i);
- }
- }
- return is;
-}
-
-// Merges the sequences of dimensions of the given shape which start at the
-// given indices `segs`.
-Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs,
- const Shape& shape) {
- std::vector<int64> dimensions;
- for (size_t i = 1; i <= segs.size(); ++i) {
- dimensions.push_back(std::accumulate(
- shape.dimensions().begin() + segs[i - 1],
- shape.dimensions().begin() +
- (segs.size() == i ? shape.dimensions().size() : segs[i]),
- 1, std::multiplies<int64>()));
- }
- return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
- dimensions);
-}
-
-// Returns whether the given shapes and permutation are a 0-2-1 transpose, and
-// if so, the normalized and rank-reduced shapes. The shapes must have the same
-// dimensions, so this considers layout only.
-//
-// This function recognizes higher-rank transposes which are elementwise
-// equivalent to a 0-2-1 transpose.
-std::tuple<bool, Shape, Shape> IsTranspose021(const Shape& a, const Shape& b) {
- CHECK(ShapeUtil::Compatible(a, b));
- std::vector<int64> perm(a.dimensions().size());
- {
- auto layout_a_orig = LayoutUtil::MinorToMajor(a);
- std::vector<int64> layout_a(layout_a_orig.rbegin(), layout_a_orig.rend());
- auto layout_b_orig = LayoutUtil::MinorToMajor(b);
- std::vector<int64> layout_b(layout_b_orig.rbegin(), layout_b_orig.rend());
- for (size_t i = 0; i < perm.size(); ++i) {
- perm[i] = PositionInContainer(layout_b, layout_a[i]);
- }
+ if (CheckAndEmitHloWithTile021(fusion)) {
+ return Status::OK();
}
- auto segs = ConsecutiveSegments(perm);
- Shape norm_a =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
- Shape norm_b =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(b);
- if (3 == segs.size() && 0 == perm[0]) {
- Shape reduced_a = MergeDimensions(segs, norm_a);
- Shape reduced_b = ShapeUtil::MakeShapeWithDescendingLayout(
- b.element_type(),
- Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions())));
- return std::make_tuple(true, reduced_a, reduced_b);
- }
- return std::make_tuple(false, ShapeUtil::MakeNil(), ShapeUtil::MakeNil());
-}
-
-// Returns whether the given shapes are potentially of a 0-2-1 transpose.
-// As 0-2-1 is a self-inverse permutation, which shape is input or output is
-// arbitrary.
-bool AreShapesForTranspose021(const Shape& a, const Shape& b) {
- return 3 == b.dimensions().size() &&
- ShapeUtil::Compatible(
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a),
- ShapeUtil::PermuteDimensions(
- {0, 2, 1},
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
- b)));
-}
-
-// Emits a tiled 0-2-1 transpose, assuming both input and output lain out from
-// major to minor. The x- and y- dimensions are tiled in square tiles of edge
-// length `tile_size`. Each thread block of `tile_size` x `num_rows` threads
-// transposes one tile: each thread copies a row from the input to a shared
-// memory tile, then copies a column from the shared memory tile to the output.
-//
-// `tile_size` should usually be same as warp size.
-//
-// Returns (number of tiles = number of thread blocks needed).
-//
-// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient
-// to launch fewer blocks so each transposes many tiles, and
-// in any case, the number of blocks we can launch is limited.
-//
-// This is the same algorithm in CUDA:
-// https://github.com/tensorflow/tensorflow/blob/d2693c8a70567cc78b2e8a9ac8020d321620ca83/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L189
-int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
- const int64 tile_size, const int64 num_rows,
- llvm::IRBuilder<>* builder) {
- // Adds `addend` to the given `dim` of `index`.
- auto offset_dim = [builder](llvm_ir::IrArray::Index index,
- llvm::Value* addend, int64 dim) {
- index[dim] = builder->CreateAdd(index[dim], addend);
- return index;
- };
- CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape()));
-
- Shape input_shape =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
- input.GetShape());
- Shape output_shape =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
- output.GetShape());
- input = input.CastToShape(input_shape, builder);
- output = output.CastToShape(output_shape, builder);
-
- llvm::Type* tile_type = llvm::ArrayType::get(
- llvm::ArrayType::get(input.GetElementLlvmType(), tile_size),
- // One extra here to avoid share memory bank conflict
- tile_size + 1);
- auto* tile = new llvm::GlobalVariable(
- *builder->GetInsertBlock()->getParent()->getParent(), tile_type,
- /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
- llvm::UndefValue::get(tile_type), "tile", nullptr,
- llvm::GlobalValue::NotThreadLocal,
- /*AddressSpace=*/3 /* GPU shared memory */);
-
- // let x = threadIdx.x
- llvm::Value* x = llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder);
- llvm_ir::AddRangeMetadata(0, num_rows * tile_size,
- static_cast<llvm::Instruction*>(x));
- x = builder->CreateIntCast(x, builder->getInt64Ty(), /*isSigned=*/true,
- "thread.id.x");
-
- // computing logical thread ids
- // logical_x = x % tile_size
- auto logical_x = builder->CreateURem(x, builder->getInt64(tile_size));
-
- // logical_y = x / tile_size
- auto logical_y = builder->CreateUDiv(x, builder->getInt64(tile_size));
-
- // `emit_cp` emits equivalent to following pseudocode:
- // if (tile_size == tile_width && tile_size == tile_height) {
- // unroll for (i in range(0, tile_size, num_rows)) {
- // emit_cp_element(index + {0, i, 0}, y + logical_y);
- // }
- // } else if (x < tile_width) {
- // tile_height_upperbound = ceil(tile_height / num_rows) * num_rows;
- // for (i in range(0, tile_height_upperbound, num_rows)) {
- // y_loc = i + logical_y;
- // if (y_loc < tile_height)
- // emit_cp_element(index + {0, i, 0}, y_loc);
- // }
- // }
- //
- // We use this to emit both the copy from input to tile and the copy from tile
- // to output.
- //
- // `index` is the origin of the row or column in the input or output array.
- //
- // `emit_cp_element(index, y)` emits code to copy a single element between the
- // tile and the input or output array, where `y` is the `y`-position in the
- // tile, whether which is row or column is a function of whether we're copying
- // from input or to output, and `index` is the index into the input or output
- // array.
- auto emit_cp_tile = [builder, tile_size, &offset_dim, num_rows, logical_x,
- logical_y](
- std::function<void(const llvm_ir::IrArray::Index&,
- llvm::Value*)>
- emit_cp_element,
- llvm::Value* tile_width, llvm::Value* tile_height,
- const llvm_ir::IrArray::Index& index,
- const string& loop_name) {
- llvm_ir::LlvmIfData if_not_last_row = llvm_ir::EmitIfThenElse(
- builder->CreateAnd(
- builder->CreateICmpEQ(builder->getInt64(tile_size), tile_width),
- builder->CreateICmpEQ(builder->getInt64(tile_size), tile_height)),
- "not_last_row", builder);
- builder->SetInsertPoint(if_not_last_row.true_block->getTerminator());
- for (int64 i = 0; i < tile_size; i += num_rows) {
- auto source_idx = offset_dim(index, builder->getInt64(i), /*dim=*/1);
- auto y_loc = builder->CreateAdd(builder->getInt64(i), logical_y);
- emit_cp_element(source_idx, y_loc);
- }
- builder->SetInsertPoint(if_not_last_row.false_block->getTerminator());
- llvm_ir::LlvmIfData if_in_tile = llvm_ir::EmitIfThenElse(
- builder->CreateICmpULT(logical_x, tile_width), "x_in_tile", builder);
- builder->SetInsertPoint(if_in_tile.true_block->getTerminator());
-
- // tile_height_upper_bound = ceil(tile_height / num_rows) * num_rows
- auto tile_height_upper_bound = builder->CreateMul(
- builder->CreateUDiv(
- builder->CreateAdd(tile_height, builder->getInt64(num_rows - 1)),
- builder->getInt64(num_rows)),
- builder->getInt64(num_rows));
-
- auto loop = llvm_ir::ForLoop::EmitForLoop(
- loop_name, builder->getInt64(0), tile_height_upper_bound,
- builder->getInt64(num_rows), builder);
- llvm_ir::SetToFirstInsertPoint(loop->GetHeaderBasicBlock(), builder);
- builder->SetInsertPoint(loop->GetBodyBasicBlock()->getTerminator());
-
- auto y_loc = builder->CreateAdd(loop->GetIndVarValue(), logical_y);
- auto if_y_in_tile = llvm_ir::EmitIfThenElse(
- builder->CreateICmpULT(y_loc, tile_height), "y_in_tile", builder);
- builder->SetInsertPoint(if_y_in_tile.true_block->getTerminator());
-
- emit_cp_element(offset_dim(index, loop->GetIndVarValue(), /*dim=*/1),
- y_loc);
- builder->SetInsertPoint(if_not_last_row.after_block->getTerminator());
- };
-
- auto input_dims_in_tiles = input_shape.dimensions();
- // Unpermuted dimensions are untiled.
- for (int i = 1; i < 3; ++i) {
- input_dims_in_tiles[i] =
- CeilOfRatio<int64>(input_dims_in_tiles[i], tile_size);
- }
- int64 num_tiles =
- std::accumulate(input_dims_in_tiles.begin(), input_dims_in_tiles.end(), 1,
- std::multiplies<int64>());
- const llvm_ir::IrArray::Index input_tile_index(
- /*linear=*/builder->CreateIntCast(
- llvm_ir::AddRangeMetadata(
- 0, num_tiles,
- static_cast<llvm::Instruction*>(llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {},
- builder))),
- builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"),
- ShapeUtil::MakeShapeWithDescendingLayout(
- PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)),
- builder);
- const llvm_ir::IrArray::Index input_tile_origin = ({
- llvm_ir::IrArray::Index index = input_tile_index;
- for (int i = 1; i < 3; ++i) {
- index[i] = builder->CreateMul(index[i], builder->getInt64(tile_size),
- "tile_origin." + std::to_string(i));
- }
- index;
- });
- const llvm_ir::IrArray::Index input_index =
- offset_dim(offset_dim(input_tile_origin, logical_x, /*dim=*/2), logical_y,
- /*dim=*/1);
- std::vector<llvm::Value*> tile_dims(input_shape.dimensions().size());
- // Only last row or column may not have full size.
- for (int i = 1; i < 3; ++i) {
- tile_dims[i] = builder->CreateSelect(
- builder->CreateICmpEQ(input_tile_index[i],
- builder->getInt64(input_dims_in_tiles[i] - 1)),
- builder->getInt64(input_shape.dimensions(i) -
- (input_dims_in_tiles[i] - 1) * tile_size),
- builder->getInt64(tile_size), "tile_size");
- }
-
- // Load data from input memory to shared memory tile.
- emit_cp_tile(
- // tile[y, x] = input_array[index]
- [builder, tile, &input, logical_x](const llvm_ir::IrArray::Index& index,
- llvm::Value* y) {
- builder->CreateStore(
- input.EmitReadArrayElement(index, builder, "input_element"),
- builder->CreateGEP(tile, {builder->getInt64(0), y, logical_x}));
- },
- tile_dims[2], tile_dims[1], input_index, "input");
+ int unroll_factor = ComputeMaxUnrollFactor(fusion);
- // Wait for all threads to reach this point, lest we copy a value from tile to
- // output before the other thread copies it from input to tile.
- // This is `__syncthreads` in CUDA.
- llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, builder);
-
- const llvm_ir::IrArray::Index output_tile_index(
- Permute({0, 2, 1}, input_tile_index.multidim()));
- const llvm_ir::IrArray::Index output_tile_origin(
- Permute({0, 2, 1}, input_tile_origin.multidim()));
- const llvm_ir::IrArray::Index output_index =
- offset_dim(offset_dim(output_tile_origin, logical_x, /*dim=*/2),
- logical_y, /*dim=*/1);
-
- // Store data from shared memory tile to output memory.
- emit_cp_tile(
- // output_array[index] = tile[x, y]
- [builder, tile, &output, logical_x](const llvm_ir::IrArray::Index& index,
- llvm::Value* y) {
- output.EmitWriteArrayElement(
- index,
- builder->CreateLoad(
- builder->CreateGEP(tile, {builder->getInt64(0), logical_x, y}),
- "output_element"),
- builder);
- },
- tile_dims[1], tile_dims[2], output_index, "output");
-
- return num_tiles;
+ thunk_sequence_->emplace_back(BuildKernelThunk(
+ fusion, /*implements_whole_instruction=*/true, unroll_factor));
+ return IrEmitter::HandleFusion(fusion);
}
-} // namespace
-
Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
if (ImplementedAsHostToDeviceMemcpy(ir_emitter_context_->buffer_assignment(),
*copy)) {
@@ -1006,25 +728,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
thunk_sequence_->emplace_back(BuildDeviceToDeviceCopyThunk(copy));
return Status::OK();
}
- bool is_transpose_021;
- Shape reduced_input_shape, reduced_output_shape;
- std::tie(is_transpose_021, reduced_input_shape, reduced_output_shape) =
- IsTranspose021(copy->operand(0)->shape(), copy->shape());
- if (is_transpose_021 &&
- reduced_input_shape.dimensions(1) >= kMinDimensionToTransposeTiled &&
- reduced_input_shape.dimensions(2) >= kMinDimensionToTransposeTiled) {
- thunk_sequence_->emplace_back(BuildKernelThunk(copy));
- VLOG(3) << "Emitting tiled 0-2-1 transposition";
- constexpr int64 tile_size = 32;
- constexpr int64 num_rows = 8;
- int64 num_tiles = EmitTranspose021Tiled(
- GetIrArray(*copy->operand(0), *copy)
- .CastToShape(reduced_input_shape, &ir_builder_),
- GetIrArray(*copy, *copy)
- .CastToShape(reduced_output_shape, &ir_builder_),
- tile_size, num_rows, &ir_builder_);
- UpdateLaunchDimensions(LaunchDimensions(num_tiles, num_rows * tile_size),
- LastThunk(), ir_emitter_context_->llvm_module());
+ if (CheckAndEmitHloWithTile021(copy)) {
return Status::OK();
}
@@ -1032,7 +736,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
}
Status IrEmitterUnnested::EmitExtraOutputsForReduce(
- const HloInstruction* reduce, const llvm_ir::IrArray::Index& index,
+ const HloInstruction* reduce, const IrArray::Index& index,
tensorflow::gtl::ArraySlice<
std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
@@ -1040,11 +744,11 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce(
const HloInstruction* output = reduce->parent()->FusionInstruction();
llvm::Value* extra_output_address =
GetIrArray(*output, *output, extra_output_gens[i].second)
- .EmitArrayElementAddress(index, &ir_builder_,
+ .EmitArrayElementAddress(index, &b_,
"extra_output_element_address");
TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
extra_output_gens[i].first(index));
- ir_builder_.CreateStore(extra_output_ir_value, extra_output_address);
+ b_.CreateStore(extra_output_ir_value, extra_output_address);
}
return Status::OK();
}
@@ -1074,12 +778,10 @@ Status IrEmitterUnnested::EmitReductionToScalar(
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
tiled_input_shape, ir_emitter_context_->device_description());
- llvm::Type* index_ty = GetIndexTypeForKernel(
- reduce,
- launch_dimensions.block_count() * launch_dimensions.threads_per_block(),
- &ir_builder_);
+ llvm::Type* index_ty =
+ GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_);
- auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_ty, c);
};
@@ -1121,59 +823,57 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// // and threads_per_block is a multiple of warpSize.
// reduce_kernel<<<num_blocks, threads_per_block>>>();
//
- auto loop_body_emitter =
- [=](const llvm_ir::IrArray::Index& tile_index) -> Status {
+ auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status {
const int num_reduces = reducers.size();
llvm::Type* element_ir_type =
llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
std::vector<llvm::Value*> partial_reduction_result_addresses;
for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
- element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." + llvm::Twine(i));
- TF_ASSIGN_OR_RETURN(
- llvm::Value* const init_ir_value,
- init_value_gens[i](llvm_ir::IrArray::Index(index_ty)));
- ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
+ llvm::Value* partial_reduction_result_address =
+ b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." + llvm::Twine(i));
+ TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
+ init_value_gens[i](IrArray::Index(index_ty)));
+ b_.CreateStore(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
llvm::Value* x_in_tiles = tile_index[0];
- x_in_tiles = ir_builder_.CreateZExtOrTrunc(x_in_tiles, index_ty);
+ x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty);
// Emit an inner for-loop that reduces the elements in the tile.
auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
llvm_ir::ForLoop::EmitForLoop(
- "element_id_in_tile", index_typed_const(0),
- index_typed_const(kTileSize), index_typed_const(1), &ir_builder_);
+ "element_id_in_tile", index_typed_constant(0),
+ index_typed_constant(kTileSize), index_typed_constant(1), &b_);
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
- &ir_builder_);
- llvm::Value* x = ir_builder_.CreateNSWAdd(
- ir_builder_.CreateNSWMul(x_in_tiles, index_typed_const(kTileSize)),
+ &b_);
+ llvm::Value* x = b_.CreateNSWAdd(
+ b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)),
tile_element_loop->GetIndVarValue());
// Unless we know the tile is entirely in bounds, we have to emit a
// x-in-bounds check before reading from the input.
if (!tile_in_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpULT(x, index_typed_const(num_elems)),
- "x_in_bounds", &ir_builder_);
+ b_.CreateICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds",
+ &b_);
// Emit code that reads the input element and accumulates it to
// the partial reduction result.
- llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_);
}
- llvm_ir::IrArray::Index input_index(
- /*linear=*/x, input_shape, &ir_builder_);
- llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
+ IrArray::Index input_index(
+ /*linear=*/x, input_shape, &b_);
+ llvm::Value* input_address = b_.CreateAlloca(element_ir_type);
for (int i = 0; i != num_reduces; ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
input_gens[i](input_index));
- ir_builder_.CreateStore(input_ir_value, input_address);
+ b_.CreateStore(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], input_address},
@@ -1184,49 +884,48 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's
// immediately beyond the tile.
- llvm::Value* x_end = ir_builder_.CreateNSWAdd(
- index_typed_const(kTileSize),
- ir_builder_.CreateNSWMul(x_in_tiles, index_typed_const(kTileSize)));
+ llvm::Value* x_end = b_.CreateNSWAdd(
+ index_typed_constant(kTileSize),
+ b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)));
// The tile is entirely in bound if all_threads_in_bounds or
// x_end <= num_elems.
- llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
- ir_builder_.CreateICmpULE(x_end, index_typed_const(num_elems)),
- ir_builder_.getInt1(all_threads_in_bounds));
+ llvm::Value* tile_in_bounds =
+ b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(num_elems)),
+ b_.getInt1(all_threads_in_bounds));
llvm_ir::LlvmIfData if_tile_in_bounds_data =
- llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_);
- llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block,
- &ir_builder_);
+ llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_);
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_);
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true));
- llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block,
- &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, &b_);
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false));
// After the if-then-else statement on tile_in_bounds, emit calls to
// shfl_down that accumulate the partial reduction results of all threads
// from the warp.
- llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block,
- &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, &b_);
int bit_width = llvm_ir::GetSizeInBits(element_ir_type);
// bitcast cannot be applied to aggregate types (even packed ones), so we
// instead bitcast addresses of load/store to intN* of the same bit-width.
llvm::Type* shuffle_ir_type = element_ir_type->isStructTy()
- ? ir_builder_.getIntNTy(bit_width)
+ ? b_.getIntNTy(bit_width)
: element_ir_type;
for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1;
shuffle_distance /= 2) {
- llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca(
- element_ir_type, nullptr, "result_from_other_lane");
+ llvm::Value* result_from_other_lane =
+ b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane");
for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result = ir_builder_.CreateLoad(
- ir_builder_.CreateBitCast(partial_reduction_result_addresses[i],
- shuffle_ir_type->getPointerTo()),
+ llvm::Value* partial_reduction_result = b_.CreateLoad(
+ b_.CreateBitCast(partial_reduction_result_addresses[i],
+ shuffle_ir_type->getPointerTo()),
"partial_reduction_result");
- ir_builder_.CreateStore(
- EmitShuffleDown(partial_reduction_result,
- ir_builder_.getInt32(shuffle_distance),
- &ir_builder_),
- ir_builder_.CreateBitCast(result_from_other_lane,
- shuffle_ir_type->getPointerTo()));
+ CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0)
+ << "Requires block size a multiple of the warp size, otherwise we "
+ "will read undefined elements.";
+ b_.CreateStore(
+ EmitFullWarpShuffleDown(partial_reduction_result,
+ b_.getInt32(shuffle_distance), &b_),
+ b_.CreateBitCast(result_from_other_lane,
+ shuffle_ir_type->getPointerTo()));
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], result_from_other_lane},
@@ -1240,24 +939,23 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// Emit an atomic operation that accumulates the partial reduction result of
// lane 0 (which holds the partially accumulated result for its warp) to the
// output element.
- llvm::Value* lane_id = ir_builder_.CreateURem(
- x_in_tiles, index_typed_const(kWarpSize), "lane_id");
+ llvm::Value* lane_id =
+ b_.CreateURem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id");
llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpEQ(lane_id, index_typed_const(0)),
- "lane_id_is_zero", &ir_builder_);
- llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
- &ir_builder_);
+ b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero",
+ &b_);
+ llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
for (int i = 0; i != num_reduces; ++i) {
llvm::Value* output_address =
GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
- llvm_ir::IrArray::Index(
- /*linear=*/ir_builder_.getInt64(0),
+ IrArray::Index(
+ /*linear=*/b_.getInt64(0),
ShapeUtil::GetSubshape(output->shape(),
reduce_output_shapes[i]),
- &ir_builder_),
- &ir_builder_, "output_element_address");
+ &b_),
+ &b_, "output_element_address");
TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
*reducers[i], output_address, partial_reduction_result_addresses[i]));
}
@@ -1271,7 +969,7 @@ Status IrEmitterUnnested::EmitReductionToScalar(
static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(),
ir_emitter_context_->llvm_module());
return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape,
- launch_dimensions, &ir_builder_)
+ launch_dimensions, &b_)
.EmitLoop(IrName(reduce), index_ty);
}
@@ -1284,8 +982,8 @@ Status IrEmitterUnnested::EmitColumnReduction(
tensorflow::gtl::ArraySlice<
std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
- // Divide the input matrix into tiles of size Kx1. For example, when the
- // input matrix is 4x4 and K=2, the tiled matrix looks like
+ // Divide the input matrix into tiles of size KxL. For example, when the
+ // input matrix is 4x4, K=2, and L=1 the tiled matrix looks like
//
// 0123
// 0123
@@ -1297,100 +995,131 @@ Status IrEmitterUnnested::EmitColumnReduction(
//
// We choose 128 as the tile size based on empirical evidence. It's big enough
// to reduce the amount of atomic adds in the end, maximizing the memory
- // bandwidth.
- constexpr int64 kTileSize = 128;
+ // bandwidth. A tile width of 2 allows for high memory bandwidth utilization
+ // on 16b input data.
+ constexpr int64 kTileHeight = 128;
+ constexpr int64 kTileWidth = 2;
- // If the height is not a multiple of the tile size, we pad the bottom of the
+ // If the height is not a multiple of kTileHeight, we pad the bottom of the
// input matrix.
- const int64 height_in_tiles = CeilOfRatio(height, kTileSize);
- Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout(
- reduce->shape().element_type(), {height_in_tiles, width}, {1, 0});
+ const int64 height_in_tiles = CeilOfRatio(height, kTileHeight);
+ // If width is not a multiple of kTileWidth the rightmost thread will process
+ // fewer input elements.
+ const int64 width_in_tiles = CeilOfRatio(width, kTileWidth);
+ Shape tiled_input_shape =
+ ShapeUtil::MakeShapeWithLayout(reduce->shape().element_type(),
+ {height_in_tiles, width_in_tiles}, {1, 0});
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
tiled_input_shape, ir_emitter_context_->device_description());
// TODO(b/110211620): Convert to use i32 index_type when it is possible.
- llvm::Type* index_ty = ir_builder_.getInt64Ty();
+ llvm::Type* index_ty = b_.getInt64Ty();
- auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_ty, c);
};
// for (linear_index = threadIdx.x + blockIdx.x * blockDim.x;
- // linear_index < height_in_tiles * width;
+ // linear_index < height_in_tiles * width_in_tiles;
// linear_index += blockDim.x * gridDim.x) {
- // y_in_tiles = linear_index / width;
- // x = linear_index % width;
+ // y_in_tiles = linear_index / width_in_tiles;
+ // x_in_tiles = linear_index % width_in_tiles;
//
- // partial_result = init_value;
- // if (height % kTileSize == 0 ||
- // y_in_tiles * kTileSize + kTileSize <= height) {
- // for (element_id_in_tile : range(kTileSize)) {
- // y = y_in_tiles * kTileSize + element_id_in_tile;
- // partial_result = Reducer(partial_result, input[y][x]);
+ // partial_results[kTileWidth] = init_values;
+ // tile_in_y_bounds = height % kTileHeight == 0 ||
+ // y_in_tiles * kTileHeight + kTileHeight <= height;
+ // tile_in_x_bounds = width % kTileWidth == 0 ||
+ // x_in_tiles * kTileWidth + kTileWidth <= width;
+ // // The implementation handles y and x bound checks separately.
+ // if (tile_in_y_bounds && tile_in_x_bounds) {
+ // for (y_offset : range(kTileHeight)) {
+ // y = y_in_tiles * kTileHeight + y_offset;
+ // for (x_offset : range(kTileWidth)) {
+ // x = x_in_tiles * kTileWidth + x_offset;
+ // partial_result = Reducer(partial_result[x_offset], input[y][x]);
+ // }
// }
// } else {
- // for (element_id_in_tile : range(kTileSize)) {
- // y = y_in_tiles * kTileSize + element_id_in_tile;
- // if (y < height) {
- // partial_result = Reducer(partial_result, input[y][x]);
+ // for (y_offset : range(kTileHeight)) {
+ // y = y_in_tiles * kTileHeight + y_offset;
+ // for (y_offset : range(kTileHeight)) {
+ // x = x_in_tiles * kTileWidth + x_offset;
+ // if (y < height && x < width) {
+ // partial_result = Reducer(partial_result, input[y][x]);
+ // }
// }
// }
// }
- // AtomicReducer(&output[x], partial_result);
+ // for (x_offset : range(kTileWidth)) {
+ // AtomicReducer(&output[x + x_offset], partial_result[x_offset]);
+ // }
// }
- auto loop_body_emitter =
- [=](const llvm_ir::IrArray::Index& tile_index) -> Status {
+ auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status {
const int num_reduces = reducers.size();
// Emit the loop body that reduces one tile.
llvm::Type* element_ir_type =
llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
std::vector<llvm::Value*> partial_reduction_result_addresses;
for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
- element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." + llvm::Twine(i));
- TF_ASSIGN_OR_RETURN(
- llvm::Value* const init_ir_value,
- init_value_gens[i](llvm_ir::IrArray::Index(index_ty)));
- ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
- partial_reduction_result_addresses.push_back(
- partial_reduction_result_address);
+ for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
+ llvm::Value* partial_reduction_result_address =
+ b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." +
+ llvm::Twine(i * kTileWidth + x_offset));
+ TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
+ init_value_gens[i](IrArray::Index(index_ty)));
+ b_.CreateStore(init_ir_value, partial_reduction_result_address);
+ partial_reduction_result_addresses.push_back(
+ partial_reduction_result_address);
+ }
}
// Emit an inner for-loop that partially reduces the elements in the given
// tile.
llvm::Value* y_in_tiles = tile_index[0];
- llvm::Value* x = tile_index[1];
+ llvm::Value* x_in_tiles = tile_index[1];
- y_in_tiles = ir_builder_.CreateZExtOrTrunc(y_in_tiles, index_ty);
- x = ir_builder_.CreateZExtOrTrunc(x, index_ty);
+ y_in_tiles = b_.CreateZExtOrTrunc(y_in_tiles, index_ty);
+ x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty);
- auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
+ auto emit_tile_element_loop = [=](bool tile_in_y_bounds,
+ bool tile_in_x_bounds) -> Status {
std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
llvm_ir::ForLoop::EmitForLoop(
- "element_id_in_tile", index_typed_const(0),
- index_typed_const(kTileSize), index_typed_const(1), &ir_builder_);
+ "element_id_in_tile", index_typed_constant(0),
+ index_typed_constant(kTileHeight), index_typed_constant(1), &b_);
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
- &ir_builder_);
- llvm::Value* y = ir_builder_.CreateNSWAdd(
- ir_builder_.CreateNSWMul(y_in_tiles, index_typed_const(kTileSize)),
+ &b_);
+ llvm::Value* y = b_.CreateNSWAdd(
+ b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)),
tile_element_loop->GetIndVarValue());
- // Unless we know the tile is entirely in bounds, we have to emit a
- // y-in-bounds check before reading from the input.
- if (!tile_in_bounds) {
+ // Unless we know that y is in bounds, we have to emit a check before
+ // reading from the input.
+ if (!tile_in_y_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpULT(y, index_typed_const(height)),
- "y_in_bounds", &ir_builder_);
+ b_.CreateICmpULT(y, index_typed_constant(height)), "y_in_bounds",
+ &b_);
// Emit code that reads the input element and accumulates it to
// the partial reduction result.
- llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_);
}
- llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
- {
+ for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
+ llvm::Value* x = b_.CreateNSWAdd(
+ b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
+ index_typed_constant(x_offset));
+ // Unless we know that x is in bounds, we have to emit a check before
+ // reading from the input.
+ if (!tile_in_x_bounds) {
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ b_.CreateICmpULT(x, index_typed_constant(width)), "x_in_bounds",
+ &b_);
+ llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_);
+ }
+ llvm::Value* input_address = b_.CreateAlloca(element_ir_type);
// {y,x} is an index to input_matrix_shape [height,width]. We need to
// convert that to an index to input_shape (the shape of the operand of
// "reduce"). This conversion is composed of a transposition from
@@ -1406,67 +1135,95 @@ Status IrEmitterUnnested::EmitColumnReduction(
const Shape input_matrix_shape =
ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(),
{height, width});
- const llvm_ir::IrArray::Index input_matrix_index(
- {y, x}, input_matrix_shape, &ir_builder_);
- const llvm_ir::IrArray::Index input_index =
+ const IrArray::Index input_matrix_index({y, x}, input_matrix_shape,
+ &b_);
+ const IrArray::Index input_index =
input_matrix_index
.SourceIndexOfReshape(input_matrix_shape,
- normalized_input_shape, &ir_builder_)
+ normalized_input_shape, &b_)
.SourceIndexOfTranspose(normalized_input_shape, input_shape,
- transpose_dimension_mapping,
- &ir_builder_);
+ transpose_dimension_mapping, &b_);
for (int i = 0; i != num_reduces; ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
input_gens[i](input_index));
- ir_builder_.CreateStore(input_ir_value, input_address);
+ b_.CreateStore(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
- {partial_reduction_result_addresses[i], input_address},
- partial_reduction_result_addresses[i]));
+ {partial_reduction_result_addresses[i * kTileWidth + x_offset],
+ input_address},
+ partial_reduction_result_addresses[i * kTileWidth + x_offset]));
+ TF_RETURN_IF_ERROR(EmitExtraOutputsForReduce(reduce, input_index,
+ extra_output_gens));
}
- return EmitExtraOutputsForReduce(reduce, input_index,
- extra_output_gens);
}
+ return Status::OK();
};
- // y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's
- // immediately beyond the tile.
- llvm::Value* y_end = ir_builder_.CreateNSWAdd(
- index_typed_const(kTileSize),
- ir_builder_.CreateNSWMul(y_in_tiles, index_typed_const(kTileSize)));
- llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
- ir_builder_.CreateICmpULE(y_end, index_typed_const(height)),
- ir_builder_.getInt1(height % kTileSize == 0));
- // The tile is entirely in bound if "height" is a multiple of kTileSize or
+ // y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location
+ // that's immediately beyond the tile.
+ llvm::Value* y_end = b_.CreateNSWAdd(
+ index_typed_constant(kTileHeight),
+ b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)));
+ // x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location
+ // that's immediately beyond the tile.
+ llvm::Value* x_end = b_.CreateNSWAdd(
+ index_typed_constant(kTileWidth),
+ b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)));
+ llvm::Value* tile_in_y_bounds =
+ b_.CreateOr(b_.CreateICmpULE(y_end, index_typed_constant(height)),
+ b_.getInt1(height % kTileHeight == 0));
+ llvm::Value* tile_in_x_bounds =
+ b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(width)),
+ b_.getInt1(width % kTileWidth == 0));
+ // The tile is in y bounds if "height" is a multiple of kTileHeight or
// y_end <= height.
- llvm_ir::LlvmIfData if_tile_in_bounds_data =
- llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_);
- llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block,
- &ir_builder_);
- TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true));
- llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block,
- &ir_builder_);
- TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false));
-
- // After the if-then-else statement on tile_in_bounds, emit atomic
- // operations to accumulate the partial reduction result to the output
- // element.
- llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block,
- &ir_builder_);
+ llvm_ir::LlvmIfData if_tile_in_y_bounds_data =
+ llvm_ir::EmitIfThenElse(tile_in_y_bounds, "tile_in_y_bounds", &b_);
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.true_block, &b_);
+ // The tile is in x bounds if "width" is a multiple of kTileWidth or
+ // x_end <= width.
+ llvm_ir::LlvmIfData if_tile_in_x_bounds_data =
+ llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_);
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_);
+ TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true,
+ /*tile_in_x_bounds=*/true));
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_);
+ TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true,
+ /*tile_in_x_bounds=*/false));
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.false_block, &b_);
+ if_tile_in_x_bounds_data =
+ llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_);
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_);
+ TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false,
+ /*tile_in_x_bounds=*/true));
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_);
+ TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false,
+ /*tile_in_x_bounds=*/false));
+
+ // After the nested if-then-else statement on tile_in_y_bounds and
+ // tile_in_x_bounds, emit atomic operations to accumulate the partial
+ // reduction result to the output element.
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.after_block, &b_);
const HloInstruction* output =
reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* output_address =
- GetIrArray(*output, *output, reduce_output_shapes[i])
- .EmitArrayElementAddress(
- llvm_ir::IrArray::Index(
- x,
- ShapeUtil::GetSubshape(output->shape(),
- reduce_output_shapes[i]),
- &ir_builder_),
- &ir_builder_, "output_element_address");
- TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
- *reducers[i], output_address, partial_reduction_result_addresses[i]));
+ for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
+ llvm::Value* x = b_.CreateNSWAdd(
+ b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
+ index_typed_constant(x_offset));
+ llvm::Value* output_address =
+ GetIrArray(*output, *output, reduce_output_shapes[i])
+ .EmitArrayElementAddress(
+ IrArray::Index(
+ x,
+ ShapeUtil::GetSubshape(output->shape(),
+ reduce_output_shapes[i]),
+ &b_),
+ &b_, "output_element_address");
+ TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
+ *reducers[i], output_address,
+ partial_reduction_result_addresses[i * kTileWidth + x_offset]));
+ }
}
return Status::OK();
};
@@ -1478,7 +1235,7 @@ Status IrEmitterUnnested::EmitColumnReduction(
static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(),
ir_emitter_context_->llvm_module());
return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape,
- launch_dimensions, &ir_builder_)
+ launch_dimensions, &b_)
.EmitLoop(IrName(reduce), index_ty);
}
@@ -1628,28 +1385,25 @@ Status IrEmitterUnnested::EmitRowReduction(
{depth / z_tile_size, height, width_in_tiles}, {2, 1, 0});
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
tiled_input_shape, ir_emitter_context_->device_description());
- llvm::Type* index_ty = GetIndexTypeForKernel(
- reduce,
- launch_dimensions.block_count() * launch_dimensions.threads_per_block(),
- &ir_builder_);
+ llvm::Type* index_ty =
+ GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_);
- auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_ty, c);
};
- auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) {
+ auto loop_body_emitter = [=](const IrArray::Index& tile_index) {
const int num_reduces = reducers.size();
llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(
input_shape.element_type(), ir_emitter_context_->llvm_module());
std::vector<llvm::Value*> partial_reduction_result_addresses;
for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
- element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." + llvm::Twine(i));
- TF_ASSIGN_OR_RETURN(
- llvm::Value* const init_ir_value,
- init_value_gens[i](llvm_ir::IrArray::Index(index_ty)));
- ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
+ llvm::Value* partial_reduction_result_address =
+ b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." + llvm::Twine(i));
+ TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
+ init_value_gens[i](IrArray::Index(index_ty)));
+ b_.CreateStore(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
@@ -1658,25 +1412,25 @@ Status IrEmitterUnnested::EmitRowReduction(
llvm::Value* y = tile_index[1];
llvm::Value* x_tile = tile_index[2];
- x_tile = ir_builder_.CreateZExtOrTrunc(x_tile, index_ty);
+ x_tile = b_.CreateZExtOrTrunc(x_tile, index_ty);
llvm::Value* warp_id =
- ir_builder_.CreateUDiv(x_tile, index_typed_const(kWarpSize), "warp_id");
+ b_.CreateUDiv(x_tile, index_typed_constant(kWarpSize), "warp_id");
llvm::Value* lane_id =
- ir_builder_.CreateURem(x_tile, index_typed_const(kWarpSize), "lane_id");
+ b_.CreateURem(x_tile, index_typed_constant(kWarpSize), "lane_id");
// The x-location of the last element in this z-x-tile.
// last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size);
- llvm::Value* last_x = ir_builder_.CreateNSWAdd(
- lane_id, ir_builder_.CreateNSWMul(
- index_typed_const(kWarpSize),
- ir_builder_.CreateNSWAdd(
- index_typed_const(x_tile_size - 1),
- ir_builder_.CreateNSWMul(
- warp_id, index_typed_const(x_tile_size)))));
+ llvm::Value* last_x = b_.CreateNSWAdd(
+ lane_id,
+ b_.CreateNSWMul(
+ index_typed_constant(kWarpSize),
+ b_.CreateNSWAdd(
+ index_typed_constant(x_tile_size - 1),
+ b_.CreateNSWMul(warp_id, index_typed_constant(x_tile_size)))));
KernelSupportLibrary ksl(
- &ir_builder_,
+ &b_,
/*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll,
/*prevent_vectorization=*/false);
@@ -1685,22 +1439,22 @@ Status IrEmitterUnnested::EmitRowReduction(
auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds,
int64 x_tile_loop_bound) -> Status {
auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status {
- llvm::Value* z = ir_builder_.CreateNSWAdd(
+ llvm::Value* z = b_.CreateNSWAdd(
z_indvar,
- ir_builder_.CreateNSWMul(index_typed_const(z_tile_size), z_tile));
+ b_.CreateNSWMul(index_typed_constant(z_tile_size), z_tile));
TF_RETURN_IF_ERROR(ksl.For(
"x_tile",
- /*start=*/index_typed_const(0),
- /*end=*/index_typed_const(x_tile_loop_bound),
+ /*start=*/index_typed_constant(0),
+ /*end=*/index_typed_constant(x_tile_loop_bound),
/*step=*/1, [&](llvm::Value* x_indvar) -> Status {
// x = lane_id +
// warpSize * (element_id_in_x_tile + warp_id * x_tile_size);
- llvm::Value* x = ir_builder_.CreateNSWAdd(
+ llvm::Value* x = b_.CreateNSWAdd(
lane_id,
- ir_builder_.CreateNSWMul(
- index_typed_const(kWarpSize),
- ir_builder_.CreateNSWAdd(
- x_indvar, ir_builder_.CreateNSWMul(
+ b_.CreateNSWMul(
+ index_typed_constant(kWarpSize),
+ b_.CreateNSWAdd(
+ x_indvar, b_.CreateNSWMul(
warp_id, llvm::ConstantInt::get(
index_ty, x_tile_size)))));
@@ -1709,17 +1463,16 @@ Status IrEmitterUnnested::EmitRowReduction(
if (!x_tile_in_bounds) {
llvm_ir::LlvmIfData if_x_in_bounds_data =
llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpULT(x, index_typed_const(width)),
- "x_in_bounds", &ir_builder_);
- // Points ir_builder_ to the then-block.
+ b_.CreateICmpULT(x, index_typed_constant(width)),
+ "x_in_bounds", &b_);
+ // Points b_ to the then-block.
llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block,
- &ir_builder_);
+ &b_);
}
// Emit code that reads the input element and accumulates it
// to the partial reduction result.
- llvm::Value* input_address =
- ir_builder_.CreateAlloca(element_ir_type);
+ llvm::Value* input_address = b_.CreateAlloca(element_ir_type);
{
// {z,y,x} is an index to input_3d_tensor_shape
// [depth,height,width]. We need to convert that to an index
@@ -1737,21 +1490,20 @@ Status IrEmitterUnnested::EmitRowReduction(
const Shape input_3d_tensor_shape =
ShapeUtil::MakeShapeWithDescendingLayout(
input_shape.element_type(), {depth, height, width});
- const llvm_ir::IrArray::Index input_3d_tensor_index(
- {z, y, x}, input_3d_tensor_shape, &ir_builder_);
- const llvm_ir::IrArray::Index input_index =
+ const IrArray::Index input_3d_tensor_index(
+ {z, y, x}, input_3d_tensor_shape, &b_);
+ const IrArray::Index input_index =
input_3d_tensor_index
.SourceIndexOfReshape(input_3d_tensor_shape,
- normalized_input_shape,
- &ir_builder_)
+ normalized_input_shape, &b_)
.SourceIndexOfTranspose(
normalized_input_shape, input_shape,
- transpose_dimension_mapping, &ir_builder_);
+ transpose_dimension_mapping, &b_);
for (int i = 0; i != num_reduces; ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
input_gens[i](input_index));
- ir_builder_.CreateStore(input_ir_value, input_address);
+ b_.CreateStore(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], input_address},
@@ -1765,14 +1517,14 @@ Status IrEmitterUnnested::EmitRowReduction(
};
return ksl.For("z_tile",
- /*start=*/index_typed_const(0),
- /*end=*/index_typed_const(z_tile_size),
+ /*start=*/index_typed_constant(0),
+ /*end=*/index_typed_constant(z_tile_size),
/*step=*/1, emit_z_tile_element_loop);
};
- llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
- ir_builder_.getInt1(width % (x_tile_size * kWarpSize) == 0),
- ir_builder_.CreateICmpULT(last_x, index_typed_const(width)));
+ llvm::Value* tile_in_bounds =
+ b_.CreateOr(b_.getInt1(width % (x_tile_size * kWarpSize) == 0),
+ b_.CreateICmpULT(last_x, index_typed_constant(width)));
TF_RETURN_IF_ERROR(
ksl.If(tile_in_bounds,
@@ -1795,23 +1547,25 @@ Status IrEmitterUnnested::EmitRowReduction(
// bitcast cannot be applied to aggregate types (even packed ones), so we
// instead bitcast addresses of load/store to intN* of the same bit-width.
llvm::Type* shuffle_ir_type = element_ir_type->isStructTy()
- ? ir_builder_.getIntNTy(bit_width)
+ ? b_.getIntNTy(bit_width)
: element_ir_type;
for (int shuffle_distance = 16; shuffle_distance >= 1;
shuffle_distance /= 2) {
- llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca(
- element_ir_type, nullptr, "result_from_other_lane");
+ llvm::Value* result_from_other_lane =
+ b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane");
for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result = ir_builder_.CreateLoad(
- ir_builder_.CreateBitCast(partial_reduction_result_addresses[i],
- shuffle_ir_type->getPointerTo()),
+ llvm::Value* partial_reduction_result = b_.CreateLoad(
+ b_.CreateBitCast(partial_reduction_result_addresses[i],
+ shuffle_ir_type->getPointerTo()),
"partial_reduction_result");
- ir_builder_.CreateStore(
- EmitShuffleDown(partial_reduction_result,
- ir_builder_.getInt32(shuffle_distance),
- &ir_builder_),
- ir_builder_.CreateBitCast(result_from_other_lane,
- shuffle_ir_type->getPointerTo()));
+ CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0)
+ << "Requires block size a multiple of the warp size, otherwise we "
+ "will read undefined elements.";
+ b_.CreateStore(
+ EmitFullWarpShuffleDown(partial_reduction_result,
+ b_.getInt32(shuffle_distance), &b_),
+ b_.CreateBitCast(result_from_other_lane,
+ shuffle_ir_type->getPointerTo()));
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], result_from_other_lane},
@@ -1826,20 +1580,18 @@ Status IrEmitterUnnested::EmitRowReduction(
// lane 0 (which holds the partially accumulated result for its warp) to the
// output element.
llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpEQ(lane_id, index_typed_const(0)),
- "lane_id_is_zero", &ir_builder_);
- llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
- &ir_builder_);
+ b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero",
+ &b_);
+ llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
for (int i = 0; i != num_reduces; ++i) {
llvm::Value* output_address =
GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
- llvm_ir::IrArray::Index(
- y,
- ShapeUtil::GetSubshape(output->shape(),
- reduce_output_shapes[i]),
- &ir_builder_),
- &ir_builder_, "output_element_address");
+ IrArray::Index(y,
+ ShapeUtil::GetSubshape(
+ output->shape(), reduce_output_shapes[i]),
+ &b_),
+ &b_, "output_element_address");
// We don't need to emit atomic operations if there is only one tile of
// results. 'depth' is the z dimension, 'width' is the x dimension.
if (z_tile_size >= depth && x_tile_size >= width) {
@@ -1863,7 +1615,7 @@ Status IrEmitterUnnested::EmitRowReduction(
static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(),
ir_emitter_context_->llvm_module());
return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape,
- launch_dimensions, &ir_builder_)
+ launch_dimensions, &b_)
.EmitLoop(IrName(reduce), index_ty);
}
@@ -1982,31 +1734,33 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
BuildInitializerThunk(reduce));
std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(std::move(initializer_thunk));
- thunks.push_back(BuildKernelThunk(reduce));
+ thunks.push_back(
+ BuildKernelThunk(reduce, /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
MakeUnique<SequentialThunk>(std::move(thunks), reduce));
return EmitReductionToVector(
- reduce, input->shape(), {[&](const llvm_ir::IrArray::Index& index) {
- return GetIrArray(*input, *reduce)
- .EmitReadArrayElement(index, &ir_builder_);
+ reduce, input->shape(), {[&](const IrArray::Index& index) {
+ return GetIrArray(*input, *reduce).EmitReadArrayElement(index, &b_);
}},
- {[&](const llvm_ir::IrArray::Index& index) {
+ {[&](const IrArray::Index& index) {
return GetIrArray(*init_value, *reduce)
- .EmitReadArrayElement(index, &ir_builder_);
+ .EmitReadArrayElement(index, &b_);
}},
dimensions_to_reduce, {reducer}, {{}}, {});
}
- thunk_sequence_->emplace_back(BuildKernelThunk(reduce));
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(reduce, /*implements_whole_instruction=*/true));
return IrEmitter::HandleReduce(reduce);
}
Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
bool all_tuple_elements_have_buffer =
c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) {
- return ir_emitter_context_->buffer_assignment().HasTopLevelAllocation(
- tuple_element);
+ return ir_emitter_context_->buffer_assignment()
+ .GetUniqueTopLevelSlice(tuple_element)
+ .ok();
});
// Tuples (especially tuples that are the final result of a computation) can
// be so huge that if we were to emit a kernel that took each tuple element as
@@ -2027,7 +1781,8 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
tuple_element_buffers, GetAllocationSlice(*tuple), tuple));
return Status::OK();
}
- thunk_sequence_->emplace_back(BuildKernelThunk(tuple));
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(tuple, /*implements_whole_instruction=*/true));
return IrEmitter::HandleTuple(tuple);
}
@@ -2052,7 +1807,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
BuildInitializerThunk(select_and_scatter));
std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(std::move(initializer_thunk));
- thunks.push_back(BuildKernelThunk(select_and_scatter));
+ thunks.push_back(BuildKernelThunk(select_and_scatter,
+ /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
MakeUnique<SequentialThunk>(std::move(thunks), select_and_scatter));
@@ -2065,8 +1821,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
source->shape(), ir_emitter_context_->device_description());
llvm::Type* index_type = GetIndexTypeForKernel(
- select_and_scatter, launch_dimensions.launch_bound(), &ir_builder_);
- auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ select_and_scatter, launch_dimensions.launch_bound(), &b_);
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_type, c);
};
@@ -2089,114 +1845,106 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// selected_index = I
// initialized_flag = true
// output(selected_index) = scatter(output(selected_index), source(S))
- auto loop_body_emitter =
- [=](const llvm_ir::IrArray::Index& source_index) -> Status {
+ auto loop_body_emitter = [=](const IrArray::Index& source_index) -> Status {
// Allocate space to keep the currently selected value, its index, and a
// boolean flag if the value is initialized. The initialized_flag is set
// false.
llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(operand_element_type,
ir_emitter_context_->llvm_module()),
- "selected_value_address", &ir_builder_);
+ "selected_value_address", &b_);
llvm::Value* selected_index_address =
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
- index_type, index_typed_const(rank), "selected_index_address",
- &ir_builder_);
+ index_type, index_typed_constant(rank), "selected_index_address",
+ &b_);
llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
- ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_);
- ir_builder_.CreateStore(ir_builder_.getInt1(false),
- initialized_flag_address);
+ b_.getInt1Ty(), "initialized_flag_address", &b_);
+ b_.CreateStore(b_.getInt1(false), initialized_flag_address);
// Create the inner loop to iterate over the window.
- llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"),
- &ir_builder_, index_type);
+ llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_,
+ index_type);
std::vector<int64> window_size;
for (const auto& dim : window.dimensions()) {
window_size.push_back(dim.size());
CHECK_GT(dim.size(), 0);
}
- const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape(
+ const IrArray::Index window_index = window_loops.AddLoopsForShape(
ShapeUtil::MakeShape(operand_element_type, window_size), "window");
llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
- &ir_builder_);
+ &b_);
// Compute the operand index to visit and evaluate the condition whether the
// operand index is within the bounds. The unsigned comparison includes
// checking whether the operand index >= 0.
- llvm_ir::IrArray::Index operand_index(index_type, source_index.size());
- llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
+ IrArray::Index operand_index(index_type, source_index.size());
+ llvm::Value* in_bounds_condition = b_.getInt1(true);
for (int64 i = 0; i < rank; ++i) {
- llvm::Value* strided_index = ir_builder_.CreateNSWMul(
- source_index[i], index_typed_const(window.dimensions(i).stride()));
- operand_index[i] = ir_builder_.CreateNSWSub(
- ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
- index_typed_const(window.dimensions(i).padding_low()));
- llvm::Value* index_condition = ir_builder_.CreateICmpULT(
+ llvm::Value* strided_index = b_.CreateNSWMul(
+ source_index[i], index_typed_constant(window.dimensions(i).stride()));
+ operand_index[i] = b_.CreateNSWSub(
+ b_.CreateNSWAdd(strided_index, window_index[i]),
+ index_typed_constant(window.dimensions(i).padding_low()));
+ llvm::Value* index_condition = b_.CreateICmpULT(
operand_index[i],
- index_typed_const(ShapeUtil::GetDimension(operand->shape(), i)));
- in_bounds_condition =
- ir_builder_.CreateAnd(in_bounds_condition, index_condition);
+ index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i)));
+ in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition);
}
CHECK(in_bounds_condition != nullptr);
// Only need to do something if the operand index is within the bounds.
// First check if the initialized_flag is set.
llvm_ir::LlvmIfData if_in_bounds =
- llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_);
- llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &ir_builder_);
+ llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
+ llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateLoad(initialized_flag_address), "initialized",
- &ir_builder_);
+ b_.CreateLoad(initialized_flag_address), "initialized", &b_);
// If the initialized_flag is false, initialize the selected value and index
// with the currently visiting operand.
- llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_);
- const auto save_operand_index = [&](
- const llvm_ir::IrArray::Index& operand_index) {
+ llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &b_);
+ const auto save_operand_index = [&](const IrArray::Index& operand_index) {
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- ir_builder_.CreateInBoundsGEP(selected_index_address,
- {ir_builder_.getInt32(i)});
- ir_builder_.CreateStore(operand_index[i], selected_index_address_slot);
+ b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ b_.CreateStore(operand_index[i], selected_index_address_slot);
}
};
- llvm_ir::IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
+ IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
llvm::Value* operand_data =
- operand_array.EmitReadArrayElement(operand_index, &ir_builder_);
- ir_builder_.CreateStore(operand_data, selected_value_address);
+ operand_array.EmitReadArrayElement(operand_index, &b_);
+ b_.CreateStore(operand_data, selected_value_address);
save_operand_index(operand_index);
- ir_builder_.CreateStore(ir_builder_.getInt1(true),
- initialized_flag_address);
+ b_.CreateStore(b_.getInt1(true), initialized_flag_address);
// If the initialized_flag is true, call the `select` function to
// potentially update the selected value and index with the currently
// visiting operand.
- llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_);
const Shape output_shape = ShapeUtil::MakeShape(PRED, {});
llvm::Value* operand_address =
- operand_array.EmitArrayElementAddress(operand_index, &ir_builder_);
+ operand_array.EmitArrayElementAddress(operand_index, &b_);
llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(PRED,
ir_emitter_context_->llvm_module()),
- "select_return_buffer", &ir_builder_);
+ "select_return_buffer", &b_);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*select_and_scatter->select(),
{selected_value_address, operand_address}, select_return_buffer));
- llvm::Value* result = ir_builder_.CreateLoad(select_return_buffer);
+ llvm::Value* result = b_.CreateLoad(select_return_buffer);
// If the 'select' function returns false, update the selected value and the
// index to the currently visiting operand.
- llvm::Value* cond = ir_builder_.CreateICmpNE(
+ llvm::Value* cond = b_.CreateICmpNE(
result,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(
PRED, ir_emitter_context_->llvm_module()),
0),
"boolean_predicate");
llvm_ir::LlvmIfData if_select_lhs =
- llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_);
- llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &ir_builder_);
- ir_builder_.CreateStore(ir_builder_.CreateLoad(operand_address),
- selected_value_address);
+ llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
+ llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
+ b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address);
save_operand_index(operand_index);
// After iterating over the window elements, scatter the source element to
@@ -2204,20 +1952,19 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// location is computed by calling the `scatter` function with the source
// value and the current output value.
llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
- &ir_builder_);
- llvm_ir::IrArray::Index selected_index(operand_index.GetType());
+ &b_);
+ IrArray::Index selected_index(operand_index.GetType());
for (int64 i = 0; i < rank; ++i) {
- llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP(
- selected_index_address, {ir_builder_.getInt32(i)});
- selected_index.push_back(
- ir_builder_.CreateLoad(selected_index_address_slot));
+ llvm::Value* selected_index_address_slot =
+ b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ selected_index.push_back(b_.CreateLoad(selected_index_address_slot));
}
llvm::Value* source_value_address =
GetIrArray(*source, *select_and_scatter)
- .EmitArrayElementAddress(source_index, &ir_builder_);
+ .EmitArrayElementAddress(source_index, &b_);
llvm::Value* output_value_address =
GetIrArray(*select_and_scatter, *select_and_scatter)
- .EmitArrayElementAddress(selected_index, &ir_builder_);
+ .EmitArrayElementAddress(selected_index, &b_);
return EmitAtomicOperationForNestedComputation(
*select_and_scatter->scatter(), output_value_address,
source_value_address);
@@ -2232,7 +1979,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(),
ir_emitter_context_->llvm_module());
return ParallelLoopEmitter(loop_body_emitter, source->shape(),
- launch_dimensions, &ir_builder_)
+ launch_dimensions, &b_)
.EmitLoop(IrName(select_and_scatter), index_type);
}
@@ -2260,15 +2007,92 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) {
}
Status IrEmitterUnnested::HandleRng(HloInstruction* random) {
- thunk_sequence_->push_back(BuildKernelThunk(random));
+ thunk_sequence_->push_back(
+ BuildKernelThunk(random, /*implements_whole_instruction=*/true));
return IrEmitter::HandleRng(random);
}
Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
- thunk_sequence_->push_back(BuildKernelThunk(select));
+ thunk_sequence_->push_back(
+ BuildKernelThunk(select, /*implements_whole_instruction=*/true));
return IrEmitter::HandleSelect(select);
}
+Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
+ std::vector<std::unique_ptr<Thunk>> thunks;
+ auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
+ if (values != nullptr) {
+ // TODO(b/26783907): Also sort the values by their corresponding key.
+ return Unimplemented("Key/Value Sort is not implemented on GPU");
+ }
+
+ // First copy the operand to the output, so that we can sort in-place.
+ // TODO(b/26783907): Share buffer of output and operand when it is possible.
+ if (sort->operand(0)->IsConstant()) {
+ thunks.push_back(MakeUnique<HostToDeviceCopyThunk>(
+ /*source_address=*/sort->operand(0)->literal().untyped_data(),
+ /*destination_buffer=*/GetAllocationSlice(*sort),
+ /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort));
+ } else {
+ thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+ /*source_address=*/GetAllocationSlice(*sort->operand(0)),
+ /*destination_buffer=*/GetAllocationSlice(*sort),
+ /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort));
+ }
+
+ int64 dimension_to_sort = sort->dimensions(0);
+ int64 dimension_to_sort_bound = sort->shape().dimensions(dimension_to_sort);
+ int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound);
+ auto index_type = b_.getInt64Ty();
+
+ // Naive C++ code for the outer loops:
+ //
+ // for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
+ // ++stage) {
+ // int64 first_xor_mask = (1LL << (stage + 1)) - 1;
+ // SortInPlace(first_xor_mask);
+ // for (int64 mask = stage - 1; mask >= 0; --mask) {
+ // int64 later_xor_mask = 1LL << mask;
+ // SortInPlace(later_xor_mask);
+ // }
+ // }
+ //
+ // This follows the algorithm described on Wikipedia:
+ // https://en.wikipedia.org/wiki/Bitonic_sorter
+
+ for (int64 stage = 0; stage < num_stages; ++stage) {
+ for (int64 mask = stage; mask >= 0; --mask) {
+ thunks.push_back(
+ BuildKernelThunk(sort, /*implements_whole_instruction=*/false));
+ LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
+ sort->shape(), ir_emitter_context_->device_description());
+ UpdateLaunchDimensions(launch_dimensions, thunks.back().get(),
+ ir_emitter_context_->llvm_module());
+
+ llvm::Value* xor_mask;
+ if (mask == stage) {
+ xor_mask = llvm::ConstantInt::get(index_type, (1LL << (stage + 1)) - 1);
+ } else {
+ xor_mask = llvm::ConstantInt::get(index_type, 1LL << mask);
+ }
+
+ TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace(
+ dimension_to_sort, GetIrArray(*sort, *sort), IrName(sort), xor_mask,
+ &b_, &launch_dimensions));
+ }
+ }
+
+ thunk_sequence_->emplace_back(
+ MakeUnique<SequentialThunk>(std::move(thunks), sort));
+ return Status::OK();
+}
+
+Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
+ thunk_sequence_->push_back(
+ BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true));
+ return IrEmitter::HandleTupleSelect(tuple_select);
+}
+
Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) {
if (hlo_module_config_.replica_count() != 1) {
// TODO(b/33011107): Support nontrivial cross replica sum on GPU.
@@ -2304,12 +2128,12 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) {
thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*crs->operand(i)),
/*destination_buffer=*/tuple_element_buffers.back(),
- /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), crs));
+ /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr));
}
// Output a tuple of the buffers above.
thunks.push_back(MakeUnique<TupleThunk>(tuple_element_buffers,
- GetAllocationSlice(*crs), crs));
+ GetAllocationSlice(*crs), nullptr));
thunk_sequence_->push_back(
MakeUnique<SequentialThunk>(std::move(thunks), crs));
return Status::OK();
@@ -2324,6 +2148,11 @@ Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) {
return Status::OK();
}
+Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {
+ thunk_sequence_->emplace_back(BuildOutfeedThunk(outfeed));
+ return Status::OK();
+}
+
// Figures out how to access the buffers for all subshapes of hlo's operands and
// for hlo itself (i.e. all the buffers produced by HLO).
//
@@ -2443,7 +2272,8 @@ GetHloBufferSlices(const HloInstruction* hlo,
}
std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
- const HloInstruction* inst, int unroll_factor) {
+ const HloInstruction* inst, bool implements_whole_instruction,
+ int unroll_factor) {
const BufferAssignment& buffer_assn =
ir_emitter_context_->buffer_assignment();
@@ -2508,18 +2338,16 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
<< " is found in slice " << slice.ToString() << " at GTE index "
<< gte_index.ToString();
- llvm::Value* loc =
- ir_builder_.CreateInBoundsGEP(kernel_args.at(slice.allocation()),
- {ir_builder_.getInt64(slice.offset())});
+ llvm::Value* loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()),
+ {b_.getInt64(slice.offset())});
// If gte_index is nonempty, we have to dereference `loc` to get to the
// value we're ultimately interested in.
llvm::Type* int8_double_pointer =
- llvm::PointerType::get(ir_builder_.getInt8PtrTy(), /*AddressSpace=*/0);
+ llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0);
for (int64 idx : gte_index) {
- loc = ir_builder_.CreateBitCast(loc, int8_double_pointer);
- loc = ir_builder_.CreateLoad(
- ir_builder_.CreateInBoundsGEP(loc, {ir_builder_.getInt64(idx)}));
+ loc = b_.CreateBitCast(loc, int8_double_pointer);
+ loc = b_.CreateLoad(b_.CreateInBoundsGEP(loc, {b_.getInt64(idx)}));
}
bindings_.BindHloToIrValue(*instr, loc, index);
@@ -2531,11 +2359,12 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
bindings_.SetTempBufferBase(kernel_args.at(*temp_buffer));
} else {
bindings_.SetTempBufferBase(
- llvm::ConstantPointerNull::get(ir_builder_.getInt8PtrTy()));
+ llvm::ConstantPointerNull::get(b_.getInt8PtrTy()));
}
return MakeUnique<KernelThunk>(buffers, llvm_ir::AsString(kernel->getName()),
- inst, unroll_factor);
+ implements_whole_instruction ? inst : nullptr,
+ unroll_factor);
}
std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk(
@@ -2569,7 +2398,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk(
ShapeTree<BufferAllocation::Slice> slices(inst->shape());
slices.ForEachMutableElement(
- [this, inst](const ShapeIndex& index, BufferAllocation::Slice* slice) {
+ [&](const ShapeIndex& index, BufferAllocation::Slice* slice) {
*slice = ir_emitter_context_->buffer_assignment()
.GetUniqueSlice(inst, index)
.ConsumeValueOrDie();
@@ -2577,6 +2406,23 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk(
return MakeUnique<InfeedThunk>(slices, inst);
}
+std::unique_ptr<Thunk> IrEmitterUnnested::BuildOutfeedThunk(
+ const HloInstruction* inst) {
+ CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
+
+ ShapeTree<BufferAllocation::Slice> slices(inst->operand(0)->shape());
+ slices.ForEachMutableElement(
+ [&](const ShapeIndex& index, BufferAllocation::Slice* slice) {
+ auto status_or_slice =
+ ir_emitter_context_->buffer_assignment().GetUniqueSlice(
+ inst->operand(0), index);
+ if (status_or_slice.ok()) {
+ *slice = status_or_slice.ConsumeValueOrDie();
+ }
+ });
+ return MakeUnique<OutfeedThunk>(std::move(slices), inst);
+}
+
namespace {
double GetScalarConstantAsDouble(const Literal& literal) {
switch (literal.shape().element_type()) {
@@ -2692,6 +2538,11 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
init_value = hlo->operand(init_value->parameter_number());
}
+ // Initializer thunks don't implement a whole instruction, and we want to
+ // profile the whole instruction instead of the individual thunks it consists
+ // of. Therefore we pass nullptr as the HloInstruction* to the thunks we
+ // generate below.
+ //
// In the common case, the initializer is a constant. In this case, emit a
// device-memset call if we can. Currently StreamExecutor only supports
// zeroing and 32-bit memsets.
@@ -2705,7 +2556,8 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
ArraySlice<uint8> literal_bytes(
reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
- return {MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo, index), hlo)};
+ return {
+ MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo, index), nullptr)};
}
// If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
@@ -2723,7 +2575,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
}
uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
return {MakeUnique<Memset32BitValueThunk>(
- pattern32, GetAllocationSlice(*hlo, index), hlo)};
+ pattern32, GetAllocationSlice(*hlo, index), nullptr)};
}
// If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
@@ -2734,12 +2586,13 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
uint32 word;
memcpy(&word, literal_bytes.data(), sizeof(word));
return {MakeUnique<Memset32BitValueThunk>(
- word, GetAllocationSlice(*hlo, index), hlo)};
+ word, GetAllocationSlice(*hlo, index), nullptr)};
}
}
// Otherwise fall back to our slow initializer code.
- std::unique_ptr<KernelThunk> kernel_thunk = BuildKernelThunk(hlo);
+ std::unique_ptr<KernelThunk> kernel_thunk =
+ BuildKernelThunk(hlo, /*implements_whole_instruction=*/false);
LaunchDimensions launch_dimensions =
CalculateLaunchDimensions(ShapeUtil::GetSubshape(hlo->shape(), index),
ir_emitter_context_->device_description());
@@ -2751,12 +2604,11 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
TF_RETURN_IF_ERROR(HandleConstant(const_cast<HloInstruction*>(init_value)));
}
TF_RETURN_IF_ERROR(ParallelLoopEmitter(
- [=](const llvm_ir::IrArray::Index& index) {
+ [=](const IrArray::Index& index) {
return GetIrArray(*init_value, *hlo)
- .EmitReadArrayElement(index, &ir_builder_);
+ .EmitReadArrayElement(index, &b_);
},
- GetIrArray(*hlo, *hlo, index), launch_dimensions,
- &ir_builder_)
+ GetIrArray(*hlo, *hlo, index), launch_dimensions, &b_)
.EmitLoop(IrName(hlo)));
// Clean up state left behind by emitting the loop above. (This is normally
@@ -2940,41 +2792,546 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
ir_emitter_context_->llvm_module());
if (!hlo.IsMultiOutputFusion()) {
return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo),
- launch_dimensions, &ir_builder_, unroll_factor)
- .EmitLoop(IrName(&hlo),
- GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(),
- &ir_builder_));
+ launch_dimensions, &b_, unroll_factor)
+ .EmitLoop(
+ IrName(&hlo),
+ GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), &b_));
}
- // For multiple outputs fusion, we need to emit each operand and the root.
- std::vector<llvm_ir::IrArray> output_arrays;
+ // For multioutput fusion, we need to emit each operand and the root.
+ std::vector<IrArray> output_arrays;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) {
output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
}
TF_RETURN_IF_ERROR(
ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions,
- &ir_builder_, unroll_factor)
+ &b_, unroll_factor)
.EmitLoop(IrName(&hlo),
GetIndexTypeForKernel(
- &hlo, launch_dimensions.launch_bound(), &ir_builder_)));
+ &hlo, launch_dimensions.launch_bound(), &b_)));
std::vector<llvm::Value*> tuple_operand_ptrs;
for (int64 i = 0; i < output_arrays.size(); ++i) {
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
}
- ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator());
- llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_,
- module_);
+ b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
+ llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_);
return Status::OK();
}
Status IrEmitterUnnested::EmitTargetElementLoop(
const HloInstruction& hlo,
const llvm_ir::ElementGenerator& element_generator) {
- CHECK(Thunk::Kind::kKernel == LastThunk()->kind());
+ CHECK_EQ(Thunk::Kind::kKernel, LastThunk()->kind());
return EmitTargetElementLoopInThunk(hlo, element_generator,
static_cast<KernelThunk*>(LastThunk()));
}
+int IrEmitterUnnested::ConstructIrArrayForOutputs(
+ const HloInstruction& hlo, std::vector<IrArray>* output_arrays) {
+ int64 num_outputs = 1;
+ if (hlo.IsMultiOutputFusion()) {
+ num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
+ output_arrays->reserve(num_outputs);
+ for (int64 i = 0; i < num_outputs; ++i) {
+ output_arrays->push_back(GetIrArray(hlo, hlo, {i}));
+ }
+ } else {
+ output_arrays->push_back(GetIrArray(hlo, hlo));
+ }
+ return num_outputs;
+}
+
+int IrEmitterUnnested::ConstructIrArrayForInputs(
+ const HloInstruction& hlo, std::vector<IrArray>* param_arrays) {
+ int64 num_params = hlo.operands().size();
+ param_arrays->reserve(num_params);
+ for (const HloInstruction* param : hlo.operands()) {
+ param_arrays->push_back(GetIrArray(*param, hlo));
+ }
+ return num_params;
+}
+
+int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
+ const HloInstruction& hlo, const std::vector<IrArray>& output_arrays,
+ tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ std::vector<Shape>* output_reduced_shapes,
+ std::vector<IrArray>* output_in_reduced_shape_arrays) {
+ int64 num_outputs = 1;
+ if (hlo.IsMultiOutputFusion()) {
+ num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
+ output_in_reduced_shape_arrays->reserve(num_outputs);
+ output_reduced_shapes->reserve(num_outputs);
+ for (int64 i = 0; i < num_outputs; ++i) {
+ output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout(
+ ShapeUtil::GetSubshape(hlo.shape(), {i}).element_type(),
+ reduced_output_dims));
+ output_in_reduced_shape_arrays->push_back(
+ output_arrays[i].CastToShape((*output_reduced_shapes)[i], &b_));
+ }
+ } else {
+ output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout(
+ hlo.shape().element_type(), reduced_output_dims));
+ output_in_reduced_shape_arrays->push_back(
+ output_arrays[0].CastToShape((*output_reduced_shapes)[0], &b_));
+ }
+ return num_outputs;
+}
+
+int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape(
+ const HloInstruction& hlo, const std::vector<IrArray>& param_arrays,
+ const std::vector<llvm::Value*>& param_buffers,
+ tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ std::vector<Shape>* param_reduced_shapes,
+ std::vector<IrArray>* param_in_reduced_shape_arrays) {
+ int64 num_params = hlo.operands().size();
+ param_in_reduced_shape_arrays->reserve(num_params);
+ param_reduced_shapes->reserve(num_params);
+ for (int64 id = 0; id < num_params; ++id) {
+ if (param_buffers[id] == nullptr) {
+ param_reduced_shapes->push_back(Shape());
+ param_in_reduced_shape_arrays->push_back(IrArray());
+ continue;
+ }
+ const HloInstruction* param = hlo.operand(id);
+ param_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout(
+ param->shape().element_type(),
+ Permute({0, 2, 1}, reduced_output_dims)));
+ param_in_reduced_shape_arrays->push_back(
+ param_arrays[id].CastToShape((*param_reduced_shapes)[id], &b_));
+ }
+ return num_params;
+}
+
+namespace {
+
+// Reads thread_idx.x and converts it to a (y,x) coordinate, assuming that the
+// thread lives within a square tile of size tile_size (so thread blocks are of
+// size tile_size * tile_size).
+std::tuple<llvm::Value*, llvm::Value*> CalculateYXCoordinateWithinTile(
+ llvm::IRBuilder<>* builder, llvm::Value* tile_size,
+ int64 threads_per_tile) {
+ // Calculate the starting element coordinate within a tile for the current
+ // thread, (y, x) from thread_id.
+ llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder);
+ llvm_ir::AddRangeMetadata(0, threads_per_tile,
+ llvm::cast<llvm::Instruction>(thread_id));
+ thread_id = builder->CreateIntCast(thread_id, tile_size->getType(),
+ /*isSigned=*/true, "thread.id.x");
+ auto x = builder->CreateURem(thread_id, tile_size);
+ auto y = builder->CreateUDiv(thread_id, tile_size);
+ return std::make_tuple(y, x);
+}
+
+// Reads block_idx.x, casts it to type index_ty, and adds the assumption that
+// it's in the range [0, num_blocks].
+llvm::Value* GetBlockIdx(llvm::IRBuilder<>* builder, llvm::Type* index_ty,
+ int64 num_blocks) {
+ llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, builder);
+ llvm_ir::AddRangeMetadata(0, num_blocks,
+ llvm::cast<llvm::Instruction>(block_id));
+ return builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true,
+ "block.id.x");
+}
+
+// Emits code to process up to (tile_size/num_rows) elements in a tile, given
+// `emit_elem_function` is the function to emit code to process one element, `y`
+// and `x` are the coordinates for the first element to process, and `index` is
+// the index for the origin of the tile. Emits bounds check to ensure that each
+// processed element is within the boundary defined by `tile_width` and
+// `tile_height`.
+void EmitTiledElementalCodeWithBoundsCheck(
+ int64 tile_size, int64 num_rows, const IrArray::Index& index,
+ const string& loop_name, KernelSupportLibrary* ksl,
+ llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x,
+ llvm::Value* tile_width, llvm::Value* tile_height,
+ const std::function<void(const IrArray::Index&, llvm::Value*)>&
+ emit_elem_function) {
+ llvm::Type* index_ty = tile_width->getType();
+ // Emits a constant value with index type.
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_ty, c);
+ };
+ // Adds `addend` to the given `dim` of `index`.
+ auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) {
+ index[dim] = builder->CreateAdd(index[dim], addend);
+ return index;
+ };
+
+ auto emit_full_tile = [&] {
+ for (int64 i = 0; i < tile_size; i += num_rows) {
+ auto source_idx = offset_dim(index, index_typed_constant(i), /*dim=*/1);
+ auto y_loc = builder->CreateAdd(index_typed_constant(i), y);
+ emit_elem_function(source_idx, y_loc);
+ }
+ };
+
+ auto emit_last_row = [&] {
+ ksl->IfReturnVoid("x_in_tile", builder->CreateICmpULT(x, tile_width), [&] {
+ // tile_height_upper_bound =
+ // ceil(tile_height / num_rows) * num_rows
+ auto tile_height_upper_bound = builder->CreateMul(
+ builder->CreateUDiv(
+ builder->CreateAdd(tile_height,
+ index_typed_constant(num_rows - 1)),
+ index_typed_constant(num_rows)),
+ index_typed_constant(num_rows));
+ ksl->ForReturnVoid(
+ loop_name, /*start=*/index_typed_constant(0),
+ /*end=*/tile_height_upper_bound,
+ /*step=*/index_typed_constant(num_rows), [&](llvm::Value* y_indvar) {
+ auto y_loc = builder->CreateAdd(y_indvar, y);
+ ksl->IfReturnVoid(
+ "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), [&] {
+ emit_elem_function(offset_dim(index, y_indvar, /*dim=*/1),
+ y_loc);
+ });
+ });
+ });
+ };
+ ksl->IfReturnVoid(
+ "full_tile",
+ builder->CreateAnd(
+ builder->CreateICmpEQ(index_typed_constant(tile_size), tile_width),
+ builder->CreateICmpEQ(index_typed_constant(tile_size), tile_height)),
+ emit_full_tile, emit_last_row);
+}
+} // namespace
+
+// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
+// algorithm to improve the memory access patterns for the input parameters
+// which have a shape that is a 0-2-1 transpose of the output tensors.
+//
+// For the purpose of tiling, the output tensors have a logical shape of three
+// components 0-2-1 while the relevant input parameters have a logical shape of
+// three components 0-1-2 in the order major to minor. The x- and y- dimensions
+// of the tensors are tiled in square tiles of edge length `kTileSize`. Each
+// thread block of `kTileSize` x `kNumRows` threads transposes one tile: each
+// thread copies kTileSize/kNumRows elements from the input to a shared memory
+// tile, then the otherwise "regular hlo kernel" reads from the shared memory
+// instead of the original input.
+//
+// This is similar to the following CUDA algorithm in TensorFlow:
+// https://goo.gl/MStRV6.
+//
+// `kTileSize` should usually be same as warp size. We currently choose 32 for
+// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`.
+//
+// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient
+// to launch fewer blocks so each transposes many tiles.
+LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
+ HloInstruction* hlo, tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ tensorflow::gtl::ArraySlice<int64> tiled_param_ids) {
+ // Parameters for the tiling algorithm.
+ constexpr int64 kTileSize = 32;
+ constexpr int64 kNumRows = 4;
+ constexpr int64 kThreadsPerTile = kTileSize * kNumRows;
+
+ // Construct IrArrays for the inputs and outputs.
+ std::vector<IrArray> output_arrays;
+ int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays);
+ std::vector<IrArray> param_arrays;
+ int64 num_params = ConstructIrArrayForInputs(*hlo, &param_arrays);
+
+ // Allocate shared memory buffers to store the tiled inputs.
+ std::vector<llvm::Value*> param_shmem_buffers(num_params, nullptr);
+ for (int64 id : tiled_param_ids) {
+ const HloInstruction* param = hlo->operand(id);
+ // Add 1 to the minor dimension to reduce shared memory bank conflicts.
+ llvm::Type* tile_type = llvm::ArrayType::get(
+ llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType(
+ param->shape().element_type(), module_),
+ kTileSize + 1),
+ kTileSize);
+ const int kNVPTXSharedMemoryAddrSpace = 3;
+ auto* tile_base_ptr = new llvm::GlobalVariable(
+ *b_.GetInsertBlock()->getParent()->getParent(), tile_type,
+ /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
+ llvm::UndefValue::get(tile_type),
+ llvm_ir::AsStringRef(IrName(hlo, StrCat("tile", id))), nullptr,
+ llvm::GlobalValue::NotThreadLocal, kNVPTXSharedMemoryAddrSpace);
+ param_shmem_buffers[id] = tile_base_ptr;
+ VLOG(3) << "Added shmem buffer for parameter " << id << ": "
+ << llvm_ir::DumpToString(*tile_base_ptr);
+ }
+
+ // The 0-2-1 shape of the tiling scheme is the reduced shape of the HLO result
+ // for the purpose of tiling. Calculate the logical output dimensions in the
+ // tile from the reduced output dimensions.
+ std::vector<int64> output_dims_in_tiles = std::vector<int64>(
+ reduced_output_dims.begin(), reduced_output_dims.end());
+ CHECK_EQ(output_dims_in_tiles.size(), 3);
+ for (int i = 1; i < 3; ++i) {
+ output_dims_in_tiles[i] =
+ CeilOfRatio<int64>(output_dims_in_tiles[i], kTileSize);
+ }
+ const int64 num_tiles =
+ c_accumulate(output_dims_in_tiles, 1, std::multiplies<int64>());
+ LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile);
+
+ llvm::Type* index_ty =
+ GetIndexTypeForKernel(hlo, launch_dimensions.launch_bound(), &b_);
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_ty, c);
+ };
+
+ // Cast each output IrArray to its corresponding reduced shape and keep the
+ // reduced shape live during IR emission.
+ std::vector<IrArray> output_in_reduced_shape_arrays;
+ std::vector<Shape> output_reduced_shapes;
+ CHECK_EQ(ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
+ *hlo, output_arrays, reduced_output_dims, &output_reduced_shapes,
+ &output_in_reduced_shape_arrays),
+ num_outputs);
+
+ // For each tiled parameter, cast its input IrArray to the corresponding
+ // reduced shape and keep the reduced shape live during IR emission.
+ std::vector<IrArray> param_in_reduced_shape_arrays;
+ std::vector<Shape> param_reduced_shapes;
+ CHECK_EQ(ConstructInputReducedShapeAndCastInputIrArrayToShape(
+ *hlo, param_arrays, param_shmem_buffers, reduced_output_dims,
+ &param_reduced_shapes, &param_in_reduced_shape_arrays),
+ num_params);
+
+ // Calculate the starting element coordinate within a tile for the current
+ // thread, (y, x) from thread_id.
+ llvm::Value* x;
+ llvm::Value* y;
+ std::tie(y, x) = CalculateYXCoordinateWithinTile(
+ &b_, index_typed_constant(kTileSize), kThreadsPerTile);
+
+ // Calculate the index for the current output tile from block_id.
+ const IrArray::Index output_tile_index(
+ GetBlockIdx(&b_, index_ty, num_tiles),
+ ShapeUtil::MakeShapeWithDescendingLayout(PRED /*arbitrary*/,
+ output_dims_in_tiles),
+ &b_);
+
+ // Output tile origin is the index for the first element of the current output
+ // tile.
+ const IrArray::Index output_tile_origin = [&] {
+ IrArray::Index index = output_tile_index;
+ for (int i = 1; i < 3; ++i) {
+ index[i] =
+ b_.CreateMul(output_tile_index[i], index_typed_constant(kTileSize),
+ "tile_origin." + std::to_string(i));
+ }
+ return index;
+ }();
+
+ // Calculate the input tile origin from the output tile origin.
+ const IrArray::Index input_tile_origin(
+ Permute({0, 2, 1}, output_tile_origin.multidim()));
+
+ // Calculate the current output tile bounds in each of the logical dimensions.
+ std::vector<llvm::Value*> output_tile_bounds(3);
+ for (int i = 1; i < 3; ++i) {
+ // Only last row or column may not have full size.
+ output_tile_bounds[i] = b_.CreateSelect(
+ b_.CreateICmpEQ(output_tile_index[i],
+ index_typed_constant(output_dims_in_tiles[i] - 1)),
+ index_typed_constant(reduced_output_dims[i] -
+ (output_dims_in_tiles[i] - 1) * kTileSize),
+ index_typed_constant(kTileSize), "kTileSize");
+ }
+
+ KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
+
+ // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck.
+ auto emit_tiled_elemental_code_with_bounds_check =
+ [&](const IrArray::Index& index, const string& loop_name,
+ llvm::Value* tile_width, llvm::Value* tile_height,
+ const std::function<void(const IrArray::Index&, llvm::Value*)>&
+ emit_elem_function) {
+ EmitTiledElementalCodeWithBoundsCheck(
+ kTileSize, kNumRows, index, loop_name, &ksl, &b_, y, x, tile_width,
+ tile_height, emit_elem_function);
+ };
+
+ // Adds `addend` to the given `dim` of `index`.
+ auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) {
+ index[dim] = b_.CreateAdd(index[dim], addend);
+ return index;
+ };
+ const IrArray::Index input_index =
+ offset_dim(offset_dim(input_tile_origin, x, /*dim=*/2), y, /*dim=*/1);
+
+ // Copy input parameter values to shared memory buffers:
+ // tile[y, x] = input[index]
+ emit_tiled_elemental_code_with_bounds_check(
+ input_index, "input", output_tile_bounds[1], output_tile_bounds[2],
+ [&](const IrArray::Index& index, llvm::Value* y_loc) {
+ for (int64 id : tiled_param_ids) {
+ IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id];
+ llvm::Value* shmem_buffer = param_shmem_buffers[id];
+ // TODO(jlebar): Add AA metadata to this store. Tile buffers are
+ // global variables, so LLVM can't infer much about it.
+ b_.CreateStore(
+ input_in_logical_shape.EmitReadArrayElement(index, &b_,
+ "input_element"),
+ b_.CreateGEP(shmem_buffer, {index_typed_constant(0), y_loc, x}));
+ }
+ });
+
+ // Wait for all threads to reach this point, lest we copy a value from tile to
+ // output before the other thread copies it from input to tile.
+ // This is `__syncthreads` in CUDA.
+ llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_);
+
+ llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x);
+
+ const IrArray::Index output_index =
+ offset_dim(offset_dim(output_tile_origin, x, /*dim=*/2), y, /*dim=*/1);
+
+ // Write to output[index] by emitting code like normal, except that values for
+ // the tiled parameters are read from the shmem buffers.
+ if (hlo->opcode() == HloOpcode::kCopy) {
+ emit_tiled_elemental_code_with_bounds_check(
+ output_index, "output", output_tile_bounds[2], output_tile_bounds[1],
+ [&](const IrArray::Index& index, llvm::Value* y_loc) {
+ // TODO(jlebar): Add AA metadata to this load.
+ llvm::Instruction* load_from_shmem_buffer = b_.CreateLoad(
+ b_.CreateGEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}),
+ "output_element");
+ output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
+ index, load_from_shmem_buffer, &b_);
+ });
+ } else {
+ CHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
+ emit_tiled_elemental_code_with_bounds_check(
+ output_index, "output", output_tile_bounds[2], output_tile_bounds[1],
+ [&](const IrArray::Index& index, llvm::Value* y_loc) {
+ GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
+ GetNestedComputer());
+ FusedIrEmitter fused_emitter(param_arrays, &elem_emitter);
+ tiled_param_info.set_y(y_loc);
+ fused_emitter.SetTiledParameterInfo(&tiled_param_info);
+ TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter));
+ IrArray::Index untiled_index = llvm_ir::GetUnreducedOutputIndex(
+ index, output_reduced_shapes[0], output_arrays[0].GetShape(),
+ &b_);
+ const llvm_ir::ElementGenerator& output_generator =
+ fused_emitter.GetRootGenerator();
+ llvm::Value* output_value =
+ output_generator(untiled_index).ValueOrDie();
+ if (hlo->IsMultiOutputFusion()) {
+ CHECK(output_value->getType()->isStructTy());
+ CHECK_EQ(output_value->getType()->getStructNumElements(),
+ output_in_reduced_shape_arrays.size());
+ for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) {
+ output_in_reduced_shape_arrays[i].EmitWriteArrayElement(
+ index, b_.CreateExtractValue(output_value, i), &b_);
+ }
+ } else {
+ output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
+ index, output_value, &b_);
+ }
+ });
+ }
+
+ // For multioutput fusion, emit a tuple with all the individual outputs.
+ if (hlo->IsMultiOutputFusion()) {
+ std::vector<llvm::Value*> tuple_operand_ptrs;
+ for (int64 i = 0; i < output_arrays.size(); ++i) {
+ tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
+ }
+ llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &b_,
+ module_);
+ }
+
+ return launch_dimensions;
+}
+
+bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
+ HloOpcode opcode = hlo->opcode();
+ CHECK(opcode == HloOpcode::kFusion || opcode == HloOpcode::kCopy);
+ CHECK(opcode != HloOpcode::kFusion ||
+ hlo->fusion_kind() == HloInstruction::FusionKind::kLoop)
+ << "Only loop fusions are supported.";
+
+ const Shape& output_shape = hlo->IsMultiOutputFusion()
+ ? ShapeUtil::GetSubshape(hlo->shape(), {0})
+ : hlo->shape();
+
+ // If the output_shape is reduced to 021 shape, find all the parameters of the
+ // hlo that are in the corresponding 012 shape.
+ std::vector<int64> params_012;
+ optional<std::vector<int64>> reduced_dims_021;
+ for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
+ ++operand_idx) {
+ HloInstruction* operand = hlo->mutable_operand(operand_idx);
+ auto find_transpose_result =
+ llvm_ir::FindTranspose021(operand->shape(), output_shape);
+ if (!find_transpose_result.has_value()) {
+ continue;
+ }
+ const std::vector<int64>& curr_reduced_dims_021 = *find_transpose_result;
+ if (!reduced_dims_021.has_value()) {
+ reduced_dims_021 = curr_reduced_dims_021;
+ }
+ if (!ContainersEqual(*reduced_dims_021, curr_reduced_dims_021)) {
+ // There is more than one possible transpose. Instead of picking one
+ // transpose, we simply give up here.
+ return false;
+ }
+ params_012.push_back(operand_idx);
+ }
+
+ if (!reduced_dims_021.has_value()) {
+ return false;
+ }
+
+ if ((*reduced_dims_021)[1] < kMinDimensionToTransposeTiled ||
+ (*reduced_dims_021)[2] < kMinDimensionToTransposeTiled) {
+ return false;
+ }
+
+ // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the
+ // elements are of size 4 bytes), and CUDA has an architectural limit of 48kb
+ // shared memory per SM. (This is increased to 96kb in Volta, but we don't
+ // use this, in part because it eats into our L1 cache space.)
+ //
+ // For correctness we need to ensure that we don't make more than 48kb worth
+ // of shmem tiles per block. And for performance, we'd probably like to use
+ // significantly less, so that we can fit more than one block at a time on a
+ // gpu core.
+ //
+ // We say without benchmarks that we want at least 3 threads/block,
+ // corresponding to 3 shmem tiles if the elements are 32 bits wide. We choose
+ // which params get the shmem transpose treatment arbitrarily; it's not clear
+ // if there's a Right Choice.
+ //
+ // This is only sound if tiled transposes are the only place where we use
+ // shared memory in fusions. If in the future other fusile ops use shared
+ // memory, we'll have to adjust this heuristic.
+ constexpr int kMinBlocksPerCore = 3;
+ constexpr int64 kShmemPerCore = 48 * 1024;
+ int64 shmem_used = 0;
+ for (int64 i = 0; i < params_012.size(); ++i) {
+ const HloInstruction* operand = hlo->operand(params_012[i]);
+ shmem_used +=
+ 32 * 33 *
+ ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
+
+ if (kMinBlocksPerCore * shmem_used > kShmemPerCore) {
+ // Erase this element and everything after it from params_012.
+ params_012.resize(i);
+ break;
+ }
+ }
+
+ VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString();
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(hlo, /*implements_whole_instruction=*/true));
+ const LaunchDimensions launch_dimensions =
+ EmitHlo021Tile(hlo, *reduced_dims_021, params_012);
+ UpdateLaunchDimensions(launch_dimensions, LastThunk(),
+ ir_emitter_context_->llvm_module());
+
+ return true;
+}
+
} // namespace gpu
} // namespace xla