aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc305
1 files changed, 131 insertions, 174 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index d5e07c3afb..f95541cba4 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -57,12 +57,12 @@ IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config,
IrEmitterContext* ir_emitter_context, bool is_nested)
: ir_emitter_context_(ir_emitter_context),
module_(ir_emitter_context->llvm_module()),
- ir_builder_(module_->getContext()),
+ b_(module_->getContext()),
bindings_(ir_emitter_context->hlo_module(),
- &ir_emitter_context->buffer_assignment(), &ir_builder_, module_,
+ &ir_emitter_context->buffer_assignment(), &b_, module_,
is_nested),
hlo_module_config_(hlo_module_config) {
- ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags(
+ b_.setFastMathFlags(llvm_ir::GetFastMathFlags(
/*fast_math_enabled=*/hlo_module_config.debug_options()
.xla_enable_fast_math()));
}
@@ -71,12 +71,11 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
for (const HloInstruction* operand : hlo->operands()) {
operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
- return GetIrArray(*operand, *hlo)
- .EmitReadArrayElement(index, &ir_builder_);
+ return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_);
};
}
return EmitTargetElementLoop(
- *hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_,
+ *hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &b_,
GetNestedComputer())
.MakeElementGenerator(hlo, operand_to_generator));
}
@@ -119,15 +118,10 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
get_tuple_element->shape(), get_tuple_element->tuple_index(),
// TODO(b/26344050): tighten the alignment here
// based on the real element type.
- /*alignment=*/1, GetBasePointer(*operand), &ir_builder_, module_));
+ /*alignment=*/1, GetBasePointer(*operand), &b_, module_));
return Status::OK();
}
-Status IrEmitter::HandleSort(HloInstruction*) {
- // TODO(b/26783907): Implement sort on GPU.
- return Unimplemented("sort");
-}
-
Status IrEmitter::HandleSend(HloInstruction*) {
return Unimplemented("Send is not implemented on GPU");
}
@@ -149,8 +143,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
for (const HloInstruction* operand : tuple->operands()) {
base_ptrs.push_back(GetBasePointer(*operand));
}
- llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &ir_builder_,
- module_);
+ llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_, module_);
return Status::OK();
}
@@ -171,7 +164,7 @@ Status IrEmitter::EmitCallToNestedComputation(
std::vector<llvm::Value*> arguments(operands.begin(), operands.end());
arguments.push_back(output);
arguments.push_back(bindings_.GetTempBufferBase());
- ir_builder_.CreateCall(emitted_function, arguments);
+ b_.CreateCall(emitted_function, arguments);
return Status::OK();
}
@@ -193,21 +186,20 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
computation.root_instruction()->shape().element_type();
bool is_atomic_integral = element_type == S32 || element_type == U32 ||
element_type == S64 || element_type == U64;
- llvm::Value* source = ir_builder_.CreateLoad(source_address, "source");
+ llvm::Value* source = b_.CreateLoad(source_address, "source");
if (root_opcode == HloOpcode::kAdd) {
// NVPTX supports atomicAdd on F32 and integer types.
if (element_type == F32) {
// F32 + F32
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_atomic_load_add_f32,
{output_address, source},
- {output_address->getType()}, &ir_builder_);
+ {output_address->getType()}, &b_);
return true;
}
if (is_atomic_integral) {
// integral + integral
- ir_builder_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address,
- source,
- llvm::AtomicOrdering::SequentiallyConsistent);
+ b_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address, source,
+ llvm::AtomicOrdering::SequentiallyConsistent);
return true;
}
}
@@ -218,8 +210,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
auto opcode = primitive_util::IsSignedIntegralType(element_type)
? llvm::AtomicRMWInst::Max
: llvm::AtomicRMWInst::UMax;
- ir_builder_.CreateAtomicRMW(opcode, output_address, source,
- llvm::AtomicOrdering::SequentiallyConsistent);
+ b_.CreateAtomicRMW(opcode, output_address, source,
+ llvm::AtomicOrdering::SequentiallyConsistent);
return true;
}
@@ -228,8 +220,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
auto opcode = primitive_util::IsSignedIntegralType(element_type)
? llvm::AtomicRMWInst::Min
: llvm::AtomicRMWInst::UMin;
- ir_builder_.CreateAtomicRMW(opcode, output_address, source,
- llvm::AtomicOrdering::SequentiallyConsistent);
+ b_.CreateAtomicRMW(opcode, output_address, source,
+ llvm::AtomicOrdering::SequentiallyConsistent);
return true;
}
@@ -301,20 +293,20 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
llvm::Type* element_address_type = element_type->getPointerTo();
int atomic_size = (element_size < 32) ? 32 : element_size;
- llvm::Type* atomic_type = ir_builder_.getIntNTy(atomic_size);
+ llvm::Type* atomic_type = b_.getIntNTy(atomic_size);
llvm::Type* atomic_address_type =
atomic_type->getPointerTo(output_address_type->getPointerAddressSpace());
// cas_old_output_address and cas_new_output_address point to the scratch
// memory where we store the old and new values for the repeated atomicCAS
// operations.
- llvm::Value* cas_old_output_address = ir_builder_.CreateAlloca(
+ llvm::Value* cas_old_output_address = b_.CreateAlloca(
atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address");
- llvm::Value* cas_new_output_address = ir_builder_.CreateAlloca(
+ llvm::Value* cas_new_output_address = b_.CreateAlloca(
atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address");
// Emit preparation code to the preheader.
- llvm::BasicBlock* loop_preheader_bb = ir_builder_.GetInsertBlock();
+ llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock();
llvm::Value* atomic_memory_address;
// binop_output_address points to the scratch memory that stores the
@@ -325,77 +317,71 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
CHECK_EQ((element_size % sizeof(char)), 0);
llvm::Type* address_int_type =
module_->getDataLayout().getIntPtrType(output_address_type);
- atomic_memory_address =
- ir_builder_.CreatePtrToInt(output_address, address_int_type);
+ atomic_memory_address = b_.CreatePtrToInt(output_address, address_int_type);
llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3);
- llvm::Value* offset = ir_builder_.CreateAnd(atomic_memory_address, mask);
+ llvm::Value* offset = b_.CreateAnd(atomic_memory_address, mask);
mask = llvm::ConstantInt::get(address_int_type, -4);
- atomic_memory_address = ir_builder_.CreateAnd(atomic_memory_address, mask);
+ atomic_memory_address = b_.CreateAnd(atomic_memory_address, mask);
atomic_memory_address =
- ir_builder_.CreateIntToPtr(atomic_memory_address, atomic_address_type);
- binop_output_address = ir_builder_.CreateAdd(
- ir_builder_.CreatePtrToInt(cas_new_output_address, address_int_type),
- offset);
+ b_.CreateIntToPtr(atomic_memory_address, atomic_address_type);
+ binop_output_address = b_.CreateAdd(
+ b_.CreatePtrToInt(cas_new_output_address, address_int_type), offset);
binop_output_address =
- ir_builder_.CreateIntToPtr(binop_output_address, element_address_type);
+ b_.CreateIntToPtr(binop_output_address, element_address_type);
} else {
atomic_memory_address =
- ir_builder_.CreateBitCast(output_address, atomic_address_type);
+ b_.CreateBitCast(output_address, atomic_address_type);
binop_output_address =
- ir_builder_.CreateBitCast(cas_new_output_address, element_address_type);
+ b_.CreateBitCast(cas_new_output_address, element_address_type);
}
// Use the value from the memory that atomicCAS operates on to initialize
// cas_old_output.
llvm::Value* cas_old_output =
- ir_builder_.CreateLoad(atomic_memory_address, "cas_old_output");
- ir_builder_.CreateStore(cas_old_output, cas_old_output_address);
+ b_.CreateLoad(atomic_memory_address, "cas_old_output");
+ b_.CreateStore(cas_old_output, cas_old_output_address);
llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock(
- ir_builder_.GetInsertPoint(), "atomic_op_loop_exit");
- llvm::BasicBlock* loop_body_bb =
- llvm::BasicBlock::Create(ir_builder_.getContext(), "atomic_op_loop_body",
- ir_builder_.GetInsertBlock()->getParent());
- ir_builder_.SetInsertPoint(loop_body_bb);
+ b_.GetInsertPoint(), "atomic_op_loop_exit");
+ llvm::BasicBlock* loop_body_bb = llvm::BasicBlock::Create(
+ b_.getContext(), "atomic_op_loop_body", b_.GetInsertBlock()->getParent());
+ b_.SetInsertPoint(loop_body_bb);
// Change preheader's successor from loop_exit_bb to loop_body_bb.
loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb);
// Emit the body of the loop that repeatedly invokes atomicCAS.
//
// Use cas_old_output to initialize cas_new_output.
- cas_old_output =
- ir_builder_.CreateLoad(cas_old_output_address, "cas_old_output");
- ir_builder_.CreateStore(cas_old_output, cas_new_output_address);
+ cas_old_output = b_.CreateLoad(cas_old_output_address, "cas_old_output");
+ b_.CreateStore(cas_old_output, cas_new_output_address);
// Emits code to calculate new_output = operation(old_output, source);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
computation, {binop_output_address, source_address},
binop_output_address));
llvm::Value* cas_new_output =
- ir_builder_.CreateLoad(cas_new_output_address, "cas_new_output");
+ b_.CreateLoad(cas_new_output_address, "cas_new_output");
// Emit code to perform the atomicCAS operation
// (cas_old_output, success) = atomicCAS(memory_address, cas_old_output,
// cas_new_output);
- llvm::Value* ret_value = ir_builder_.CreateAtomicCmpXchg(
+ llvm::Value* ret_value = b_.CreateAtomicCmpXchg(
atomic_memory_address, cas_old_output, cas_new_output,
llvm::AtomicOrdering::SequentiallyConsistent,
llvm::AtomicOrdering::SequentiallyConsistent);
// Extract the memory value returned from atomicCAS and store it as
// cas_old_output.
- ir_builder_.CreateStore(
- ir_builder_.CreateExtractValue(ret_value, 0, "cas_old_output"),
- cas_old_output_address);
+ b_.CreateStore(b_.CreateExtractValue(ret_value, 0, "cas_old_output"),
+ cas_old_output_address);
// Extract the success bit returned from atomicCAS and generate a
// conditional branch on the success bit.
- ir_builder_.CreateCondBr(
- ir_builder_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb,
- loop_body_bb);
+ b_.CreateCondBr(b_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb,
+ loop_body_bb);
// Set the insertion point to the exit basic block so that the caller of
// this method can continue emitting code to the right place.
- SetToFirstInsertPoint(loop_exit_bb, &ir_builder_);
+ SetToFirstInsertPoint(loop_exit_bb, &b_);
return Status::OK();
}
@@ -421,46 +407,49 @@ Status IrEmitter::EmitAtomicOperationForNestedComputation(
Status IrEmitter::HandleSelect(HloInstruction* select) {
auto pred = select->operand(0);
- auto on_true = select->operand(1);
- auto on_false = select->operand(2);
TF_RET_CHECK(pred->shape().element_type() == PRED);
-
- if (ShapeUtil::IsTuple(select->shape())) {
- llvm_ir::EmitTupleSelect(GetIrArray(*select, *select),
- GetIrArray(*pred, *select),
- GetBasePointer(*on_true),
- GetBasePointer(*on_false), &ir_builder_, module_);
- return Status::OK();
- }
-
// We must not call the subclass `DefaultAction` method, lest its
// `HandleSelect` call `IrEmitter::HandleSelect` and its `DefaultAction`
// assume no handler has already been called.
return IrEmitter::DefaultAction(select);
}
+Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
+ auto pred = tuple_select->operand(0);
+ auto on_true = tuple_select->operand(1);
+ auto on_false = tuple_select->operand(2);
+ TF_RET_CHECK(pred->shape().element_type() == PRED);
+ TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()));
+ TF_RET_CHECK(ShapeUtil::IsTuple(tuple_select->shape()));
+ llvm_ir::EmitTupleSelect(GetIrArray(*tuple_select, *tuple_select),
+ GetIrArray(*pred, *tuple_select),
+ GetBasePointer(*on_true), GetBasePointer(*on_false),
+ &b_, module_);
+ return Status::OK();
+}
+
namespace {
-llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* ir_builder) {
- return ir_builder->CreateExtractValue(x, {0});
-}
-
-llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* ir_builder) {
- return ir_builder->CreateExtractValue(x, {1});
-}
-
-std::pair<llvm::Value*, llvm::Value*> MultiplyComplex(
- llvm::Value* lhs_value, llvm::Value* rhs_value,
- llvm::IRBuilder<>* ir_builder) {
- llvm::Value* lhs_real = Real(lhs_value, ir_builder);
- llvm::Value* lhs_imag = Imag(lhs_value, ir_builder);
- llvm::Value* rhs_real = Real(rhs_value, ir_builder);
- llvm::Value* rhs_imag = Imag(rhs_value, ir_builder);
- llvm::Value* real_result1 = ir_builder->CreateFMul(lhs_real, rhs_real);
- llvm::Value* real_result2 = ir_builder->CreateFMul(lhs_imag, rhs_imag);
- llvm::Value* real_result = ir_builder->CreateFSub(real_result1, real_result2);
- llvm::Value* imag_result1 = ir_builder->CreateFMul(lhs_real, rhs_imag);
- llvm::Value* imag_result2 = ir_builder->CreateFMul(lhs_imag, rhs_real);
- llvm::Value* imag_result = ir_builder->CreateFAdd(imag_result1, imag_result2);
+llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* b) {
+ return b->CreateExtractValue(x, {0});
+}
+
+llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* b) {
+ return b->CreateExtractValue(x, {1});
+}
+
+std::pair<llvm::Value*, llvm::Value*> MultiplyComplex(llvm::Value* lhs_value,
+ llvm::Value* rhs_value,
+ llvm::IRBuilder<>* b) {
+ llvm::Value* lhs_real = Real(lhs_value, b);
+ llvm::Value* lhs_imag = Imag(lhs_value, b);
+ llvm::Value* rhs_real = Real(rhs_value, b);
+ llvm::Value* rhs_imag = Imag(rhs_value, b);
+ llvm::Value* real_result1 = b->CreateFMul(lhs_real, rhs_real);
+ llvm::Value* real_result2 = b->CreateFMul(lhs_imag, rhs_imag);
+ llvm::Value* real_result = b->CreateFSub(real_result1, real_result2);
+ llvm::Value* imag_result1 = b->CreateFMul(lhs_real, rhs_imag);
+ llvm::Value* imag_result2 = b->CreateFMul(lhs_imag, rhs_real);
+ llvm::Value* imag_result = b->CreateFAdd(imag_result1, imag_result2);
return {real_result, imag_result};
}
} // namespace
@@ -476,25 +465,24 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
const Shape& rhs_shape = rhs_instruction->shape();
// TODO(b/110211620): Convert to use i32 index_type when it is possible.
- llvm::Type* index_type = ir_builder_.getInt64Ty();
+ llvm::Type* index_type = b_.getInt64Ty();
llvm_ir::IrArray::Index element_index(index_type);
if (ShapeUtil::IsScalar(lhs_shape) && ShapeUtil::IsScalar(rhs_shape)) {
// If the operands are scalar, don't emit any loops.
llvm::Value* lhs_value =
- lhs_array.EmitReadArrayElement(/*index=*/element_index, &ir_builder_);
+ lhs_array.EmitReadArrayElement(/*index=*/element_index, &b_);
llvm::Value* rhs_value =
- rhs_array.EmitReadArrayElement(/*index=*/element_index, &ir_builder_);
+ rhs_array.EmitReadArrayElement(/*index=*/element_index, &b_);
llvm::Value* result;
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
- auto value = MultiplyComplex(lhs_value, rhs_value, &ir_builder_);
+ auto value = MultiplyComplex(lhs_value, rhs_value, &b_);
result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType());
- result = ir_builder_.CreateInsertValue(result, value.first, {0});
- result = ir_builder_.CreateInsertValue(result, value.second, {1});
+ result = b_.CreateInsertValue(result, value.first, {0});
+ result = b_.CreateInsertValue(result, value.second, {1});
} else {
- result = ir_builder_.CreateFMul(lhs_value, rhs_value);
+ result = b_.CreateFMul(lhs_value, rhs_value);
}
- target_array.EmitWriteArrayElement(/*index=*/element_index, result,
- &ir_builder_);
+ target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_);
return Status::OK();
}
@@ -521,11 +509,11 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
// Create loop nests which loop through the LHS operand dimensions and the RHS
// operand dimensions. The reduction dimension of the LHS and RHS are handled
// in a separate innermost loop which performs the sum of products.
- llvm_ir::ForLoopNest loop_nest(IrName(dot), &ir_builder_);
- llvm_ir::IrArray::Index lhs_index = EmitOperandArrayLoopNest(
- lhs_array, lhs_reduction_dimension, "lhs", &loop_nest);
- llvm_ir::IrArray::Index rhs_index = EmitOperandArrayLoopNest(
- rhs_array, rhs_reduction_dimension, "rhs", &loop_nest);
+ llvm_ir::ForLoopNest loop_nest(IrName(dot), &b_);
+ llvm_ir::IrArray::Index lhs_index = loop_nest.EmitOperandArrayLoopNest(
+ lhs_array, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
+ llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest(
+ rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
// Create the reduction loop which does the sum of products reduction.
std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
@@ -545,7 +533,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
llvm::Value* accum_address = llvm_ir::EmitAllocaAtFunctionEntry(
accum_type, // The pointee type of the alloca instruction.
"accum_address", // The name of the alloca instruction.
- &ir_builder_);
+ &b_);
// Initialize the accumulator in the preheader to zero.
new llvm::StoreInst(
@@ -559,27 +547,25 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
// updated_accum = accum + lhs_element * rhs_element
// *accum_address = updated_accum
TF_RET_CHECK(!reduction_loop->GetBodyBasicBlock()->empty());
- ir_builder_.SetInsertPoint(
+ b_.SetInsertPoint(
&*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt());
- llvm::Value* lhs_element =
- lhs_array.EmitReadArrayElement(lhs_index, &ir_builder_);
- llvm::Value* rhs_element =
- rhs_array.EmitReadArrayElement(rhs_index, &ir_builder_);
- llvm::Value* accum = ir_builder_.CreateLoad(accum_address);
+ llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_);
+ llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_);
+ llvm::Value* accum = b_.CreateLoad(accum_address);
llvm::Value* updated_accum;
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
- auto value = MultiplyComplex(lhs_element, rhs_element, &ir_builder_);
- llvm::Value* accum_real = Real(accum, &ir_builder_);
- llvm::Value* real_sum = ir_builder_.CreateFAdd(accum_real, value.first);
- updated_accum = ir_builder_.CreateInsertValue(accum, real_sum, {0});
- llvm::Value* accum_imag = Imag(accum, &ir_builder_);
- llvm::Value* imag_sum = ir_builder_.CreateFAdd(accum_imag, value.second);
- updated_accum = ir_builder_.CreateInsertValue(updated_accum, imag_sum, {1});
+ auto value = MultiplyComplex(lhs_element, rhs_element, &b_);
+ llvm::Value* accum_real = Real(accum, &b_);
+ llvm::Value* real_sum = b_.CreateFAdd(accum_real, value.first);
+ updated_accum = b_.CreateInsertValue(accum, real_sum, {0});
+ llvm::Value* accum_imag = Imag(accum, &b_);
+ llvm::Value* imag_sum = b_.CreateFAdd(accum_imag, value.second);
+ updated_accum = b_.CreateInsertValue(updated_accum, imag_sum, {1});
} else {
- llvm::Value* product = ir_builder_.CreateFMul(lhs_element, rhs_element);
- updated_accum = ir_builder_.CreateFAdd(accum, product);
+ llvm::Value* product = b_.CreateFMul(lhs_element, rhs_element);
+ updated_accum = b_.CreateFAdd(accum, product);
}
- ir_builder_.CreateStore(updated_accum, accum_address);
+ b_.CreateStore(updated_accum, accum_address);
// After the reduction loop exits, store the accumulator into the target
// address. The index into the target address is the concatenation of the rhs
@@ -596,16 +582,15 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
target_index.push_back(rhs_index[dimension]);
}
}
- SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &ir_builder_);
+ SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_);
target_array.EmitWriteArrayElement(
target_index,
- ir_builder_.CreateLoad(
- accum_address), // The value written to the target array.
- &ir_builder_);
+ b_.CreateLoad(accum_address), // The value written to the target array.
+ &b_);
// Set the IR builder insert point to the exit basic block of the outer most
// loop. This ensures later instructions are inserted after this loop nest.
- ir_builder_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
+ b_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
return Status::OK();
}
@@ -647,11 +632,10 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
[=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
// Initialize an accumulator with init_value.
llvm::AllocaInst* accumulator_addr =
- ir_builder_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
+ b_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
reduce->shape().element_type(), module_));
- ir_builder_.CreateStore(
- ir_builder_.CreateLoad(GetBasePointer(*init_value)),
- accumulator_addr);
+ b_.CreateStore(b_.CreateLoad(GetBasePointer(*init_value)),
+ accumulator_addr);
// The enclosing loops go over all the target elements. Now we have to
// compute the actual target element. For this, we build a new loop nest
@@ -659,12 +643,12 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
// AddLoopsForShapeOnDimensions will return an Index where induction
// Value*s are placed for each dimension in dimensions, and all the rest
// are nullptrs.
- llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &ir_builder_);
+ llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_);
const llvm_ir::IrArray::Index reduced_dims_index =
loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
"reduction_dim");
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
// Build a full index for the input argument, using reduced_dims_index
// as the base. In reduced_dims_index only the reduction dimensions are
@@ -683,13 +667,12 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
// Apply the reduction function to the loaded value.
llvm::Value* input_address =
- GetIrArray(*arg, *reduce)
- .EmitArrayElementAddress(input_index, &ir_builder_);
+ GetIrArray(*arg, *reduce).EmitArrayElementAddress(input_index, &b_);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*function, {accumulator_addr, input_address}, accumulator_addr));
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
- return ir_builder_.CreateLoad(accumulator_addr);
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
+ return b_.CreateLoad(accumulator_addr);
});
}
@@ -702,8 +685,8 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
for (HloInstruction* operand : fusion->operands()) {
parameter_arrays.push_back(GetIrArray(*operand, *fusion));
}
- GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_,
- &ir_builder_, GetNestedComputer());
+ GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
+ GetNestedComputer());
FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
@@ -737,17 +720,16 @@ Status IrEmitter::HandleRng(HloInstruction* random) {
ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
for (const HloInstruction* operand : random->operands()) {
operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
- return GetIrArray(*operand, *random)
- .EmitReadArrayElement(index, &ir_builder_);
+ return GetIrArray(*operand, *random).EmitReadArrayElement(index, &b_);
};
}
// Emits a single-threaded loop because the loop body generated by the element
// generator for Rng can't be parallelized (b/32333178).
return llvm_ir::LoopEmitter(
- GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_,
+ GpuElementalIrEmitter(hlo_module_config_, module_, &b_,
GetNestedComputer())
.MakeElementGenerator(random, operand_to_generator),
- GetIrArray(*random, *random), &ir_builder_)
+ GetIrArray(*random, *random), &b_)
.EmitLoop(IrName(random));
}
@@ -774,34 +756,9 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
"to a cudnn CustomCall using CudnnBatchNormRewriter.");
}
-llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest(
- const llvm_ir::IrArray& operand_array, int64 reduction_dimension,
- tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest) {
- // Prepares the dimension list we will use to emit the loop nest. Outermost
- // loops are added first. Add loops in major-to-minor order, and skip the
- // reduction dimension.
- std::vector<int64> dimensions;
- const Shape& shape = operand_array.GetShape();
- for (int i = 0; i < LayoutUtil::MinorToMajor(shape).size(); ++i) {
- int64 dimension = LayoutUtil::Major(shape.layout(), i);
- if (dimension != reduction_dimension) {
- dimensions.push_back(dimension);
- }
- }
-
- // Create loop nest with one for-loop for each dimension of the
- // output.
- llvm_ir::IrArray::Index index =
- loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix);
- // Verify every dimension except the reduction dimension was set in the index.
- for (size_t dimension = 0; dimension < index.size(); ++dimension) {
- if (dimension == reduction_dimension) {
- DCHECK_EQ(nullptr, index[dimension]);
- } else {
- DCHECK_NE(nullptr, index[dimension]);
- }
- }
- return index;
+Status IrEmitter::HandleIota(HloInstruction*) {
+ // TODO(b/64798317): implement iota on GPU.
+ return Unimplemented("Iota is not implemented on GPU.");
}
StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
@@ -810,16 +767,16 @@ StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(
computation.root_instruction()->shape().element_type(), module_),
- "return_buffer", &ir_builder_);
+ "return_buffer", &b_);
std::vector<llvm::Value*> parameter_buffers;
for (llvm::Value* parameter_element : parameter_elements) {
parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
- parameter_element->getType(), "parameter_buffer", &ir_builder_));
- ir_builder_.CreateStore(parameter_element, parameter_buffers.back());
+ parameter_element->getType(), "parameter_buffer", &b_));
+ b_.CreateStore(parameter_element, parameter_buffers.back());
}
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers,
return_buffer));
- return ir_builder_.CreateLoad(return_buffer);
+ return b_.CreateLoad(return_buffer);
}
} // namespace gpu