aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc25
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc376
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h7
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc688
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.h126
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc120
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h40
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc103
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc264
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD9
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h400
16 files changed, 1242 insertions, 931 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index f164a614f1..025f0b0195 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2689,6 +2689,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index f0adfc5d45..4cd192873f 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -278,6 +278,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index db54454707..c8312d80bd 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -30,15 +30,16 @@ limitations under the License.
namespace xla {
namespace cpu {
-StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
- PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const {
+StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
+ llvm::Value* lhs,
+ llvm::Value* rhs) {
string function_name;
bool cast_result_to_fp16 = false;
switch (prim_type) {
case F16:
cast_result_to_fp16 = true;
- lhs = b_->CreateFPCast(lhs, b_->getFloatTy());
- rhs = b_->CreateFPCast(rhs, b_->getFloatTy());
+ lhs = FPCast(lhs, b_->getFloatTy());
+ rhs = FPCast(rhs, b_->getFloatTy());
TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "atan2f";
@@ -58,21 +59,21 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
// Create an instruction to call the function.
- llvm::Value* result = b_->CreateCall(function, {lhs, rhs});
+ llvm::Value* result = Call(function, {lhs, rhs});
if (cast_result_to_fp16) {
- result = b_->CreateFPCast(result, b_->getHalfTy());
+ result = FPCast(result, b_->getHalfTy());
}
return result;
}
-StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
+ llvm::Value* value) {
bool cast_result_to_fp16 = false;
string function_name;
switch (prim_type) {
case F16:
cast_result_to_fp16 = true;
- value = b_->CreateFPCast(value, b_->getFloatTy());
+ value = FPCast(value, b_->getFloatTy());
TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "tanhf";
@@ -91,16 +92,16 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
// Create an instruction to call the function.
- llvm::Value* result = b_->CreateCall(function, value);
+ llvm::Value* result = Call(function, value);
if (cast_result_to_fp16) {
- result = b_->CreateFPCast(result, b_->getHalfTy());
+ result = FPCast(result, b_->getHalfTy());
}
return result;
}
llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const {
+ const HloToElementGeneratorMap& operand_to_generator) {
if (hlo->opcode() == HloOpcode::kMap) {
return [this, hlo, &operand_to_generator](
const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
index 76833e765d..e3fba9306b 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
@@ -36,13 +36,13 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
llvm_ir::ElementGenerator MakeElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const override;
+ const HloToElementGeneratorMap& operand_to_generator) override;
protected:
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
- llvm::Value* rhs) const override;
+ llvm::Value* rhs) override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
IrEmitter* ir_emitter_;
};
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 321c2e9896..903e73f606 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -170,9 +170,9 @@ IrEmitter::~IrEmitter() {}
Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
VLOG(2) << "HandleBitcast: " << bitcast->ToString();
emitted_value_[bitcast] =
- b_.CreateBitCast(GetEmittedValueFor(bitcast->operand(0)),
- IrShapeType(bitcast->shape())->getPointerTo(),
- AsStringRef(IrName(bitcast)));
+ BitCast(GetEmittedValueFor(bitcast->operand(0)),
+ IrShapeType(bitcast->shape())->getPointerTo(),
+ AsStringRef(IrName(bitcast)));
return Status::OK();
}
@@ -439,22 +439,22 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
// of size exactly 'length_32', and the runtime is responsible for
// check-failing the process if there is a mismatch, versus passing us back a
// buffer that we might overrun.
- llvm::Value* acquired_pointer = b_.CreateCall(
- acquire_func,
- {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)});
+ llvm::Value* acquired_pointer =
+ Call(acquire_func,
+ {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)});
if (kind == XfeedKind::kInfeed) {
// Copy to the program buffer address from the acquired buffer.
- b_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer,
- /*SrcAlign=*/1, length_32);
+ MemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer,
+ /*SrcAlign=*/1, length_32);
} else {
// Outfeed -- copy from the in-program address to the acquired buffer.
- b_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address,
- /*SrcAlign=*/1, length_32);
+ MemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address,
+ /*SrcAlign=*/1, length_32);
}
- b_.CreateCall(release_func, {b_.getInt32(length_32), acquired_pointer,
- shape_ptr, b_.getInt32(shape_length)});
+ Call(release_func, {b_.getInt32(length_32), acquired_pointer, shape_ptr,
+ b_.getInt32(shape_length)});
return Status::OK();
}
@@ -518,8 +518,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
"reduce_window_accumulator_address", &b_,
MinimumAlignmentForPrimitiveType(operand_element_type));
- b_.CreateStore(b_.CreateLoad(GetEmittedValueFor(reduce_window->operand(1))),
- accumulator_address);
+ Store(Load(GetEmittedValueFor(reduce_window->operand(1))),
+ accumulator_address);
llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_);
std::vector<int64> window_size;
@@ -536,22 +536,21 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
llvm::Value* in_bounds_condition = nullptr;
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* strided_index =
- b_.CreateNSWMul(index[i], b_.getInt64(window.dimensions(i).stride()));
- input_index[i] =
- b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]),
- b_.getInt64(window.dimensions(i).padding_low()));
+ NSWMul(index[i], b_.getInt64(window.dimensions(i).stride()));
+ input_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
+ b_.getInt64(window.dimensions(i).padding_low()));
// We need to check if 0 <= input_index[i] < bound, as otherwise we are in
// the padding so that we can skip the computation. That is equivalent to
// input_index[i] < bound as an *unsigned* comparison, since a negative
// value will wrap to a large positive value.
- llvm::Value* index_condition = b_.CreateICmpULT(
- input_index[i],
- b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
+ llvm::Value* index_condition =
+ ICmpULT(input_index[i],
+ b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
if (in_bounds_condition == nullptr) {
in_bounds_condition = index_condition;
} else {
- in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition);
+ in_bounds_condition = And(in_bounds_condition, index_condition);
}
}
CHECK(in_bounds_condition != nullptr);
@@ -564,12 +563,12 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
llvm_ir::IrArray input_array(GetIrArrayFor(operand));
llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_);
llvm::Value* result = EmitThreadLocalCall(
- *reduce_window->to_apply(),
- {b_.CreateLoad(accumulator_address), input_value}, "reducer_function");
- b_.CreateStore(result, accumulator_address);
+ *reduce_window->to_apply(), {Load(accumulator_address), input_value},
+ "reducer_function");
+ Store(result, accumulator_address);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
- return b_.CreateLoad(accumulator_address);
+ return Load(accumulator_address);
}
Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
@@ -646,7 +645,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"),
[this, init_value](const llvm_ir::IrArray::Index& target_index) {
llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
- return b_.CreateLoad(init_value_addr);
+ return Load(init_value_addr);
}));
// Create a loop to iterate over the source array to scatter to the output.
@@ -666,7 +665,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_);
llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
b_.getInt1Ty(), "initialized_flag_address", &b_);
- b_.CreateStore(b_.getInt1(false), initialized_flag_address);
+ Store(b_.getInt1(false), initialized_flag_address);
// Create the inner loop to iterate over the window.
llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_);
@@ -684,15 +683,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
llvm_ir::IrArray::Index operand_index(b_.getInt64Ty(), source_index.size());
llvm::Value* in_bounds_condition = b_.getTrue();
for (int64 i = 0; i < rank; ++i) {
- llvm::Value* strided_index = b_.CreateNSWMul(
- source_index[i], b_.getInt64(window.dimensions(i).stride()));
- operand_index[i] =
- b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]),
- b_.getInt64(window.dimensions(i).padding_low()));
- llvm::Value* index_condition = b_.CreateICmpULT(
- operand_index[i],
- b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
- in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition);
+ llvm::Value* strided_index =
+ NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride()));
+ operand_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
+ b_.getInt64(window.dimensions(i).padding_low()));
+ llvm::Value* index_condition =
+ ICmpULT(operand_index[i],
+ b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
+ in_bounds_condition = And(in_bounds_condition, index_condition);
}
CHECK(in_bounds_condition != nullptr);
@@ -702,7 +700,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
- b_.CreateLoad(initialized_flag_address), "initialized", &b_);
+ Load(initialized_flag_address), "initialized", &b_);
// If the initialized_flag is false, initialize the selected value and index
// with the currently visiting operand.
@@ -711,38 +709,37 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
[&](const llvm_ir::IrArray::Index& operand_index) {
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
- b_.CreateStore(operand_index[i], selected_index_address_slot);
+ InBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ Store(operand_index[i], selected_index_address_slot);
}
};
llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
llvm::Value* operand_data =
operand_array.EmitReadArrayElement(operand_index, &b_);
- b_.CreateStore(operand_data, selected_value_address);
+ Store(operand_data, selected_value_address);
save_operand_index(operand_index);
- b_.CreateStore(b_.getInt1(true), initialized_flag_address);
+ Store(b_.getInt1(true), initialized_flag_address);
// If the initialized_flag is true, call the `select` function to potentially
// update the selected value and index with the currently visiting operand.
SetToFirstInsertPoint(if_initialized.true_block, &b_);
llvm::Value* operand_address =
operand_array.EmitArrayElementAddress(operand_index, &b_);
- llvm::Value* operand_element = b_.CreateLoad(operand_address);
+ llvm::Value* operand_element = Load(operand_address);
llvm::Value* result = EmitThreadLocalCall(
*select_and_scatter->select(),
- {b_.CreateLoad(selected_value_address), operand_element},
- "select_function");
+ {Load(selected_value_address), operand_element}, "select_function");
// If the 'select' function returns false, update the selected value and the
// index to the currently visiting operand.
- llvm::Value* cond = b_.CreateICmpNE(
+ llvm::Value* cond = ICmpNE(
result,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
"boolean_predicate");
llvm_ir::LlvmIfData if_select_lhs =
llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
- b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address);
+ Store(Load(operand_address), selected_value_address);
save_operand_index(operand_index);
// After iterating over the window elements, scatter the source element to
@@ -753,8 +750,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
llvm_ir::IrArray::Index selected_index(source_index.GetType());
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
- selected_index.push_back(b_.CreateLoad(selected_index_address_slot));
+ InBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ selected_index.push_back(Load(selected_index_address_slot));
}
llvm_ir::IrArray source_array(GetIrArrayFor(source));
llvm::Value* source_value =
@@ -836,7 +833,7 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
lhs_llvm_type, "convolution_sum_address", &b_,
MinimumAlignmentForPrimitiveType(lhs_element_type));
llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type);
- b_.CreateStore(constant_zero, sum_address);
+ Store(constant_zero, sum_address);
llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_);
std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
@@ -863,11 +860,11 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
llvm::Value* kernel_index,
const WindowDimension& window_dim) {
llvm::Value* strided_index =
- b_.CreateNSWMul(output_index, b_.getInt64(window_dim.stride()));
- llvm::Value* dilated_kernel_index = b_.CreateNSWMul(
- kernel_index, b_.getInt64(window_dim.window_dilation()));
- return b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, dilated_kernel_index),
- b_.getInt64(window_dim.padding_low()));
+ NSWMul(output_index, b_.getInt64(window_dim.stride()));
+ llvm::Value* dilated_kernel_index =
+ NSWMul(kernel_index, b_.getInt64(window_dim.window_dilation()));
+ return NSWSub(NSWAdd(strided_index, dilated_kernel_index),
+ b_.getInt64(window_dim.padding_low()));
};
std::vector<llvm::Value*> input_spatial(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
@@ -884,9 +881,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
// Also need to check that the input coordinates are not in one of the
// holes created by base dilation.
const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
- llvm::Value* remainder =
- b_.CreateSRem(input_index, b_.getInt64(base_dilation));
- return b_.CreateICmpEQ(remainder, b_.getInt64(0));
+ llvm::Value* remainder = SRem(input_index, b_.getInt64(base_dilation));
+ return ICmpEQ(remainder, b_.getInt64(0));
};
llvm::Value* in_bounds_condition = b_.getInt1(true);
@@ -894,17 +890,17 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound(
lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
window.dimensions(i).base_dilation()));
- llvm::Value* dim_in_bound = b_.CreateICmpULT(input_spatial[i], input_bound);
+ llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound);
llvm::Value* dim_not_in_hole =
not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
- llvm::Value* dim_ok = b_.CreateAnd(dim_in_bound, dim_not_in_hole);
- in_bounds_condition = b_.CreateAnd(in_bounds_condition, dim_ok);
+ llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole);
+ in_bounds_condition = And(in_bounds_condition, dim_ok);
}
// Now we need to map the dilated base coordinates back to the actual
// data indices on the lhs.
const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
- return b_.CreateSDiv(input_index, b_.getInt64(base_dilation));
+ return SDiv(input_index, b_.getInt64(base_dilation));
};
for (int i = 0; i < num_spatial_dims; ++i) {
input_spatial[i] =
@@ -929,8 +925,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
for (int i = 0; i < num_spatial_dims; ++i) {
kernel_index[dnums.kernel_spatial_dimensions(i)] =
window.dimensions(i).window_reversal()
- ? b_.CreateNSWSub(b_.getInt64(window.dimensions(i).size() - 1),
- kernel_spatial[i])
+ ? NSWSub(b_.getInt64(window.dimensions(i).size() - 1),
+ kernel_spatial[i])
: kernel_spatial[i];
}
@@ -939,13 +935,13 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
llvm_ir::IrArray input_array(GetIrArrayFor(lhs));
llvm::Value* product =
- b_.CreateFMul(input_array.EmitReadArrayElement(input_index, &b_),
- kernel_array.EmitReadArrayElement(kernel_index, &b_));
- llvm::Value* sum = b_.CreateFAdd(b_.CreateLoad(sum_address), product);
- b_.CreateStore(sum, sum_address);
+ FMul(input_array.EmitReadArrayElement(input_index, &b_),
+ kernel_array.EmitReadArrayElement(kernel_index, &b_));
+ llvm::Value* sum = FAdd(Load(sum_address), product);
+ Store(sum, sum_address);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
- return b_.CreateLoad(sum_address);
+ return Load(sum_address);
}
Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
@@ -1071,34 +1067,32 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
conv_func->setCallingConv(llvm::CallingConv::C);
conv_func->setDoesNotThrow();
conv_func->setOnlyAccessesArgMemory();
- b_.CreateCall(
- conv_func,
- {
- GetExecutableRunOptionsArgument(),
- b_.CreateBitCast(GetEmittedValueFor(convolution), ir_ptr_type),
- b_.CreateBitCast(lhs_address, ir_ptr_type),
- b_.CreateBitCast(rhs_address, ir_ptr_type),
- b_.getInt64(input_batch),
- b_.getInt64(input_rows),
- b_.getInt64(input_cols),
- b_.getInt64(input_channels),
- b_.getInt64(kernel_rows),
- b_.getInt64(kernel_cols),
- b_.getInt64(kernel_channels),
- b_.getInt64(kernel_filters),
- b_.getInt64(output_rows),
- b_.getInt64(output_cols),
- b_.getInt64(row_stride),
- b_.getInt64(col_stride),
- b_.getInt64(padding_top),
- b_.getInt64(padding_bottom),
- b_.getInt64(padding_left),
- b_.getInt64(padding_right),
- b_.getInt64(lhs_row_dilation),
- b_.getInt64(lhs_col_dilation),
- b_.getInt64(rhs_row_dilation),
- b_.getInt64(rhs_col_dilation),
- });
+ Call(conv_func, {
+ GetExecutableRunOptionsArgument(),
+ BitCast(GetEmittedValueFor(convolution), ir_ptr_type),
+ BitCast(lhs_address, ir_ptr_type),
+ BitCast(rhs_address, ir_ptr_type),
+ b_.getInt64(input_batch),
+ b_.getInt64(input_rows),
+ b_.getInt64(input_cols),
+ b_.getInt64(input_channels),
+ b_.getInt64(kernel_rows),
+ b_.getInt64(kernel_cols),
+ b_.getInt64(kernel_channels),
+ b_.getInt64(kernel_filters),
+ b_.getInt64(output_rows),
+ b_.getInt64(output_cols),
+ b_.getInt64(row_stride),
+ b_.getInt64(col_stride),
+ b_.getInt64(padding_top),
+ b_.getInt64(padding_bottom),
+ b_.getInt64(padding_left),
+ b_.getInt64(padding_right),
+ b_.getInt64(lhs_row_dilation),
+ b_.getInt64(lhs_col_dilation),
+ b_.getInt64(rhs_row_dilation),
+ b_.getInt64(rhs_col_dilation),
+ });
return Status::OK();
}
@@ -1158,15 +1152,14 @@ Status IrEmitter::HandleFft(HloInstruction* fft) {
fft_func->setDoesNotThrow();
fft_func->setOnlyAccessesInaccessibleMemOrArgMem();
const int fft_rank = fft_length.size();
- b_.CreateCall(
- fft_func,
- {GetExecutableRunOptionsArgument(),
- b_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type),
- b_.CreateBitCast(operand_address, int8_ptr_type),
- b_.getInt32(fft->fft_type()), b_.getInt32(fft_rank),
- b_.getInt64(input_batch), b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
- b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
- b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)});
+ Call(fft_func,
+ {GetExecutableRunOptionsArgument(),
+ BitCast(GetEmittedValueFor(fft), int8_ptr_type),
+ BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()),
+ b_.getInt32(fft_rank), b_.getInt64(input_batch),
+ b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
+ b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
+ b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)});
return Status::OK();
}
@@ -1205,8 +1198,8 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape));
// TODO(b/63762267): Be more aggressive about specifying alignment.
- b_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
- /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape));
+ MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
+ /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape));
}
llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_, module_);
return Status::OK();
@@ -1465,19 +1458,19 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
accumulator_shard_type, "accumulator", &b_, 0));
}
- llvm::Value* init_value_ssa = b_.CreateLoad(GetEmittedValueFor(init_value));
+ llvm::Value* init_value_ssa = Load(GetEmittedValueFor(init_value));
for (llvm::Value* accumulator_shard : accumulator) {
llvm::Value* initial_value;
auto shard_type = accumulator_shard->getType()->getPointerElementType();
if (auto vector_type = llvm::dyn_cast<llvm::VectorType>(shard_type)) {
initial_value =
- b_.CreateVectorSplat(vector_type->getNumElements(), init_value_ssa);
+ VectorSplat(vector_type->getNumElements(), init_value_ssa);
} else {
initial_value = init_value_ssa;
}
- b_.CreateAlignedStore(initial_value, accumulator_shard, element_alignment);
+ AlignedStore(initial_value, accumulator_shard, element_alignment);
}
llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"),
@@ -1499,24 +1492,24 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
}
CHECK(output_index.end() == it);
- llvm::Value* input_address = b_.CreateBitCast(
+ llvm::Value* input_address = BitCast(
arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy());
for (int i = 0; i < accumulator.size(); i++) {
auto input_address_typed =
- b_.CreateBitCast(input_address, accumulator[i]->getType());
+ BitCast(input_address, accumulator[i]->getType());
auto current_accumulator_value =
- b_.CreateAlignedLoad(accumulator[i], element_alignment);
- auto addend = b_.CreateAlignedLoad(input_address_typed, element_alignment);
+ AlignedLoad(accumulator[i], element_alignment);
+ auto addend = AlignedLoad(input_address_typed, element_alignment);
arg_array.AnnotateLoadStoreInstructionWithMetadata(addend);
auto reduced_result =
reduction_generator(&b_, current_accumulator_value, addend);
- b_.CreateAlignedStore(reduced_result, accumulator[i], element_alignment);
+ AlignedStore(reduced_result, accumulator[i], element_alignment);
if (i != (accumulator.size() - 1)) {
- input_address = b_.CreateConstInBoundsGEP1_32(reduced_result->getType(),
- input_address_typed, 1);
+ input_address = ConstInBoundsGEP1_32(reduced_result->getType(),
+ input_address_typed, 1);
}
}
@@ -1525,8 +1518,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
ShardedVector result_ssa;
result_ssa.reserve(accumulator.size());
for (auto accumulator_shard : accumulator) {
- result_ssa.push_back(
- b_.CreateAlignedLoad(accumulator_shard, element_alignment));
+ result_ssa.push_back(AlignedLoad(accumulator_shard, element_alignment));
}
return result_ssa;
}
@@ -1535,18 +1527,18 @@ void IrEmitter::EmitShardedVectorStore(
llvm::Value* store_address, const std::vector<llvm::Value*>& value_to_store,
const int alignment, const llvm_ir::IrArray& containing_array) {
for (int i = 0; i < value_to_store.size(); i++) {
- auto store_address_typed = b_.CreateBitCast(
- store_address,
- llvm::PointerType::getUnqual(value_to_store[i]->getType()));
+ auto store_address_typed =
+ BitCast(store_address,
+ llvm::PointerType::getUnqual(value_to_store[i]->getType()));
- auto store_instruction = b_.CreateAlignedStore(
- value_to_store[i], store_address_typed, alignment);
+ auto store_instruction =
+ AlignedStore(value_to_store[i], store_address_typed, alignment);
containing_array.AnnotateLoadStoreInstructionWithMetadata(
store_instruction);
if (i != (value_to_store.size() - 1)) {
- store_address = b_.CreateConstInBoundsGEP1_32(
- value_to_store[i]->getType(), store_address_typed, 1);
+ store_address = ConstInBoundsGEP1_32(value_to_store[i]->getType(),
+ store_address_typed, 1);
}
}
}
@@ -1711,8 +1703,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator",
&b_, MinimumAlignmentForPrimitiveType(accumulator_type));
llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
- llvm::Value* load_init_value = b_.CreateLoad(init_value_addr);
- b_.CreateStore(load_init_value, accumulator_addr);
+ llvm::Value* load_init_value = Load(init_value_addr);
+ Store(load_init_value, accumulator_addr);
// The enclosing loops go over all the target elements. Now we have to compute
// the actual target element. For this, we build a new loop nest to iterate
@@ -1745,12 +1737,12 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
// Apply the reduction function to the loaded value.
llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_);
llvm::Value* result = EmitThreadLocalCall(
- *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element},
+ *reduce->to_apply(), {Load(accumulator_addr), input_element},
"reduce_function");
- b_.CreateStore(result, accumulator_addr);
+ Store(result, accumulator_addr);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
- return b_.CreateLoad(accumulator_addr);
+ return Load(accumulator_addr);
}
Status IrEmitter::HandleReduce(HloInstruction* reduce) {
@@ -1988,7 +1980,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) {
[this, pad](const llvm_ir::IrArray::Index& target_index) {
const HloInstruction* padding_value = pad->operand(1);
llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value);
- return b_.CreateLoad(padding_value_addr);
+ return Load(padding_value_addr);
}));
// Create a loop to iterate over the operand elements and update the output
@@ -2010,10 +2002,10 @@ Status IrEmitter::HandlePad(HloInstruction* pad) {
const PaddingConfig& padding_config = pad->padding_config();
llvm_ir::IrArray::Index output_index(operand_index.GetType());
for (size_t i = 0; i < operand_index.size(); ++i) {
- llvm::Value* offset = b_.CreateMul(
- operand_index[i],
- b_.getInt64(padding_config.dimensions(i).interior_padding() + 1));
- llvm::Value* index = b_.CreateAdd(
+ llvm::Value* offset =
+ Mul(operand_index[i],
+ b_.getInt64(padding_config.dimensions(i).interior_padding() + 1));
+ llvm::Value* index = Add(
offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low()));
output_index.push_back(index);
}
@@ -2124,10 +2116,10 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
for (size_t i = 0; i < operands.size(); ++i) {
const HloInstruction* operand = operands[i];
llvm::Value* operand_as_i8ptr =
- b_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type);
+ PointerCast(GetEmittedValueFor(operand), i8_ptr_type);
llvm::Value* slot_in_operands_alloca =
- b_.CreateInBoundsGEP(operands_alloca, {b_.getInt64(i)});
- b_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca);
+ InBoundsGEP(operands_alloca, {b_.getInt64(i)});
+ Store(operand_as_i8ptr, slot_in_operands_alloca);
}
auto* custom_call_ir_function =
llvm::cast<llvm::Function>(module_->getOrInsertFunction(
@@ -2139,9 +2131,9 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
auto* output_address_arg =
- b_.CreatePointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
+ PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
- b_.CreateCall(custom_call_ir_function, {output_address_arg, operands_alloca});
+ Call(custom_call_ir_function, {output_address_arg, operands_alloca});
return Status::OK();
}
@@ -2200,15 +2192,14 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
llvm::BasicBlock* header_bb = llvm::BasicBlock::Create(
module_->getContext(), AsStringRef(IrName(xla_while, "header")),
compute_function_->function());
- b_.CreateBr(header_bb);
+ Br(header_bb);
b_.SetInsertPoint(header_bb);
// Calls the condition function to determine whether to proceed with the
// body. It must return a bool, so use the scalar call form.
EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond"));
- llvm::Value* while_predicate = b_.CreateICmpNE(
- b_.CreateLoad(
- GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
+ llvm::Value* while_predicate = ICmpNE(
+ Load(GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
// Branches to the body or to the while exit depending on the condition.
@@ -2217,7 +2208,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
compute_function_->function());
llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create(
module_->getContext(), AsStringRef(IrName(xla_while, "exit")));
- b_.CreateCondBr(while_predicate, body_bb, exit_bb);
+ CondBr(while_predicate, body_bb, exit_bb);
// Calls the body function from the body block.
b_.SetInsertPoint(body_bb);
@@ -2226,7 +2217,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body"));
// Finishes with a branch back to the header.
- b_.CreateBr(header_bb);
+ Br(header_bb);
// Adds the exit block to the function and sets the insert point there.
compute_function_->function()->getBasicBlockList().push_back(exit_bb);
@@ -2273,7 +2264,6 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
output_min2maj.end());
llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
- llvm::Type* i8_type = b_.getInt8Ty();
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate));
llvm_ir::IrArray target_array = GetIrArrayFor(concatenate);
@@ -2296,9 +2286,9 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
// Contiguous subregions from each operand to the concatenate contribute to a
// contiguous subregion in the target buffer starting at target_region_begin.
llvm::Value* target_region_begin =
- b_.CreateBitCast(target_array.EmitArrayElementAddress(
- outer_dims_index, &b_, "target_region"),
- i8_ptr_type);
+ BitCast(target_array.EmitArrayElementAddress(outer_dims_index, &b_,
+ "target_region"),
+ i8_ptr_type);
int64 byte_offset_into_target_region = 0;
int64 inner_dims_product =
@@ -2312,13 +2302,12 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
for (HloInstruction* operand : operands) {
const Shape& input_shape = operand->shape();
llvm_ir::IrArray source_array = GetIrArrayFor(operand);
- llvm::Value* copy_source_address = b_.CreateBitCast(
+ llvm::Value* copy_source_address = BitCast(
source_array.EmitArrayElementAddress(outer_dims_index, &b_, "src_addr"),
i8_ptr_type);
llvm::Value* copy_target_address =
- b_.CreateGEP(i8_type, target_region_begin,
- b_.getInt64(byte_offset_into_target_region));
+ GEP(target_region_begin, b_.getInt64(byte_offset_into_target_region));
EmitTransferElements(
copy_target_address, copy_source_address,
@@ -2350,15 +2339,15 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_));
if (element_count == 1) {
- auto* load_instruction = b_.CreateAlignedLoad(
- b_.CreateBitCast(source, primitive_ptr_type), element_alignment);
+ auto* load_instruction =
+ AlignedLoad(BitCast(source, primitive_ptr_type), element_alignment);
source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction);
- auto* store_instruction = b_.CreateAlignedStore(
- load_instruction, b_.CreateBitCast(target, primitive_ptr_type),
- element_alignment);
+ auto* store_instruction =
+ AlignedStore(load_instruction, BitCast(target, primitive_ptr_type),
+ element_alignment);
target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
} else {
- auto* memcpy_instruction = b_.CreateMemCpy(
+ auto* memcpy_instruction = MemCpy(
target, /*DstAlign=*/element_alignment, source,
/*SrcAlign=*/element_alignment, element_count * primitive_type_size);
@@ -2420,9 +2409,9 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
// cond_result = true_computation(true_operand)
// else
// cond_result = false_computation(false_operand)
- llvm::LoadInst* pred_value = b_.CreateLoad(
- GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value");
- llvm::Value* pred_cond = b_.CreateICmpNE(
+ llvm::LoadInst* pred_value =
+ Load(GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value");
+ llvm::Value* pred_cond = ICmpNE(
pred_value,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
"boolean_predicate");
@@ -2509,8 +2498,8 @@ llvm::Value* IrEmitter::GetProfileCounterCommon(
int64 prof_counter_idx = it->second;
string counter_name = IrName("prof_counter", hlo.name());
- return b_.CreateGEP(GetProfileCountersArgument(),
- b_.getInt64(prof_counter_idx), AsStringRef(counter_name));
+ return GEP(GetProfileCountersArgument(), b_.getInt64(prof_counter_idx),
+ AsStringRef(counter_name));
}
void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b,
@@ -2664,8 +2653,7 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
llvm::Value* params = compute_function_->parameters_arg();
llvm::Value* param_address_offset =
llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
- llvm::LoadInst* param_address_untyped =
- b_.CreateLoad(param_address_offset);
+ llvm::LoadInst* param_address_untyped = Load(param_address_offset);
if (!ShapeUtil::IsOpaque(target_shape)) {
AttachAlignmentMetadataForLoad(param_address_untyped, target_shape);
@@ -2693,8 +2681,7 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
}
return buf_it->second;
}();
- return b_.CreateBitCast(tempbuf_address,
- IrShapeType(target_shape)->getPointerTo());
+ return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo());
}
llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
@@ -2702,7 +2689,7 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
GetTempBuffersArgument(), slice.index(), &b_);
- llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr);
+ llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr);
if (hlo_module_config_.debug_options()
.xla_llvm_enable_invariant_load_metadata()) {
tempbuf_address_base->setMetadata(
@@ -2716,10 +2703,10 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
if (slice.offset() > 0) {
// Adjust the address to account for the slice offset.
tempbuf_address_untyped =
- b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
+ InBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
}
- return b_.CreateBitCast(tempbuf_address_untyped,
- IrShapeType(target_shape)->getPointerTo());
+ return BitCast(tempbuf_address_untyped,
+ IrShapeType(target_shape)->getPointerTo());
}
llvm::Value* IrEmitter::EmitTempBufferPointer(
@@ -2805,8 +2792,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source,
llvm::Value* destination_value = GetEmittedValueFor(&destination);
int64 source_size = ByteSizeOf(source.shape());
// TODO(b/63762267): Be more aggressive about specifying alignment.
- b_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value,
- /*SrcAlign=*/1, source_size);
+ MemCpy(destination_value, /*DstAlign=*/1, source_value,
+ /*SrcAlign=*/1, source_size);
return Status::OK();
}
@@ -2860,7 +2847,7 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
CHECK(!parameter->getType()->isPointerTy());
llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry(
parameter->getType(), "arg_addr", &b_);
- b_.CreateStore(parameter, parameter_addr);
+ Store(parameter, parameter_addr);
parameter_addrs.push_back(parameter_addr);
}
@@ -2869,29 +2856,28 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
absl::StrCat(name, "_retval_addr"), &b_,
MinimumAlignmentForPrimitiveType(return_type));
- b_.CreateCall(
- FindOrDie(emitted_functions_, &callee),
- GetArrayFunctionCallArguments(
- parameter_addrs, &b_, name,
- /*return_value_buffer=*/return_value_buffer,
- /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/
- llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
- /*profile_counters_arg=*/GetProfileCountersArgument()));
+ Call(FindOrDie(emitted_functions_, &callee),
+ GetArrayFunctionCallArguments(
+ parameter_addrs, &b_, name,
+ /*return_value_buffer=*/return_value_buffer,
+ /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
+ /*temp_buffers_arg=*/
+ llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
+ /*profile_counters_arg=*/GetProfileCountersArgument()));
- return b_.CreateLoad(return_value_buffer);
+ return Load(return_value_buffer);
}
void IrEmitter::EmitGlobalCall(const HloComputation& callee,
absl::string_view name) {
- b_.CreateCall(FindOrDie(emitted_functions_, &callee),
- GetArrayFunctionCallArguments(
- /*parameter_addresses=*/{}, &b_, name,
- /*return_value_buffer=*/
- llvm::Constant::getNullValue(b_.getInt8PtrTy()),
- /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/GetTempBuffersArgument(),
- /*profile_counters_arg=*/GetProfileCountersArgument()));
+ Call(FindOrDie(emitted_functions_, &callee),
+ GetArrayFunctionCallArguments(
+ /*parameter_addresses=*/{}, &b_, name,
+ /*return_value_buffer=*/
+ llvm::Constant::getNullValue(b_.getInt8PtrTy()),
+ /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
+ /*temp_buffers_arg=*/GetTempBuffersArgument(),
+ /*profile_counters_arg=*/GetProfileCountersArgument()));
}
llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 99c080b3db..ec68710d3f 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -55,7 +56,8 @@ namespace cpu {
// This class is the top-level API for the XLA HLO --> LLVM IR compiler. It
// implements the DfsHloVisitor interface and emits HLO computations as LLVM IR
// functions.
-class IrEmitter : public DfsHloVisitorWithDefault {
+class IrEmitter : public DfsHloVisitorWithDefault,
+ public IrBuilderMixin<IrEmitter> {
public:
// Create a new LLVM IR emitter.
//
@@ -100,6 +102,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
llvm::IRBuilder<>* b() { return &b_; }
+ // builder() is for IrBuilderMixin.
+ llvm::IRBuilder<>* builder() { return &b_; }
+
// Emit an LLVM global variable for every constant buffer allocation.
Status EmitConstantGlobals();
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 52faaab25c..61f6055fc9 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -204,7 +204,7 @@ llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
} // namespace
StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const {
+ const HloInstruction* op, llvm::Value* operand_value) {
if (op->opcode() == HloOpcode::kCopy) {
return operand_value;
} else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
@@ -218,7 +218,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const {
+ const HloInstruction* op, llvm::Value* operand_value) {
switch (op->opcode()) {
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
@@ -230,14 +230,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
}
if (to_type == PRED) {
return b_->CreateZExt(
- b_->CreateICmpNE(operand_value, llvm::ConstantInt::get(
- operand_value->getType(), 0)),
+ ICmpNE(operand_value,
+ llvm::ConstantInt::get(operand_value->getType(), 0)),
llvm_ir::PrimitiveTypeToIrType(PRED, module_));
}
if (primitive_util::IsIntegralType(to_type)) {
- return b_->CreateIntCast(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_),
- primitive_util::IsSignedIntegralType(from_type));
+ return IntCast(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module_),
+ primitive_util::IsSignedIntegralType(from_type));
}
if (primitive_util::IsFloatingPointType(to_type)) {
if (to_type == BF16) {
@@ -253,14 +253,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
primitive_util::ComplexComponentType(to_type), module_);
if (primitive_util::IsSignedIntegralType(from_type)) {
return EmitComposeComplex(
- op, b_->CreateSIToFP(operand_value, to_ir_component_type),
- nullptr);
+ op, SIToFP(operand_value, to_ir_component_type), nullptr);
}
if (primitive_util::IsUnsignedIntegralType(from_type) ||
from_type == PRED) {
return EmitComposeComplex(
- op, b_->CreateUIToFP(operand_value, to_ir_component_type),
- nullptr);
+ op, UIToFP(operand_value, to_ir_component_type), nullptr);
}
}
return Unimplemented("conversion from primitive type %s to %s",
@@ -276,8 +274,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
}
if (primitive_util::BitWidth(from_type) ==
primitive_util::BitWidth(to_type)) {
- return b_->CreateBitCast(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
+ return BitCast(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
return InvalidArgument(
"bitcast conversion from primitive type %s to %s with unequal "
@@ -292,8 +290,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
if (is_signed) {
auto type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
- auto cmp = b_->CreateICmpSGE(operand_value, GetZero(type));
- return Select(cmp, operand_value, b_->CreateNeg(operand_value));
+ auto cmp = ICmpSGE(operand_value, GetZero(type));
+ return Select(cmp, operand_value, Neg(operand_value));
} else {
return operand_value;
}
@@ -309,23 +307,22 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
<< op->shape().element_type();
auto type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
- auto cmp = b_->CreateICmpEQ(operand_value, GetZero(type));
- auto ashr = b_->CreateAShr(operand_value, type->getIntegerBitWidth() - 1);
- return Select(cmp, GetZero(type), b_->CreateOr(ashr, 1));
+ auto cmp = ICmpEQ(operand_value, GetZero(type));
+ auto ashr = AShr(operand_value, type->getIntegerBitWidth() - 1);
+ return Select(cmp, GetZero(type), Or(ashr, 1));
}
case HloOpcode::kNegate:
- return b_->CreateNeg(operand_value);
+ return Neg(operand_value);
case HloOpcode::kNot: {
auto type = op->shape().element_type();
if (type == PRED) {
// It is not sufficient to just call CreateNot() here because a PRED
// is represented as an i8 and the truth value is stored only in the
// bottom bit.
- return b_->CreateZExt(
- b_->CreateNot(b_->CreateTrunc(operand_value, b_->getInt1Ty())),
- llvm_ir::PrimitiveTypeToIrType(PRED, module_));
+ return b_->CreateZExt(Not(Trunc(operand_value, b_->getInt1Ty())),
+ llvm_ir::PrimitiveTypeToIrType(PRED, module_));
} else if (primitive_util::IsIntegralType(type)) {
- return b_->CreateNot(operand_value);
+ return Not(operand_value);
}
return Unimplemented("unary op Not is not defined for type '%d'", type);
}
@@ -336,7 +333,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const {
+ const HloInstruction* op, llvm::Value* operand_value) {
switch (op->opcode()) {
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
@@ -353,8 +350,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
}
return EmitComposeComplex(
op,
- b_->CreateFPCast(operand_value, llvm_ir::PrimitiveTypeToIrType(
- to_component_type, module_)),
+ FPCast(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
nullptr);
}
if (from_type == BF16) {
@@ -370,22 +367,21 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
}
if (to_type == PRED) {
return b_->CreateZExt(
- b_->CreateFCmpUNE(
- operand_value,
- llvm::ConstantFP::get(operand_value->getType(), 0.0)),
+ FCmpUNE(operand_value,
+ llvm::ConstantFP::get(operand_value->getType(), 0.0)),
llvm_ir::PrimitiveTypeToIrType(PRED, module_));
}
if (primitive_util::IsFloatingPointType(to_type)) {
- return b_->CreateFPCast(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
+ return FPCast(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
if (primitive_util::IsSignedIntegralType(to_type)) {
- return b_->CreateFPToSI(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
+ return FPToSI(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
if (primitive_util::IsUnsignedIntegralType(to_type)) {
- return b_->CreateFPToUI(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
+ return FPToUI(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
return Unimplemented("unhandled conversion operation: %s => %s",
PrimitiveType_Name(from_type),
@@ -400,8 +396,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
}
if (primitive_util::BitWidth(from_type) ==
primitive_util::BitWidth(to_type)) {
- return b_->CreateBitCast(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
+ return BitCast(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
return InvalidArgument(
"bitcast conversion from primitive type %s to %s with unequal "
@@ -444,8 +440,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
// TODO(b/32151903): Ensure consistent sign behavior for -0.0.
auto type = operand_value->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
- auto oeq = b_->CreateFCmpOEQ(operand_value, zero);
- auto olt = b_->CreateFCmpOLT(operand_value, zero);
+ auto oeq = FCmpOEQ(operand_value, zero);
+ auto olt = FCmpOLT(operand_value, zero);
return Select(oeq, zero,
Select(olt, llvm::ConstantFP::get(type, -1.0),
llvm::ConstantFP::get(type, 1.0)));
@@ -457,12 +453,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
auto abs_value = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::fabs, {operand_value}, {type}, b_);
auto infinity = llvm::ConstantFP::getInfinity(type);
- auto not_infinite = b_->CreateFCmpONE(abs_value, infinity);
+ auto not_infinite = FCmpONE(abs_value, infinity);
return b_->CreateZExt(not_infinite,
llvm_ir::PrimitiveTypeToIrType(PRED, module_));
}
case HloOpcode::kNegate:
- return b_->CreateFNeg(operand_value);
+ return FNeg(operand_value);
case HloOpcode::kReal:
return operand_value;
case HloOpcode::kImag:
@@ -474,7 +470,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const {
+ const HloInstruction* op, llvm::Value* operand_value) {
PrimitiveType input_type = op->operand(0)->shape().element_type();
PrimitiveType component_type =
primitive_util::IsComplexType(input_type)
@@ -486,12 +482,11 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto a = EmitExtractReal(operand_value);
auto b = EmitExtractImag(operand_value);
llvm::Type* llvm_ty = a->getType();
- auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b));
+ auto sum_sq = FAdd(FMul(a, a), FMul(b, b));
TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a));
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
- return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq),
- angle);
+ return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
}
case HloOpcode::kLog1p: {
// log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
@@ -499,14 +494,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto b = EmitExtractImag(operand_value);
llvm::Type* llvm_ty = a->getType();
auto one = llvm::ConstantFP::get(llvm_ty, 1.0);
- auto a_plus_one = b_->CreateFAdd(a, one);
- auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a_plus_one, a_plus_one),
- b_->CreateFMul(b, b));
+ auto a_plus_one = FAdd(a, one);
+ auto sum_sq = FAdd(FMul(a_plus_one, a_plus_one), FMul(b, b));
TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one));
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
- return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq),
- angle);
+ return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
}
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
@@ -520,11 +513,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
primitive_util::ComplexComponentType(to_type);
auto to_ir_component_type =
llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
- return EmitComposeComplex(op,
- b_->CreateFPCast(EmitExtractReal(operand_value),
- to_ir_component_type),
- b_->CreateFPCast(EmitExtractImag(operand_value),
- to_ir_component_type));
+ return EmitComposeComplex(
+ op, FPCast(EmitExtractReal(operand_value), to_ir_component_type),
+ FPCast(EmitExtractImag(operand_value), to_ir_component_type));
}
case HloOpcode::kExp: {
// e^(a+bi) = e^a*(cos(b)+sin(b)i)
@@ -534,8 +525,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
TF_ASSIGN_OR_RETURN(
auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
- return EmitComposeComplex(op, b_->CreateFMul(exp_a, cos_b),
- b_->CreateFMul(exp_a, sin_b));
+ return EmitComposeComplex(op, FMul(exp_a, cos_b), FMul(exp_a, sin_b));
}
case HloOpcode::kExpm1: {
// e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
@@ -546,8 +536,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
TF_ASSIGN_OR_RETURN(
auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0);
- auto real_result = b_->CreateFSub(b_->CreateFMul(exp_a, cos_b), one);
- auto imag_result = b_->CreateFMul(exp_a, sin_b);
+ auto real_result = FSub(FMul(exp_a, cos_b), one);
+ auto imag_result = FMul(exp_a, sin_b);
return EmitComposeComplex(op, real_result, imag_result);
}
case HloOpcode::kCos: {
@@ -562,14 +552,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto b = EmitExtractImag(operand_value);
auto type = a->getType();
TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
- auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
- auto half_exp_neg_b =
- b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
+ auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
+ auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
- return EmitComposeComplex(
- op, b_->CreateFMul(cos_a, b_->CreateFAdd(half_exp_neg_b, half_exp_b)),
- b_->CreateFMul(sin_a, b_->CreateFSub(half_exp_neg_b, half_exp_b)));
+ return EmitComposeComplex(op,
+ FMul(cos_a, FAdd(half_exp_neg_b, half_exp_b)),
+ FMul(sin_a, FSub(half_exp_neg_b, half_exp_b)));
}
case HloOpcode::kSin: {
// sin(z) = .5i(e^(-iz) - e^(iz))
@@ -585,14 +574,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto b = EmitExtractImag(operand_value);
auto type = a->getType();
TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
- auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
- auto half_exp_neg_b =
- b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
+ auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
+ auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
- return EmitComposeComplex(
- op, b_->CreateFMul(sin_a, b_->CreateFAdd(half_exp_b, half_exp_neg_b)),
- b_->CreateFMul(cos_a, b_->CreateFSub(half_exp_b, half_exp_neg_b)));
+ return EmitComposeComplex(op,
+ FMul(sin_a, FAdd(half_exp_b, half_exp_neg_b)),
+ FMul(cos_a, FSub(half_exp_b, half_exp_neg_b)));
}
case HloOpcode::kTanh: {
/*
@@ -620,61 +608,51 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a));
TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b));
TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b));
- auto exp_neg_a =
- b_->CreateFDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a);
- auto exp_2a_minus_exp_neg_2a = b_->CreateFSub(
- b_->CreateFMul(exp_a, exp_a), b_->CreateFMul(exp_neg_a, exp_neg_a));
- auto cos_b_sq = b_->CreateFMul(cos_b, cos_b);
- auto sin_b_sq = b_->CreateFMul(sin_b, sin_b);
- auto real_num =
- b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a),
- b_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a));
- auto cos_b_sin_b = b_->CreateFMul(cos_b, sin_b);
- auto exp_a_plus_exp_neg_a = b_->CreateFAdd(exp_a, exp_neg_a);
+ auto exp_neg_a = FDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a);
+ auto exp_2a_minus_exp_neg_2a =
+ FSub(FMul(exp_a, exp_a), FMul(exp_neg_a, exp_neg_a));
+ auto cos_b_sq = FMul(cos_b, cos_b);
+ auto sin_b_sq = FMul(sin_b, sin_b);
+ auto real_num = FAdd(FMul(cos_b_sq, exp_2a_minus_exp_neg_2a),
+ FMul(sin_b_sq, exp_2a_minus_exp_neg_2a));
+ auto cos_b_sin_b = FMul(cos_b, sin_b);
+ auto exp_a_plus_exp_neg_a = FAdd(exp_a, exp_neg_a);
auto exp_a_plus_exp_neg_a_sq =
- b_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a);
- auto exp_a_minus_exp_neg_a = b_->CreateFSub(exp_a, exp_neg_a);
+ FMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a);
+ auto exp_a_minus_exp_neg_a = FSub(exp_a, exp_neg_a);
auto exp_a_minus_exp_neg_a_sq =
- b_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a);
- auto imag_num = b_->CreateFMul(
- cos_b_sin_b,
- b_->CreateFSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq));
- auto denom =
- b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq),
- b_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq));
- return EmitComposeComplex(op, b_->CreateFDiv(real_num, denom),
- b_->CreateFDiv(imag_num, denom));
+ FMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a);
+ auto imag_num = FMul(
+ cos_b_sin_b, FSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq));
+ auto denom = FAdd(FMul(cos_b_sq, exp_a_plus_exp_neg_a_sq),
+ FMul(sin_b_sq, exp_a_minus_exp_neg_a_sq));
+ return EmitComposeComplex(op, FDiv(real_num, denom),
+ FDiv(imag_num, denom));
}
case HloOpcode::kAbs: {
- auto sum_sq =
- b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value),
- EmitExtractReal(operand_value)),
- b_->CreateFMul(EmitExtractImag(operand_value),
- EmitExtractImag(operand_value)));
+ auto sum_sq = FAdd(
+ FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)),
+ FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value)));
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq},
{sum_sq->getType()}, b_);
}
case HloOpcode::kSign: { // Sign(c) = c / |c|
- auto sum_sq =
- b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value),
- EmitExtractReal(operand_value)),
- b_->CreateFMul(EmitExtractImag(operand_value),
- EmitExtractImag(operand_value)));
+ auto sum_sq = FAdd(
+ FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)),
+ FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value)));
auto cplx_abs = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, b_);
auto type = cplx_abs->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
- auto oeq = b_->CreateFCmpOEQ(cplx_abs, zero);
+ auto oeq = FCmpOEQ(cplx_abs, zero);
return Select(
oeq, EmitComposeComplex(op, zero, zero),
- EmitComposeComplex(
- op, b_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs),
- b_->CreateFDiv(EmitExtractImag(operand_value), cplx_abs)));
+ EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs),
+ FDiv(EmitExtractImag(operand_value), cplx_abs)));
}
case HloOpcode::kNegate:
- return EmitComposeComplex(op,
- b_->CreateFNeg(EmitExtractReal(operand_value)),
- b_->CreateFNeg(EmitExtractImag(operand_value)));
+ return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)),
+ FNeg(EmitExtractImag(operand_value)));
case HloOpcode::kReal:
return EmitExtractReal(operand_value);
case HloOpcode::kImag:
@@ -686,8 +664,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
PrimitiveType operand_type = op->operand(0)->shape().element_type();
if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
operand_type == PRED) {
@@ -702,21 +679,20 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
switch (op->opcode()) {
case HloOpcode::kComplex:
return EmitComposeComplex(op, lhs_value, rhs_value);
case HloOpcode::kAdd:
- return b_->CreateFAdd(lhs_value, rhs_value);
+ return FAdd(lhs_value, rhs_value);
case HloOpcode::kSubtract:
- return b_->CreateFSub(lhs_value, rhs_value);
+ return FSub(lhs_value, rhs_value);
case HloOpcode::kMultiply:
- return b_->CreateFMul(lhs_value, rhs_value);
+ return FMul(lhs_value, rhs_value);
case HloOpcode::kDivide:
- return b_->CreateFDiv(lhs_value, rhs_value);
+ return FDiv(lhs_value, rhs_value);
case HloOpcode::kRemainder:
- return b_->CreateFRem(lhs_value, rhs_value);
+ return FRem(lhs_value, rhs_value);
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
// comparisons always return false when one of the operands is NaN, whereas
// unordered comparisons return true.
@@ -758,61 +734,47 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
switch (op->opcode()) {
case HloOpcode::kAdd:
- return EmitComposeComplex(op,
- b_->CreateFAdd(EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value)),
- b_->CreateFAdd(EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value)));
+ return EmitComposeComplex(
+ op, FAdd(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
+ FAdd(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
case HloOpcode::kSubtract:
- return EmitComposeComplex(op,
- b_->CreateFSub(EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value)),
- b_->CreateFSub(EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value)));
+ return EmitComposeComplex(
+ op, FSub(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
+ FSub(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
case HloOpcode::kMultiply:
return EmitComposeComplex(
op,
- b_->CreateFSub(b_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value)),
- b_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value))),
- b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractImag(rhs_value)),
- b_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractReal(rhs_value))));
+ FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
+ FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))),
+ FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
+ FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))));
case HloOpcode::kDivide: {
// (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di))
// = ((ac + bd) + (bc - ad)i) / (c^2 + d^2)
auto rhs_sum_sq =
- b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(rhs_value),
- EmitExtractReal(rhs_value)),
- b_->CreateFMul(EmitExtractImag(rhs_value),
- EmitExtractImag(rhs_value)));
+ FAdd(FMul(EmitExtractReal(rhs_value), EmitExtractReal(rhs_value)),
+ FMul(EmitExtractImag(rhs_value), EmitExtractImag(rhs_value)));
auto type = rhs_sum_sq->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
- auto oeq = b_->CreateFCmpOEQ(rhs_sum_sq, zero);
- auto real_inf_or_nan = b_->CreateFDiv(EmitExtractReal(lhs_value), zero);
- auto imag_inf_or_nan = b_->CreateFDiv(EmitExtractImag(lhs_value), zero);
+ auto oeq = FCmpOEQ(rhs_sum_sq, zero);
+ auto real_inf_or_nan = FDiv(EmitExtractReal(lhs_value), zero);
+ auto imag_inf_or_nan = FDiv(EmitExtractImag(lhs_value), zero);
return Select(
oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan),
- EmitComposeComplex(
- op,
- b_->CreateFDiv(
- b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value)),
- b_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value))),
- rhs_sum_sq),
- b_->CreateFDiv(
- b_->CreateFSub(b_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractReal(rhs_value)),
- b_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractImag(rhs_value))),
- rhs_sum_sq)));
+ EmitComposeComplex(op,
+ FDiv(FAdd(FMul(EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value)),
+ FMul(EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value))),
+ rhs_sum_sq),
+ FDiv(FSub(FMul(EmitExtractImag(lhs_value),
+ EmitExtractReal(rhs_value)),
+ FMul(EmitExtractReal(lhs_value),
+ EmitExtractImag(rhs_value))),
+ rhs_sum_sq)));
}
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
// comparisons always return false when one of the operands is NaN, whereas
@@ -822,21 +784,19 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
// unordered comparison. This makes x != y equivalent to !(x == y), and
// matches C++'s semantics.
case HloOpcode::kEq:
- return b_->CreateAnd(
- llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
- EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value), b_),
- llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
- EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value), b_));
+ return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
+ EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value), b_),
+ llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
+ EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value), b_));
case HloOpcode::kNe:
- return b_->CreateOr(
- llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
- EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value), b_),
- llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
- EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value), b_));
+ return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
+ EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value), b_),
+ llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
+ EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value), b_));
case HloOpcode::kPower: {
// (a+bi)^(c+di) =
@@ -848,26 +808,24 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
auto b = EmitExtractImag(lhs_value);
auto c = EmitExtractReal(rhs_value);
auto d = EmitExtractImag(rhs_value);
- auto aa_p_bb = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b));
+ auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b));
auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
- auto half_c = b_->CreateFMul(one_half, c);
+ auto half_c = FMul(one_half, c);
TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
EmitPow(component_type, aa_p_bb, half_c));
- auto neg_d = b_->CreateFNeg(d);
+ auto neg_d = FNeg(d);
TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a));
- auto neg_d_arg_lhs = b_->CreateFMul(neg_d, arg_lhs);
+ auto neg_d_arg_lhs = FMul(neg_d, arg_lhs);
TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
EmitExp(component_type, neg_d_arg_lhs));
- auto coeff = b_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
+ auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
- auto half_d = b_->CreateFMul(one_half, d);
- auto q = b_->CreateFAdd(b_->CreateFMul(c, arg_lhs),
- b_->CreateFMul(half_d, ln_aa_p_bb));
+ auto half_d = FMul(one_half, d);
+ auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb));
TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q));
TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q));
- return EmitComposeComplex(op, b_->CreateFMul(coeff, cos_q),
- b_->CreateFMul(coeff, sin_q));
+ return EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q));
}
default:
return Unimplemented("binary complex op '%s'",
@@ -876,17 +834,17 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
}
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ llvm::Value* rhs_value) {
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_);
}
llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ llvm::Value* rhs_value) {
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
- llvm::Value* x) const {
+ llvm::Value* x) {
if (prim_type != F32) {
// TODO(b/34339814): Implement inverse erf for F64.
return Unimplemented(
@@ -901,7 +859,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
llvm::Value* p = getFloat(coefficients.front());
coefficients.remove_prefix(1);
for (float coefficient : coefficients) {
- p = b_->CreateFAdd(b_->CreateFMul(p, w), getFloat(coefficient));
+ p = FAdd(FMul(p, w), getFloat(coefficient));
}
return p;
};
@@ -921,25 +879,24 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration(
module_, llvm::Intrinsic::log, {b_->getFloatTy()});
- llvm::Value* w = b_->CreateFNeg(b_->CreateCall(
- logf_fn, {b_->CreateFMul(b_->CreateFSub(getFloat(1.0f), x),
- b_->CreateFAdd(getFloat(1.0f), x))}));
+ llvm::Value* w = FNeg(
+ Call(logf_fn, {FMul(FSub(getFloat(1.0f), x), FAdd(getFloat(1.0f), x))}));
llvm::Value* p_addr =
llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_);
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- b_->CreateFCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_);
+ FCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_);
// Handle true BB.
SetToFirstInsertPoint(if_data.true_block, b_);
{
- llvm::Value* lw = b_->CreateFSub(w, getFloat(2.5f));
+ llvm::Value* lw = FSub(w, getFloat(2.5f));
tensorflow::gtl::ArraySlice<float> lq{
2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
-4.39150654e-06f, 0.00021858087f, -0.00125372503f,
-0.00417768164f, 0.246640727f, 1.50140941f};
llvm::Value* p = multiply_add(lq, lw);
- b_->CreateStore(p, p_addr);
+ Store(p, p_addr);
}
// Handle false BB.
@@ -948,76 +905,73 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
- llvm::Value* gw =
- b_->CreateFSub(b_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f));
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f));
tensorflow::gtl::ArraySlice<float> gq{
-0.000200214257f, 0.000100950558f, 0.00134934322f,
-0.00367342844f, 0.00573950773f, -0.0076224613f,
0.00943887047f, 1.00167406f, 2.83297682f};
llvm::Value* p = multiply_add(gq, gw);
- b_->CreateStore(p, p_addr);
+ Store(p, p_addr);
}
SetToFirstInsertPoint(if_data.after_block, b_);
- llvm::Value* p = b_->CreateLoad(p_addr);
- return b_->CreateFMul(p, x);
+ llvm::Value* p = Load(p_addr);
+ return FMul(p, x);
}
-StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type,
+ llvm::Value* value) {
// Compute erfcinv(value) by calculating erfinv(1.0 - value).
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
auto one = llvm::ConstantFP::get(type, 1.0);
- return EmitErfInv(prim_type, b_->CreateFSub(one, value));
+ return EmitErfInv(prim_type, FSub(one, value));
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value},
{value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
auto x = value;
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
auto one = llvm::ConstantFP::get(type, 1.0);
auto negative_half = llvm::ConstantFP::get(type, -0.5);
// When x is large, the naive evaluation of ln(x + 1) is more
// accurate than the Taylor series.
- TF_ASSIGN_OR_RETURN(auto for_large_x,
- EmitLog(prim_type, b_->CreateFAdd(x, one)));
+ TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one)));
// The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + ….
- auto for_small_x =
- b_->CreateFMul(b_->CreateFAdd(b_->CreateFMul(negative_half, x), one), x);
+ auto for_small_x = FMul(FAdd(FMul(negative_half, x), one), x);
const auto kAntilogarithmIsSmallThreshold = 1e-4;
auto abs_x =
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
- auto x_is_small = b_->CreateFCmpOLT(
+ auto x_is_small = FCmpOLT(
abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold));
return Select(x_is_small, for_small_x, for_large_x);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value},
{value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value},
{value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
{value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
auto x = value;
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
auto one = llvm::ConstantFP::get(type, 1.0);
@@ -1025,40 +979,40 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
// When the exponent is large, the naive evaluation of e^(x) - 1 is more
// accurate than the Taylor series.
TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value));
- auto for_large_x = b_->CreateFSub(exp_x, one);
+ auto for_large_x = FSub(exp_x, one);
// The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + ….
// We want exp(x)-1 which is x + x^2/2 + x^3/6 + ….
- auto x_squared = b_->CreateFAdd(x, x);
- auto x_squared_over_two = b_->CreateFMul(x_squared, half);
- auto for_small_x = b_->CreateFAdd(x, x_squared_over_two);
+ auto x_squared = FAdd(x, x);
+ auto x_squared_over_two = FMul(x_squared, half);
+ auto for_small_x = FAdd(x, x_squared_over_two);
const auto kExponentIsSmallThreshold = 1e-5;
auto abs_x =
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
- auto x_is_small = b_->CreateFCmpOLT(
- abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold));
+ auto x_is_small =
+ FCmpOLT(abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold));
return Select(x_is_small, for_small_x, for_large_x);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
llvm::Value* lhs,
- llvm::Value* rhs) const {
+ llvm::Value* rhs) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
{lhs->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
llvm::Value* lhs,
- llvm::Value* rhs) const {
+ llvm::Value* rhs) {
return Unimplemented("atan2");
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
return Unimplemented("tanh");
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
- const HloInstruction* hlo, llvm::Value* x) const {
+ const HloInstruction* hlo, llvm::Value* x) {
if (hlo->operand(0)->shape().element_type() != F32) {
return Unimplemented("reduce-precision only implemented for F32");
}
@@ -1089,44 +1043,39 @@ static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b,
return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value);
}
-llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) const {
+llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) {
return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 1);
}
-llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) const {
+llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) {
return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 0);
}
-llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) const {
+llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) {
auto* integer_type = llvm::cast<llvm::IntegerType>(type);
return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue(
integer_type->getBitWidth()));
}
-llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) const {
+llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) {
auto* integer_type = llvm::cast<llvm::IntegerType>(type);
return llvm::ConstantInt::get(
integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth()));
}
-llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) const {
- return b_->CreateICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0));
-}
-
-llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(
- llvm::Value* lhs, llvm::Value* rhs) const {
- return b_->CreateAnd(b_->CreateICmpEQ(lhs, GetIntSMin(lhs->getType())),
- b_->CreateICmpEQ(rhs, GetMinusOne(rhs->getType())));
+llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) {
+ return ICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0));
}
-llvm::Value* ElementalIrEmitter::Select(llvm::Value* cond, llvm::Value* if_true,
- llvm::Value* if_false) const {
- return b_->CreateSelect(cond, if_true, if_false);
+llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(llvm::Value* lhs,
+ llvm::Value* rhs) {
+ return And(ICmpEQ(lhs, GetIntSMin(lhs->getType())),
+ ICmpEQ(rhs, GetMinusOne(rhs->getType())));
}
llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs,
llvm::Value* rhs,
- bool is_signed) const {
+ bool is_signed) {
// Integer division overflow behavior:
//
// X / 0 == -1
@@ -1135,16 +1084,15 @@ llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs,
if (!is_signed) {
llvm::Value* udiv_is_unsafe = IsZero(rhs);
llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs);
- llvm::Value* safe_div = b_->CreateUDiv(lhs, safe_rhs);
+ llvm::Value* safe_div = UDiv(lhs, safe_rhs);
return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div);
}
llvm::Value* has_zero_divisor = IsZero(rhs);
llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
- llvm::Value* sdiv_is_unsafe =
- b_->CreateOr(has_int_min_overflow, has_zero_divisor);
+ llvm::Value* sdiv_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs);
- llvm::Value* safe_div = b_->CreateSDiv(lhs, safe_rhs);
+ llvm::Value* safe_div = SDiv(lhs, safe_rhs);
return Select(
has_zero_divisor, GetMinusOne(lhs->getType()),
@@ -1153,7 +1101,7 @@ llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs,
llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs,
llvm::Value* rhs,
- bool is_signed) const {
+ bool is_signed) {
// Integer remainder overflow behavior:
//
// X % 0 == X
@@ -1162,16 +1110,15 @@ llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs,
if (!is_signed) {
llvm::Value* urem_is_unsafe = IsZero(rhs);
llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs);
- llvm::Value* safe_rem = b_->CreateURem(lhs, safe_rhs);
+ llvm::Value* safe_rem = URem(lhs, safe_rhs);
return Select(urem_is_unsafe, lhs, safe_rem);
}
llvm::Value* has_zero_divisor = IsZero(rhs);
llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
- llvm::Value* srem_is_unsafe =
- b_->CreateOr(has_int_min_overflow, has_zero_divisor);
+ llvm::Value* srem_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs);
- llvm::Value* safe_rem = b_->CreateSRem(lhs, safe_rhs);
+ llvm::Value* safe_rem = SRem(lhs, safe_rhs);
return Select(
has_zero_divisor, lhs,
@@ -1180,15 +1127,15 @@ llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs,
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
- bool is_signed) const {
+ bool is_signed) {
switch (op->opcode()) {
// TODO(jingyue): add the "nsw" attribute for signed types.
case HloOpcode::kAdd:
- return b_->CreateAdd(lhs_value, rhs_value);
+ return Add(lhs_value, rhs_value);
case HloOpcode::kSubtract:
- return b_->CreateSub(lhs_value, rhs_value);
+ return Sub(lhs_value, rhs_value);
case HloOpcode::kMultiply:
- return b_->CreateMul(lhs_value, rhs_value);
+ return Mul(lhs_value, rhs_value);
case HloOpcode::kDivide:
return EmitIntegerDivide(lhs_value, rhs_value, is_signed);
case HloOpcode::kRemainder:
@@ -1220,11 +1167,11 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
case HloOpcode::kMaximum:
return EmitIntegralMax(lhs_value, rhs_value, is_signed);
case HloOpcode::kAnd:
- return b_->CreateAnd(lhs_value, rhs_value);
+ return And(lhs_value, rhs_value);
case HloOpcode::kOr:
- return b_->CreateOr(lhs_value, rhs_value);
+ return Or(lhs_value, rhs_value);
case HloOpcode::kXor:
- return b_->CreateXor(lhs_value, rhs_value);
+ return Xor(lhs_value, rhs_value);
// Shifting out bits >= the number of bits in the type being shifted
// produces a poison value in LLVM which is basically "deferred undefined
@@ -1233,15 +1180,15 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
// UB.
case HloOpcode::kShiftRightArithmetic:
return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
- b_->CreateAShr(lhs_value, rhs_value),
+ AShr(lhs_value, rhs_value),
/*saturate_to_sign_bit=*/true);
case HloOpcode::kShiftLeft:
return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
- b_->CreateShl(lhs_value, rhs_value),
+ Shl(lhs_value, rhs_value),
/*saturate_to_sign_bit=*/false);
case HloOpcode::kShiftRightLogical:
return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
- b_->CreateLShr(lhs_value, rhs_value),
+ LShr(lhs_value, rhs_value),
/*saturate_to_sign_bit=*/false);
default:
return Unimplemented("binary integer op '%s'",
@@ -1251,7 +1198,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
llvm::Value* rhs_value,
- bool is_signed) const {
+ bool is_signed) {
return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
: llvm::ICmpInst::ICMP_UGE,
lhs_value, rhs_value),
@@ -1260,7 +1207,7 @@ llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
llvm::Value* rhs_value,
- bool is_signed) const {
+ bool is_signed) {
return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
: llvm::ICmpInst::ICMP_ULE,
lhs_value, rhs_value),
@@ -1269,7 +1216,7 @@ llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
- int64 operand_no) const {
+ int64 operand_no) {
CHECK(hlo.IsElementwise())
<< "HLO " << hlo.ToString() << " is not elementwise.";
@@ -1310,7 +1257,7 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const {
+ const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) {
TF_ASSIGN_OR_RETURN(llvm::Value * a_or_mean,
operand_to_generator.at(hlo->operand(0))(index));
TF_ASSIGN_OR_RETURN(llvm::Value * b_or_sigma,
@@ -1328,17 +1275,17 @@ StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
// Perform the division using the float type with the same number of bits
// as the raw value to avoid overflow.
if (raw_value_size_in_bits == 32) {
- elem_value = b_->CreateUIToFP(elem_value, b_->getFloatTy());
- elem_value = b_->CreateFDiv(
- elem_value, llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32)));
+ elem_value = UIToFP(elem_value, b_->getFloatTy());
+ elem_value = FDiv(elem_value,
+ llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32)));
} else {
- elem_value = b_->CreateUIToFP(elem_value, b_->getDoubleTy());
- elem_value = b_->CreateFDiv(
+ elem_value = UIToFP(elem_value, b_->getDoubleTy());
+ elem_value = FDiv(
elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64)));
}
if (elem_ir_ty != elem_value->getType()) {
- elem_value = b_->CreateFPTrunc(elem_value, elem_ir_ty);
+ elem_value = FPTrunc(elem_value, elem_ir_ty);
}
}
@@ -1346,9 +1293,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
switch (hlo->random_distribution()) {
case RNG_UNIFORM: {
if (elem_ir_ty->isFloatingPointTy()) {
- return b_->CreateFAdd(
- b_->CreateFMul(b_->CreateFSub(b_or_sigma, a_or_mean), elem_value),
- a_or_mean);
+ return FAdd(FMul(FSub(b_or_sigma, a_or_mean), elem_value), a_or_mean);
} else {
// To generate a uniform random value in [a, b) from a raw random sample
// in range [0, 2^N), we let range = b - a and return
@@ -1361,17 +1306,16 @@ StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
// the same cost as if the whole warp were to re-sample. So an
// efficient re-sampling implementation on GPU would need to do
// nontrivial work to share entropy between threads in the warp.
- auto range = b_->CreateSub(b_or_sigma, a_or_mean);
- return b_->CreateAdd(a_or_mean, b_->CreateURem(elem_value, range));
+ auto range = Sub(b_or_sigma, a_or_mean);
+ return Add(a_or_mean, URem(elem_value, range));
}
}
case RNG_NORMAL: {
TF_ASSIGN_OR_RETURN(
llvm::Value * r,
- EmitErfcInv(elem_prim_ty,
- b_->CreateFMul(llvm::ConstantFP::get(elem_ir_ty, 2.0),
- elem_value)));
- return b_->CreateFAdd(b_->CreateFMul(r, b_or_sigma), a_or_mean);
+ EmitErfcInv(elem_prim_ty, FMul(llvm::ConstantFP::get(elem_ir_ty, 2.0),
+ elem_value)));
+ return FAdd(FMul(r, b_or_sigma), a_or_mean);
}
default:
return InvalidArgument(
@@ -1491,8 +1435,7 @@ std::array<llvm::Value*, 4> CalculateSampleValues(
// Precondition: the RNG instruction is not fused.
llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator(
const HloInstruction* hlo,
- const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator)
- const {
+ const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) {
VLOG(3) << "Using philox RNG algorithm";
CHECK(!hlo->IsFused());
// A random number generated by the per module random number generator.
@@ -1515,7 +1458,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator(
// Load the global state variable for the Philox RNG algorithm.
llvm::GlobalVariable* rng_state_ptr =
llvm_ir::GetOrCreateVariableForPhiloxRngState(module_, b_);
- llvm::Value* rng_state = b_->CreateLoad(rng_state_ptr, "rng_state_value");
+ llvm::Value* rng_state = Load(rng_state_ptr, "rng_state_value");
// Build and return the elemental IR generator to generate a random value for
// the element corresponding to the current thread.
@@ -1541,8 +1484,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator(
// element within the sample.
llvm::Value* elems_per_sample_value =
llvm::ConstantInt::get(index_ty, elems_per_sample);
- llvm::Value* sample_idx = b_->CreateUDiv(elem_idx, elems_per_sample_value);
- llvm::Value* elem_offset = b_->CreateURem(elem_idx, elems_per_sample_value);
+ llvm::Value* sample_idx = UDiv(elem_idx, elems_per_sample_value);
+ llvm::Value* elem_offset = URem(elem_idx, elems_per_sample_value);
std::array<llvm::Value*, 4> counter_values = CalculateSampleValues(
sample_idx, hlo_random_value, global_random_number, rng_state, b_);
@@ -1550,18 +1493,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator(
// Store the four counter_values into the sample_address alloca so we can
// load the elem_offset'th one below.
for (int idx = 0; idx < 4; ++idx) {
- b_->CreateStore(counter_values[idx],
- b_->CreateInBoundsGEP(sample_address, b_->getInt32(idx)));
+ Store(counter_values[idx],
+ InBoundsGEP(sample_address, b_->getInt32(idx)));
}
llvm::Type* int64_ty = b_->getInt64Ty();
CHECK(elems_per_sample == 2 || elems_per_sample == 4);
llvm::Type* raw_value_ty = elems_per_sample == 2 ? int64_ty : int32_ty;
// Retrieve the raw value for the current element from the current sample.
- llvm::Value* raw_elem_value = b_->CreateLoad(
- b_->CreateInBoundsGEP(
- b_->CreatePointerCast(sample_address, raw_value_ty->getPointerTo()),
- elem_offset),
+ llvm::Value* raw_elem_value = Load(
+ InBoundsGEP(PointerCast(sample_address, raw_value_ty->getPointerTo()),
+ elem_offset),
"raw_elem_value");
return ConvertValueForDistribution(hlo, operand_to_generator, index,
@@ -1572,7 +1514,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator(
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const {
+ const llvm_ir::IrArray::Index& index) {
TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
operand_to_generator.at(hlo->operand(0))(
ElementwiseSourceIndex(index, *hlo, 0)));
@@ -1582,14 +1524,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect(
TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
operand_to_generator.at(hlo->operand(2))(
ElementwiseSourceIndex(index, *hlo, 2)));
- return Select(b_->CreateTrunc(pred_value, b_->getInt1Ty()), on_true_value,
+ return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value,
on_false_value);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const {
+ const llvm_ir::IrArray::Index& index) {
TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
operand_to_generator.at(hlo->operand(0))(
ElementwiseSourceIndex(index, *hlo, 0)));
@@ -1615,7 +1557,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& target_index) const {
+ const llvm_ir::IrArray::Index& target_index) {
const int64 concat_dim = hlo->dimensions(0);
auto source_index = target_index;
@@ -1637,9 +1579,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
}
llvm_ir::SetToFirstInsertPoint(exit_block, b_);
- llvm::PHINode* output = b_->CreatePHI(
- llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
- hlo->operands().size());
+ llvm::PHINode* output =
+ PHI(llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
+ hlo->operands().size());
auto prior_insert_point = b_->GetInsertPoint();
b_->SetInsertPoint(init_block);
@@ -1654,9 +1596,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
auto concat_dim_size =
llvm::ConstantInt::get(source_index[concat_dim]->getType(),
operand->shape().dimensions(concat_dim));
- b_->CreateCondBr(
- b_->CreateICmpULT(source_index[concat_dim], concat_dim_size),
- true_block, false_block);
+ CondBr(ICmpULT(source_index[concat_dim], concat_dim_size), true_block,
+ false_block);
// Create the terminator of the true block before calling operand
// generators, because they require non-degenerate basic blocks.
@@ -1669,11 +1610,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
// Subtract the size of the concat dimension of the current operand
// from the source index.
b_->SetInsertPoint(false_block);
- source_index[concat_dim] =
- b_->CreateSub(source_index[concat_dim], concat_dim_size);
+ source_index[concat_dim] = Sub(source_index[concat_dim], concat_dim_size);
}
- b_->CreateUnreachable();
+ Unreachable();
b_->SetInsertPoint(exit_block, prior_insert_point);
return output;
}
@@ -1681,7 +1621,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const {
+ const llvm_ir::IrArray::Index& index) {
// Emit IR to read dynamic start indices from hlo->operand(1).
const HloInstruction* input_hlo = hlo->operand(0);
const int64 rank = ShapeUtil::Rank(input_hlo->shape());
@@ -1698,7 +1638,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
// Clamp the start index so that the sliced portion fits in the operand:
// start_index = clamp(start_index, 0, operand_dim_size - output_dim_size)
- start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type);
+ start_index_value = SExtOrTrunc(start_index_value, index_type);
int64 largest_valid_start_index =
input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i);
CHECK_GE(largest_valid_start_index, 0);
@@ -1718,7 +1658,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
for (int64 i = 0; i < rank; ++i) {
// Emit IR which computes:
// input_index = start_index + offset_index
- input_index[i] = b_->CreateAdd(slice_start_index[i], index[i]);
+ input_index[i] = Add(slice_start_index[i], index[i]);
}
return operand_to_generator.at(input_hlo)(input_index);
}
@@ -1726,7 +1666,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const {
+ const llvm_ir::IrArray::Index& index) {
const Shape& operand_shape = hlo->operand(0)->shape();
const Shape& indices_shape = hlo->operand(1)->shape();
const Shape& output_shape = hlo->shape();
@@ -1775,7 +1715,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) {
llvm::Value* gather_dim_component_extended =
- b_->CreateSExtOrTrunc(index_component, index_type);
+ SExtOrTrunc(index_component, index_type);
int64 operand_dim = dim_numbers.start_index_map(dim);
int64 output_dim = operand_to_output_dim[operand_dim];
// If 'output_dim' is -1, it means 'operand_dim' is an elided window dim.
@@ -1799,8 +1739,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
gather_dim_component_extended, is_signed),
is_signed);
- operand_index[operand_dim] = b_->CreateAdd(
- operand_index[operand_dim], gather_dim_component_extended_inbound);
+ operand_index[operand_dim] =
+ Add(operand_index[operand_dim], gather_dim_component_extended_inbound);
};
if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
@@ -1824,7 +1764,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const {
+ const llvm_ir::IrArray::Index& index) {
const HloInstruction* input_hlo = hlo->operand(0);
const HloInstruction* update_hlo = hlo->operand(1);
const HloInstruction* start_hlo = hlo->operand(2);
@@ -1847,7 +1787,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
// Clamp the start index so that the update region fits in the operand.
// start_index = clamp(start_index, 0, input_dim_size - update_dim_size)
- start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type);
+ start_index_value = SExtOrTrunc(start_index_value, index_type);
llvm::Value* update_dim_size =
index_typed_const(update_hlo->shape().dimensions(i));
int64 largest_valid_start_index =
@@ -1863,14 +1803,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
start_index_value->setName(
AsStringRef(IrName(hlo, StrCat("start_idx", i))));
slice_start_index[i] = start_index_value;
- slice_limit_index[i] = b_->CreateAdd(slice_start_index[i], update_dim_size);
-
- slice_intersection = b_->CreateAnd(
- slice_intersection, b_->CreateICmpSGE(index[i], slice_start_index[i]),
- "slice_intersection");
- slice_intersection = b_->CreateAnd(
- slice_intersection, b_->CreateICmpSLT(index[i], slice_limit_index[i]),
- "slice_intersection");
+ slice_limit_index[i] = Add(slice_start_index[i], update_dim_size);
+
+ slice_intersection =
+ And(slice_intersection, ICmpSGE(index[i], slice_start_index[i]),
+ "slice_intersection");
+ slice_intersection =
+ And(slice_intersection, ICmpSLT(index[i], slice_limit_index[i]),
+ "slice_intersection");
}
// Emit:
@@ -1887,26 +1827,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
// Compute update index for intersection case.
llvm_ir::IrArray::Index update_index(index.GetType(), rank);
for (int64 i = 0; i < rank; ++i) {
- update_index[i] = b_->CreateSub(index[i], slice_start_index[i]);
+ update_index[i] = Sub(index[i], slice_start_index[i]);
}
TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
operand_to_generator.at(update_hlo)(update_index));
- b_->CreateStore(true_value, ret_value_addr);
+ Store(true_value, ret_value_addr);
// Handle false BB (return data from 'input')
SetToFirstInsertPoint(if_data.false_block, b_);
TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
operand_to_generator.at(input_hlo)(index));
- b_->CreateStore(false_value, ret_value_addr);
+ Store(false_value, ret_value_addr);
SetToFirstInsertPoint(if_data.after_block, b_);
- return b_->CreateLoad(ret_value_addr);
+ return Load(ret_value_addr);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& padded_index) const {
+ const llvm_ir::IrArray::Index& padded_index) {
auto index = padded_index;
llvm::Value* in_bounds = b_->getTrue();
for (size_t i = 0; i < index.size(); ++i) {
@@ -1914,26 +1854,22 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
return llvm::ConstantInt::get(index[i]->getType(), n);
};
const auto& pad_dim = hlo->padding_config().dimensions(i);
- index[i] =
- b_->CreateSub(index[i], index_typed_const(pad_dim.edge_padding_low()));
- in_bounds = b_->CreateAnd(in_bounds,
- b_->CreateICmpSGE(index[i], index_typed_const(0)),
- "in_bounds");
- in_bounds = b_->CreateAnd(
+ index[i] = Sub(index[i], index_typed_const(pad_dim.edge_padding_low()));
+ in_bounds =
+ And(in_bounds, ICmpSGE(index[i], index_typed_const(0)), "in_bounds");
+ in_bounds = And(
in_bounds,
- b_->CreateICmpEQ(
+ ICmpEQ(
index_typed_const(0),
- b_->CreateURem(index[i],
- index_typed_const(pad_dim.interior_padding() + 1))),
- "in_bounds");
- index[i] = b_->CreateSDiv(
- index[i], index_typed_const(pad_dim.interior_padding() + 1));
- in_bounds = b_->CreateAnd(
- in_bounds,
- b_->CreateICmpSLT(
- index[i],
- index_typed_const(hlo->operand(0)->shape().dimensions(i))),
+ URem(index[i], index_typed_const(pad_dim.interior_padding() + 1))),
"in_bounds");
+ index[i] =
+ SDiv(index[i], index_typed_const(pad_dim.interior_padding() + 1));
+ in_bounds =
+ And(in_bounds,
+ ICmpSLT(index[i],
+ index_typed_const(hlo->operand(0)->shape().dimensions(i))),
+ "in_bounds");
}
// if (in_bounds) {
@@ -1949,26 +1885,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
SetToFirstInsertPoint(if_data.true_block, b_);
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(0))(index));
- b_->CreateStore(operand_value, ret_value_addr);
+ Store(operand_value, ret_value_addr);
SetToFirstInsertPoint(if_data.false_block, b_);
TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
operand_to_generator.at(hlo->operand(1))(
IrArray::Index(index.GetType())));
- b_->CreateStore(padding_value, ret_value_addr);
+ Store(padding_value, ret_value_addr);
SetToFirstInsertPoint(if_data.after_block, b_);
// Don't create phi(operand_value, padding_value) here, because invoking
// operand_to_generator may create new basic blocks, making the parent
// of operand_value or padding_value no longer a predecessor of
// if_data.after_block.
- return b_->CreateLoad(ret_value_addr);
+ return Load(ret_value_addr);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& dot_result_index) const {
+ const llvm_ir::IrArray::Index& dot_result_index) {
auto lhs_generator = operand_to_generator.at(hlo->operand(0));
auto rhs_generator = operand_to_generator.at(hlo->operand(1));
@@ -1996,8 +1932,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
llvm::Value* accumulator_alloca =
llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_);
- b_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm),
- accumulator_alloca);
+ Store(llvm::Constant::getNullValue(primitive_type_llvm), accumulator_alloca);
SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_);
@@ -2019,42 +1954,37 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
}
rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue());
- llvm::Value* current_accumulator = b_->CreateLoad(accumulator_alloca);
+ llvm::Value* current_accumulator = Load(accumulator_alloca);
TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
llvm::Value* next_accumulator;
if (primitive_util::IsComplexType(primitive_type)) {
- llvm::Value* product_real = b_->CreateFSub(
- b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
- b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
- llvm::Value* product_imag = b_->CreateFAdd(
- b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
- b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)));
- next_accumulator = b_->CreateInsertValue(
+ llvm::Value* product_real =
+ FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
+ FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
+ llvm::Value* product_imag =
+ FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
+ FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)));
+ next_accumulator = InsertValue(
current_accumulator,
- b_->CreateFAdd(EmitExtractReal(current_accumulator), product_real),
- {0});
- next_accumulator = b_->CreateInsertValue(
+ FAdd(EmitExtractReal(current_accumulator), product_real), {0});
+ next_accumulator = InsertValue(
next_accumulator,
- b_->CreateFAdd(EmitExtractImag(current_accumulator), product_imag),
- {1});
+ FAdd(EmitExtractImag(current_accumulator), product_imag), {1});
} else if (primitive_util::IsFloatingPointType(primitive_type)) {
- next_accumulator = b_->CreateFAdd(current_accumulator,
- b_->CreateFMul(lhs_value, rhs_value));
+ next_accumulator = FAdd(current_accumulator, FMul(lhs_value, rhs_value));
} else {
- next_accumulator =
- b_->CreateAdd(current_accumulator, b_->CreateMul(lhs_value, rhs_value));
+ next_accumulator = Add(current_accumulator, Mul(lhs_value, rhs_value));
}
- b_->CreateStore(next_accumulator, accumulator_alloca);
+ Store(next_accumulator, accumulator_alloca);
SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_);
- return b_->CreateLoad(accumulator_alloca);
+ return Load(accumulator_alloca);
}
llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
- const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator)
- const {
+ const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) {
switch (hlo->opcode()) {
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
@@ -2148,10 +2078,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
const HloInstruction* operand = hlo->operand(0);
auto source_index = target_index;
for (int64 dim : hlo->dimensions()) {
- source_index[dim] = b_->CreateSub(
- llvm::ConstantInt::get(target_index[dim]->getType(),
- hlo->shape().dimensions(dim) - 1),
- target_index[dim]);
+ source_index[dim] =
+ Sub(llvm::ConstantInt::get(target_index[dim]->getType(),
+ hlo->shape().dimensions(dim) - 1),
+ target_index[dim]);
}
return operand_to_generator.at(operand)(source_index);
};
@@ -2235,23 +2165,23 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
}
}
-llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const {
- return b_->CreateExtractValue(value, {0});
+llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) {
+ return ExtractValue(value, {0});
}
-llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const {
- return b_->CreateExtractValue(value, {1});
+llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) {
+ return ExtractValue(value, {1});
}
llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
llvm::Value* real,
- llvm::Value* imag) const {
+ llvm::Value* imag) {
auto cplx_type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
- auto complex = b_->CreateInsertValue(
- llvm::ConstantAggregateZero::get(cplx_type), real, {0});
+ auto complex =
+ InsertValue(llvm::ConstantAggregateZero::get(cplx_type), real, {0});
if (imag != nullptr) {
- complex = b_->CreateInsertValue(complex, imag, {1});
+ complex = InsertValue(complex, imag, {1});
}
return complex;
}
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
index c037b98929..d3e2acaabd 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
@@ -23,12 +23,13 @@ limitations under the License.
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
-class ElementalIrEmitter {
+class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
public:
using HloToElementGeneratorMap =
std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>;
@@ -40,115 +41,114 @@ class ElementalIrEmitter {
virtual ~ElementalIrEmitter() = default;
virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
- llvm::Value* operand_value) const;
+ llvm::Value* operand_value);
virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
llvm::Value* lhs_value,
- llvm::Value* rhs_value) const;
+ llvm::Value* rhs_value);
// Returns a function to generate an element of the output of `hlo`, given a
// map of functions to generate elements of its operands.
virtual llvm_ir::ElementGenerator MakeElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const;
+ const HloToElementGeneratorMap& operand_to_generator);
- llvm::IRBuilder<>* b() const { return b_; }
- llvm::Module* module() const { return module_; }
+ llvm::IRBuilder<>* b() { return b_; }
+
+ // builder() is for IrBuilderMixin.
+ llvm::IRBuilder<>* builder() { return b_; }
+
+ llvm::Module* module() { return module_; }
protected:
- virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const;
+ virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(const HloInstruction* op,
+ llvm::Value* operand_value);
- virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const;
+ virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(const HloInstruction* op,
+ llvm::Value* operand_value);
- virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const;
+ virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(const HloInstruction* op,
+ llvm::Value* operand_value);
- llvm::Value* IsZero(llvm::Value* v) const;
- llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs,
- llvm::Value* rhs) const;
- llvm::Value* GetZero(llvm::Type* type) const;
- llvm::Value* GetOne(llvm::Type* type) const;
- llvm::Value* GetIntSMin(llvm::Type* type) const;
- llvm::Value* GetMinusOne(llvm::Type* type) const;
- llvm::Value* Select(llvm::Value* cond, llvm::Value* if_true,
- llvm::Value* if_false) const;
+ llvm::Value* IsZero(llvm::Value* v);
+ llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs, llvm::Value* rhs);
+ llvm::Value* GetZero(llvm::Type* type);
+ llvm::Value* GetOne(llvm::Type* type);
+ llvm::Value* GetIntSMin(llvm::Type* type);
+ llvm::Value* GetMinusOne(llvm::Type* type);
llvm::Value* EmitIntegerDivide(llvm::Value* lhs, llvm::Value* rhs,
- bool is_signed) const;
+ bool is_signed);
llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs,
- bool is_signed) const;
+ bool is_signed);
virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op,
llvm::Value* lhs_value,
llvm::Value* rhs_value,
- bool is_signed) const;
+ bool is_signed);
- virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const;
+ virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
+ llvm::Value* lhs_value,
+ llvm::Value* rhs_value);
- virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const;
+ virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(const HloInstruction* op,
+ llvm::Value* lhs_value,
+ llvm::Value* rhs_value);
virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value,
- llvm::Value* rhs_value) const;
+ llvm::Value* rhs_value);
virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
- llvm::Value* rhs_value) const;
+ llvm::Value* rhs_value);
llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
- bool is_signed) const;
+ bool is_signed);
llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
- bool is_signed) const;
+ bool is_signed);
virtual StatusOr<llvm::Value*> EmitErfInv(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type,
- llvm::Value* lhs,
- llvm::Value* rhs) const;
+ llvm::Value* lhs, llvm::Value* rhs);
virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitLog1p(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type,
- llvm::Value* lhs,
- llvm::Value* rhs) const;
+ llvm::Value* lhs, llvm::Value* rhs);
virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo,
- llvm::Value* x) const;
+ llvm::Value* x);
- virtual llvm::Value* EmitExtractReal(llvm::Value* value) const;
- virtual llvm::Value* EmitExtractImag(llvm::Value* value) const;
+ virtual llvm::Value* EmitExtractReal(llvm::Value* value);
+ virtual llvm::Value* EmitExtractImag(llvm::Value* value);
// Composes a complex struct. imag may be nullptr for simple cast operations.
llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
- llvm::Value* imag) const;
+ llvm::Value* imag);
// A helper method for MakeElementGenerator. Given an elementwise op `hlo` and
// the target array index, computes the source array index of its
@@ -157,50 +157,50 @@ class ElementalIrEmitter {
// Precondition: `hlo` is an elementwise op.
llvm_ir::IrArray::Index ElementwiseSourceIndex(
const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
- int64 operand_no) const;
+ int64 operand_no);
// Identifier of the thread unique among all threads on the device
- virtual llvm::Value* EmitThreadId() const { return b_->getIntN(128, 0); }
+ virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); }
StatusOr<llvm::Value*> EmitElementalSelect(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const;
+ const llvm_ir::IrArray::Index& index);
StatusOr<llvm::Value*> EmitElementalClamp(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const;
+ const llvm_ir::IrArray::Index& index);
StatusOr<llvm::Value*> EmitElementalConcatenate(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& target_index) const;
+ const llvm_ir::IrArray::Index& target_index);
StatusOr<llvm::Value*> EmitElementalDynamicSlice(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const;
+ const llvm_ir::IrArray::Index& index);
StatusOr<llvm::Value*> EmitElementalGather(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const;
+ const llvm_ir::IrArray::Index& index);
StatusOr<llvm::Value*> EmitElementalDynamicUpdateSlice(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const;
+ const llvm_ir::IrArray::Index& index);
StatusOr<llvm::Value*> EmitElementalPad(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& padded_index) const;
+ const llvm_ir::IrArray::Index& padded_index);
StatusOr<llvm::Value*> EmitElementalDot(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& dot_result_index) const;
+ const llvm_ir::IrArray::Index& dot_result_index);
llvm::IRBuilder<>* const b_;
@@ -215,13 +215,13 @@ class ElementalIrEmitter {
// random number generation algorithm.
llvm_ir::ElementGenerator MakePhiloxRngElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const;
+ const HloToElementGeneratorMap& operand_to_generator);
// Converts the raw value generated by a random number generation algorithm
// to the distribution requested by the RNG HloInstruction.
StatusOr<llvm::Value*> ConvertValueForDistribution(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const;
+ const llvm_ir::IrArray::Index& index, llvm::Value* raw_value);
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 87b799e78e..d6e9436348 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -177,6 +177,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin",
"//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library",
"//tensorflow/compiler/xla/service/llvm_ir:kernel_tiling",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index afcf9fa2ea..57a3a43a6f 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -77,7 +77,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
const string& callee_name,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) const {
+ PrimitiveType output_type) {
// The libdevice math functions differentiate between "double" and "float" by
// appending an 'f' to the function's name. libdevice doesn't have f16 math
// functions, so we convert the operands to f32 before calling the function
@@ -94,7 +94,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
for (int64 i = 0; i < operands.size(); ++i) {
if (input_types[i] == F16) {
converted_operands[i] =
- b_->CreateFPCast(converted_operands[i], b_->getFloatTy());
+ FPCast(converted_operands[i], b_->getFloatTy());
converted_input_types[i] = F32;
}
}
@@ -113,7 +113,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
converted_input_types, output_type)
.ValueOrDie();
if (cast_result_to_fp16) {
- result = b_->CreateFPCast(result, b_->getHalfTy());
+ result = FPCast(result, b_->getHalfTy());
}
return result;
}
@@ -122,7 +122,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
const string& callee_name,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) const {
+ PrimitiveType output_type) {
// llvm intrinsics differentiate between half/float/double functions via
// the suffixes ".f16", ".f32" and ".f64".
string munged_callee = callee_name;
@@ -147,7 +147,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
const string& callee_name,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) const {
+ PrimitiveType output_type) {
// Binary math functions transform are of type [T] -> T.
for (PrimitiveType input_type : input_types) {
if (output_type != input_type) {
@@ -163,8 +163,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
PrimitiveType output_type = op->shape().element_type();
@@ -183,8 +182,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
CHECK_EQ(op->opcode(), HloOpcode::kPower);
PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
@@ -218,7 +216,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
// TODO(jlebar): Does this happen with fastmath disabled? If not, should
// we force-enable it?
TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt());
- return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt);
+ return FDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt);
}
VLOG(10) << "emitting pow as regular call to pow(): " << op->ToString();
@@ -227,55 +225,56 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitErfcInv(
- PrimitiveType prim_type, llvm::Value* value) const {
+ PrimitiveType prim_type, llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_erfcinv", {value}, {prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog(PrimitiveType prim_type,
+ llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog1p(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
+ llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_log1p", {value}, {prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin(PrimitiveType prim_type,
+ llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type,
+ llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_cos", {value}, {prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(PrimitiveType prim_type,
+ llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
+ llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_expm1", {value}, {prim_type}, prim_type);
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type,
llvm::Value* lhs,
- llvm::Value* rhs) const {
+ llvm::Value* rhs) {
return EmitLibdeviceMathCall("__nv_pow", {lhs, rhs}, {prim_type, prim_type},
prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
- PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
+ llvm::Value* lhs,
+ llvm::Value* rhs) {
return EmitLibdeviceMathCall("__nv_atan2", {lhs, rhs}, {prim_type, prim_type},
prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
+ llvm::Value* value) {
// Emit a fast approximation of tanh instead of calling __nv_tanh.
// __nv_tanh is particularly bad because it contains branches, thus
// preventing LLVM's load-store vectorizer from working its magic across a
@@ -285,9 +284,9 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(
// Upcast F16 to F32 if necessary.
llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
- llvm::Value* input = b_->CreateFPCast(value, type);
+ llvm::Value* input = FPCast(value, type);
llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
- return b_->CreateFPCast(fast_tanh, value->getType());
+ return FPCast(fast_tanh, value->getType());
}
llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
@@ -295,7 +294,7 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
PrimitiveType output_type,
- tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes) const {
+ tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes) {
std::vector<llvm::Type*> ir_input_types;
for (PrimitiveType input_type : input_types) {
ir_input_types.push_back(
@@ -315,29 +314,28 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
callee->addFnAttr(attribute);
}
- return b_->CreateCall(callee, llvm_ir::AsArrayRef(operands));
+ return Call(callee, llvm_ir::AsArrayRef(operands));
}
-llvm::Value* GpuElementalIrEmitter::EmitThreadId() const {
- llvm::Value* block_id = b_->CreateIntCast(
- llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
- {}, {}, b_),
- b_->getIntNTy(128), /*isSigned=*/true, "block.id");
- llvm::Value* thread_id_in_block = b_->CreateIntCast(
- llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x,
- {}, {}, b_),
- b_->getIntNTy(128), /*isSigned=*/true, "thread.id");
- llvm::Value* threads_per_block = b_->CreateIntCast(
- llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x,
- {}, {}, b_),
- b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
- return b_->CreateNSWAdd(b_->CreateNSWMul(block_id, threads_per_block),
- thread_id_in_block);
+llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
+ llvm::Value* block_id =
+ IntCast(llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_),
+ b_->getIntNTy(128), /*isSigned=*/true, "block.id");
+ llvm::Value* thread_id_in_block =
+ IntCast(llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_),
+ b_->getIntNTy(128), /*isSigned=*/true, "thread.id");
+ llvm::Value* threads_per_block =
+ IntCast(llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, {}, {}, b_),
+ b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
+ return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
}
llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const {
+ const HloToElementGeneratorMap& operand_to_generator) {
switch (hlo->opcode()) {
case HloOpcode::kMap:
return [=, &operand_to_generator](
@@ -383,7 +381,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
operand_to_generator.at(hlo->operand(1))(
IrArray::Index(index.GetType())));
- b_->CreateStore(init_value, accum_ptr);
+ Store(init_value, accum_ptr);
}
llvm::Type* index_type = index.GetType();
@@ -405,22 +403,21 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
IrArray::Index input_index(index_type, index.size());
llvm::Value* in_bounds = b_->getInt1(true);
for (size_t i = 0; i < index.size(); ++i) {
- llvm::Value* stridden_index = b_->CreateNSWMul(
+ llvm::Value* stridden_index = NSWMul(
index[i], index_typed_const(window.dimensions(i).stride()));
- input_index[i] = b_->CreateNSWSub(
- b_->CreateNSWAdd(stridden_index, window_index[i]),
- index_typed_const(window.dimensions(i).padding_low()));
+ input_index[i] =
+ NSWSub(NSWAdd(stridden_index, window_index[i]),
+ index_typed_const(window.dimensions(i).padding_low()));
// We must check whether 0 ≤ input_index[i] < bound, as otherwise
// we are in the pad and so can skip the computation. This
// comparison is equivalent to the unsigned comparison
// input_index[i] < bound, as a negative value wraps to a large
// positive value.
- in_bounds = b_->CreateAnd(
- in_bounds,
- b_->CreateICmpULT(
- input_index[i],
- index_typed_const(operand->shape().dimensions(i))));
+ in_bounds =
+ And(in_bounds,
+ ICmpULT(input_index[i],
+ index_typed_const(operand->shape().dimensions(i))));
}
llvm_ir::LlvmIfData if_data =
@@ -432,12 +429,11 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
operand_to_generator.at(operand)(input_index));
TF_ASSIGN_OR_RETURN(
llvm::Value * accum_value,
- compute_nested_(*hlo->to_apply(),
- {b_->CreateLoad(accum_ptr), input_value}));
- b_->CreateStore(accum_value, accum_ptr);
+ compute_nested_(*hlo->to_apply(), {Load(accum_ptr), input_value}));
+ Store(accum_value, accum_ptr);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
- return b_->CreateLoad(accum_ptr);
+ return Load(accum_ptr);
};
case HloOpcode::kReduce:
// TODO(b/112040122): This should be supported.
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
index 84454d31bb..91942785d2 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
@@ -48,50 +48,50 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
llvm_ir::ElementGenerator MakeElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const override;
+ const HloToElementGeneratorMap& operand_to_generator) override;
protected:
- StatusOr<llvm::Value*> EmitFloatBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const override;
+ StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
+ llvm::Value* lhs_value,
+ llvm::Value* rhs_value) override;
StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitLog1p(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type, llvm::Value* lhs,
- llvm::Value* rhs) const override;
+ llvm::Value* rhs) override;
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
- llvm::Value* rhs) const override;
+ llvm::Value* rhs) override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
- llvm::Value* EmitThreadId() const override;
+ llvm::Value* EmitThreadId() override;
private:
// Emits IR for op, which must have opcode kPower.
StatusOr<llvm::Value*> EmitPowerOp(const HloInstruction* op,
llvm::Value* lhs_value,
- llvm::Value* rhs_value) const;
+ llvm::Value* rhs_value);
// Emits IR to call a device function named "callee_name" on the given
// operand. Returns the IR value that represents the return value.
@@ -100,7 +100,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
tensorflow::gtl::ArraySlice<PrimitiveType> input_type,
PrimitiveType output_type,
- tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes) const;
+ tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes);
// Emits IR to call an LLVM intrinsic of type [T] -> T. Adjusts
// callee_name according to T. Returns the IR value that represents the
@@ -109,7 +109,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
const string& callee_name,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) const;
+ PrimitiveType output_type);
// Emits IR to call a libdevice function of type [T] -> T. Adjusts
// callee_name according to T. Returns the IR value that represents the
@@ -118,7 +118,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
const string& callee_name,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) const;
+ PrimitiveType output_type);
// Emits IR to call a function of type [T] -> T. Does not munge callee_name.
// Returns the IR value that represents the return value of the function.
@@ -126,7 +126,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
const string& callee_name,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) const;
+ PrimitiveType output_type);
const HloModuleConfig& hlo_module_config_;
NestedComputer compute_nested_;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 4cbb6d75a8..a620cebe04 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -156,7 +156,7 @@ Status IrEmitter::EmitCallToNestedComputation(
std::vector<llvm::Value*> arguments(operands.begin(), operands.end());
arguments.push_back(output);
arguments.push_back(bindings_.GetTempBufferBase());
- b_.CreateCall(emitted_function, arguments);
+ Call(emitted_function, arguments);
return Status::OK();
}
@@ -178,7 +178,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
computation.root_instruction()->shape().element_type();
bool is_atomic_integral = element_type == S32 || element_type == U32 ||
element_type == S64 || element_type == U64;
- llvm::Value* source = b_.CreateLoad(source_address, "source");
+ llvm::Value* source = Load(source_address, "source");
if (root_opcode == HloOpcode::kAdd) {
// NVPTX supports atomicAdd on F32 and integer types.
if (element_type == F32) {
@@ -190,8 +190,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
}
if (is_atomic_integral) {
// integral + integral
- b_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address, source,
- llvm::AtomicOrdering::SequentiallyConsistent);
+ AtomicRMW(llvm::AtomicRMWInst::Add, output_address, source,
+ llvm::AtomicOrdering::SequentiallyConsistent);
return true;
}
}
@@ -202,8 +202,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
auto opcode = primitive_util::IsSignedIntegralType(element_type)
? llvm::AtomicRMWInst::Max
: llvm::AtomicRMWInst::UMax;
- b_.CreateAtomicRMW(opcode, output_address, source,
- llvm::AtomicOrdering::SequentiallyConsistent);
+ AtomicRMW(opcode, output_address, source,
+ llvm::AtomicOrdering::SequentiallyConsistent);
return true;
}
@@ -212,8 +212,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
auto opcode = primitive_util::IsSignedIntegralType(element_type)
? llvm::AtomicRMWInst::Min
: llvm::AtomicRMWInst::UMin;
- b_.CreateAtomicRMW(opcode, output_address, source,
- llvm::AtomicOrdering::SequentiallyConsistent);
+ AtomicRMW(opcode, output_address, source,
+ llvm::AtomicOrdering::SequentiallyConsistent);
return true;
}
@@ -292,10 +292,10 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
// cas_old_output_address and cas_new_output_address point to the scratch
// memory where we store the old and new values for the repeated atomicCAS
// operations.
- llvm::Value* cas_old_output_address = b_.CreateAlloca(
- atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address");
- llvm::Value* cas_new_output_address = b_.CreateAlloca(
- atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address");
+ llvm::Value* cas_old_output_address =
+ Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address");
+ llvm::Value* cas_new_output_address =
+ Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address");
// Emit preparation code to the preheader.
llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock();
@@ -309,29 +309,26 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
CHECK_EQ((element_size % sizeof(char)), 0);
llvm::Type* address_int_type =
module_->getDataLayout().getIntPtrType(output_address_type);
- atomic_memory_address = b_.CreatePtrToInt(output_address, address_int_type);
+ atomic_memory_address = PtrToInt(output_address, address_int_type);
llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3);
- llvm::Value* offset = b_.CreateAnd(atomic_memory_address, mask);
+ llvm::Value* offset = And(atomic_memory_address, mask);
mask = llvm::ConstantInt::get(address_int_type, -4);
- atomic_memory_address = b_.CreateAnd(atomic_memory_address, mask);
+ atomic_memory_address = And(atomic_memory_address, mask);
atomic_memory_address =
- b_.CreateIntToPtr(atomic_memory_address, atomic_address_type);
- binop_output_address = b_.CreateAdd(
- b_.CreatePtrToInt(cas_new_output_address, address_int_type), offset);
+ IntToPtr(atomic_memory_address, atomic_address_type);
binop_output_address =
- b_.CreateIntToPtr(binop_output_address, element_address_type);
+ Add(PtrToInt(cas_new_output_address, address_int_type), offset);
+ binop_output_address = IntToPtr(binop_output_address, element_address_type);
} else {
- atomic_memory_address =
- b_.CreateBitCast(output_address, atomic_address_type);
+ atomic_memory_address = BitCast(output_address, atomic_address_type);
binop_output_address =
- b_.CreateBitCast(cas_new_output_address, element_address_type);
+ BitCast(cas_new_output_address, element_address_type);
}
// Use the value from the memory that atomicCAS operates on to initialize
// cas_old_output.
- llvm::Value* cas_old_output =
- b_.CreateLoad(atomic_memory_address, "cas_old_output");
- b_.CreateStore(cas_old_output, cas_old_output_address);
+ llvm::Value* cas_old_output = Load(atomic_memory_address, "cas_old_output");
+ Store(cas_old_output, cas_old_output_address);
llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock(
b_.GetInsertPoint(), "atomic_op_loop_exit");
@@ -344,32 +341,29 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
// Emit the body of the loop that repeatedly invokes atomicCAS.
//
// Use cas_old_output to initialize cas_new_output.
- cas_old_output = b_.CreateLoad(cas_old_output_address, "cas_old_output");
- b_.CreateStore(cas_old_output, cas_new_output_address);
+ cas_old_output = Load(cas_old_output_address, "cas_old_output");
+ Store(cas_old_output, cas_new_output_address);
// Emits code to calculate new_output = operation(old_output, source);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
computation, {binop_output_address, source_address},
binop_output_address));
- llvm::Value* cas_new_output =
- b_.CreateLoad(cas_new_output_address, "cas_new_output");
+ llvm::Value* cas_new_output = Load(cas_new_output_address, "cas_new_output");
// Emit code to perform the atomicCAS operation
// (cas_old_output, success) = atomicCAS(memory_address, cas_old_output,
// cas_new_output);
- llvm::Value* ret_value = b_.CreateAtomicCmpXchg(
- atomic_memory_address, cas_old_output, cas_new_output,
- llvm::AtomicOrdering::SequentiallyConsistent,
- llvm::AtomicOrdering::SequentiallyConsistent);
+ llvm::Value* ret_value =
+ AtomicCmpXchg(atomic_memory_address, cas_old_output, cas_new_output,
+ llvm::AtomicOrdering::SequentiallyConsistent,
+ llvm::AtomicOrdering::SequentiallyConsistent);
// Extract the memory value returned from atomicCAS and store it as
// cas_old_output.
- b_.CreateStore(b_.CreateExtractValue(ret_value, 0, "cas_old_output"),
- cas_old_output_address);
+ Store(ExtractValue(ret_value, 0, "cas_old_output"), cas_old_output_address);
// Extract the success bit returned from atomicCAS and generate a
// conditional branch on the success bit.
- b_.CreateCondBr(b_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb,
- loop_body_bb);
+ CondBr(ExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb);
// Set the insertion point to the exit basic block so that the caller of
// this method can continue emitting code to the right place.
@@ -472,10 +466,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
auto value = MultiplyComplex(lhs_value, rhs_value, &b_);
result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType());
- result = b_.CreateInsertValue(result, value.first, {0});
- result = b_.CreateInsertValue(result, value.second, {1});
+ result = InsertValue(result, value.first, {0});
+ result = InsertValue(result, value.second, {1});
} else {
- result = b_.CreateFMul(lhs_value, rhs_value);
+ result = FMul(lhs_value, rhs_value);
}
target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_);
return Status::OK();
@@ -559,21 +553,21 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
&*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt());
llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_);
llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_);
- llvm::Value* accum = b_.CreateLoad(accum_address);
+ llvm::Value* accum = Load(accum_address);
llvm::Value* updated_accum;
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
auto value = MultiplyComplex(lhs_element, rhs_element, &b_);
llvm::Value* accum_real = Real(accum, &b_);
- llvm::Value* real_sum = b_.CreateFAdd(accum_real, value.first);
- updated_accum = b_.CreateInsertValue(accum, real_sum, {0});
+ llvm::Value* real_sum = FAdd(accum_real, value.first);
+ updated_accum = InsertValue(accum, real_sum, {0});
llvm::Value* accum_imag = Imag(accum, &b_);
- llvm::Value* imag_sum = b_.CreateFAdd(accum_imag, value.second);
- updated_accum = b_.CreateInsertValue(updated_accum, imag_sum, {1});
+ llvm::Value* imag_sum = FAdd(accum_imag, value.second);
+ updated_accum = InsertValue(updated_accum, imag_sum, {1});
} else {
- llvm::Value* product = b_.CreateFMul(lhs_element, rhs_element);
- updated_accum = b_.CreateFAdd(accum, product);
+ llvm::Value* product = FMul(lhs_element, rhs_element);
+ updated_accum = FAdd(accum, product);
}
- b_.CreateStore(updated_accum, accum_address);
+ Store(updated_accum, accum_address);
// After the reduction loop exits, store the accumulator into the target
// address. The index into the target address is the concatenation of the rhs
@@ -595,7 +589,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_);
target_array.EmitWriteArrayElement(
target_index,
- b_.CreateLoad(accum_address), // The value written to the target array.
+ Load(accum_address), // The value written to the target array.
&b_);
// Set the IR builder insert point to the exit basic block of the outer most
@@ -646,10 +640,9 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
[=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
// Initialize an accumulator with init_value.
llvm::AllocaInst* accumulator_addr =
- b_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
+ Alloca(llvm_ir::PrimitiveTypeToIrType(
reduce->shape().element_type(), module_));
- b_.CreateStore(b_.CreateLoad(GetBasePointer(*init_value)),
- accumulator_addr);
+ Store(Load(GetBasePointer(*init_value)), accumulator_addr);
// The enclosing loops go over all the target elements. Now we have to
// compute the actual target element. For this, we build a new loop nest
@@ -686,7 +679,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
*function, {accumulator_addr, input_address}, accumulator_addr));
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
- return b_.CreateLoad(accumulator_addr);
+ return Load(accumulator_addr);
});
}
@@ -769,11 +762,11 @@ StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
for (llvm::Value* parameter_element : parameter_elements) {
parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
parameter_element->getType(), "parameter_buffer", &b_));
- b_.CreateStore(parameter_element, parameter_buffers.back());
+ Store(parameter_element, parameter_buffers.back());
}
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers,
return_buffer));
- return b_.CreateLoad(return_buffer);
+ return Load(return_buffer);
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 76e069fc41..e096a07704 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -64,7 +65,8 @@ namespace gpu {
// IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is
// not a subclass of gpu::IrEmitter, and in fact is better understood as an IR
// generator generator. See comments on that class.
-class IrEmitter : public DfsHloVisitorWithDefault {
+class IrEmitter : public DfsHloVisitorWithDefault,
+ public IrBuilderMixin<IrEmitter> {
public:
IrEmitter(const IrEmitter&) = delete;
IrEmitter& operator=(const IrEmitter&) = delete;
@@ -99,6 +101,8 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status FinishVisit(HloInstruction* root) override { return Status::OK(); }
+ llvm::IRBuilder<>* builder() { return &b_; }
+
protected:
// Constructs an IrEmitter with the given IrEmitter context.
// ir_emitter_context is owned by the caller and should outlive the IrEmitter
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 4d98955c58..c0c8ae181a 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -729,7 +729,7 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce(
"extra_output_element_address");
TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
extra_output_gens[i].first(index));
- b_.CreateStore(extra_output_ir_value, extra_output_address);
+ Store(extra_output_ir_value, extra_output_address);
}
return Status::OK();
}
@@ -810,17 +810,17 @@ Status IrEmitterUnnested::EmitReductionToScalar(
std::vector<llvm::Value*> partial_reduction_result_addresses;
for (int i = 0; i != num_reduces; ++i) {
llvm::Value* partial_reduction_result_address =
- b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." + llvm::Twine(i));
+ Alloca(element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." + llvm::Twine(i));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
init_value_gens[i](IrArray::Index(index_ty)));
- b_.CreateStore(init_ir_value, partial_reduction_result_address);
+ Store(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
llvm::Value* x_in_tiles = tile_index[0];
- x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty);
+ x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty);
// Emit an inner for-loop that reduces the elements in the tile.
auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
@@ -832,15 +832,14 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
&b_);
- llvm::Value* x = b_.CreateNSWAdd(
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)),
- tile_element_loop->GetIndVarValue());
+ llvm::Value* x =
+ NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileSize)),
+ tile_element_loop->GetIndVarValue());
// Unless we know the tile is entirely in bounds, we have to emit a
// x-in-bounds check before reading from the input.
if (!tile_in_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds",
- &b_);
+ ICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", &b_);
// Emit code that reads the input element and accumulates it to
// the partial reduction result.
@@ -849,11 +848,11 @@ Status IrEmitterUnnested::EmitReductionToScalar(
IrArray::Index input_index(
/*linear=*/x, input_shape, &b_);
- llvm::Value* input_address = b_.CreateAlloca(element_ir_type);
+ llvm::Value* input_address = Alloca(element_ir_type);
for (int i = 0; i != num_reduces; ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
input_gens[i](input_index));
- b_.CreateStore(input_ir_value, input_address);
+ Store(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], input_address},
@@ -864,14 +863,14 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's
// immediately beyond the tile.
- llvm::Value* x_end = b_.CreateNSWAdd(
- index_typed_constant(kTileSize),
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)));
+ llvm::Value* x_end =
+ NSWAdd(index_typed_constant(kTileSize),
+ NSWMul(x_in_tiles, index_typed_constant(kTileSize)));
// The tile is entirely in bound if all_threads_in_bounds or
// x_end <= num_elems.
llvm::Value* tile_in_bounds =
- b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(num_elems)),
- b_.getInt1(all_threads_in_bounds));
+ Or(ICmpULE(x_end, index_typed_constant(num_elems)),
+ b_.getInt1(all_threads_in_bounds));
llvm_ir::LlvmIfData if_tile_in_bounds_data =
llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_);
llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_);
@@ -892,20 +891,18 @@ Status IrEmitterUnnested::EmitReductionToScalar(
for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1;
shuffle_distance /= 2) {
llvm::Value* result_from_other_lane =
- b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane");
+ Alloca(element_ir_type, nullptr, "result_from_other_lane");
for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result = b_.CreateLoad(
- b_.CreateBitCast(partial_reduction_result_addresses[i],
- shuffle_ir_type->getPointerTo()),
- "partial_reduction_result");
+ llvm::Value* partial_reduction_result =
+ Load(BitCast(partial_reduction_result_addresses[i],
+ shuffle_ir_type->getPointerTo()),
+ "partial_reduction_result");
CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0)
<< "Requires block size a multiple of the warp size, otherwise we "
"will read undefined elements.";
- b_.CreateStore(
- EmitFullWarpShuffleDown(partial_reduction_result,
- b_.getInt32(shuffle_distance), &b_),
- b_.CreateBitCast(result_from_other_lane,
- shuffle_ir_type->getPointerTo()));
+ Store(EmitFullWarpShuffleDown(partial_reduction_result,
+ b_.getInt32(shuffle_distance), &b_),
+ BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo()));
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], result_from_other_lane},
@@ -920,10 +917,9 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// lane 0 (which holds the partially accumulated result for its warp) to the
// output element.
llvm::Value* lane_id =
- b_.CreateURem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id");
+ URem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id");
llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero",
- &b_);
+ ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_);
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
for (int i = 0; i != num_reduces; ++i) {
@@ -1043,12 +1039,12 @@ Status IrEmitterUnnested::EmitColumnReduction(
for (int i = 0; i != num_reduces; ++i) {
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
llvm::Value* partial_reduction_result_address =
- b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." +
- llvm::Twine(i * kTileWidth + x_offset));
+ Alloca(element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." +
+ llvm::Twine(i * kTileWidth + x_offset));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
init_value_gens[i](IrArray::Index(index_ty)));
- b_.CreateStore(init_ir_value, partial_reduction_result_address);
+ Store(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
@@ -1059,8 +1055,8 @@ Status IrEmitterUnnested::EmitColumnReduction(
llvm::Value* y_in_tiles = tile_index[0];
llvm::Value* x_in_tiles = tile_index[1];
- y_in_tiles = b_.CreateZExtOrTrunc(y_in_tiles, index_ty);
- x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty);
+ y_in_tiles = ZExtOrTrunc(y_in_tiles, index_ty);
+ x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty);
auto emit_tile_element_loop = [=](bool tile_in_y_bounds,
bool tile_in_x_bounds) -> Status {
@@ -1072,34 +1068,32 @@ Status IrEmitterUnnested::EmitColumnReduction(
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
&b_);
- llvm::Value* y = b_.CreateNSWAdd(
- b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)),
- tile_element_loop->GetIndVarValue());
+ llvm::Value* y =
+ NSWAdd(NSWMul(y_in_tiles, index_typed_constant(kTileHeight)),
+ tile_element_loop->GetIndVarValue());
// Unless we know that y is in bounds, we have to emit a check before
// reading from the input.
if (!tile_in_y_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpULT(y, index_typed_constant(height)), "y_in_bounds",
- &b_);
+ ICmpULT(y, index_typed_constant(height)), "y_in_bounds", &b_);
// Emit code that reads the input element and accumulates it to
// the partial reduction result.
llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_);
}
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
- llvm::Value* x = b_.CreateNSWAdd(
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
- index_typed_constant(x_offset));
+ llvm::Value* x =
+ NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
+ index_typed_constant(x_offset));
// Unless we know that x is in bounds, we have to emit a check before
// reading from the input.
if (!tile_in_x_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpULT(x, index_typed_constant(width)), "x_in_bounds",
- &b_);
+ ICmpULT(x, index_typed_constant(width)), "x_in_bounds", &b_);
llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_);
}
- llvm::Value* input_address = b_.CreateAlloca(element_ir_type);
+ llvm::Value* input_address = Alloca(element_ir_type);
// {y,x} is an index to input_matrix_shape [height,width]. We need to
// convert that to an index to input_shape (the shape of the operand of
// "reduce"). This conversion is composed of a transposition from
@@ -1126,7 +1120,7 @@ Status IrEmitterUnnested::EmitColumnReduction(
for (int i = 0; i != num_reduces; ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
input_gens[i](input_index));
- b_.CreateStore(input_ir_value, input_address);
+ Store(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i * kTileWidth + x_offset],
@@ -1141,20 +1135,20 @@ Status IrEmitterUnnested::EmitColumnReduction(
// y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location
// that's immediately beyond the tile.
- llvm::Value* y_end = b_.CreateNSWAdd(
- index_typed_constant(kTileHeight),
- b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)));
+ llvm::Value* y_end =
+ NSWAdd(index_typed_constant(kTileHeight),
+ NSWMul(y_in_tiles, index_typed_constant(kTileHeight)));
// x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location
// that's immediately beyond the tile.
- llvm::Value* x_end = b_.CreateNSWAdd(
- index_typed_constant(kTileWidth),
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)));
+ llvm::Value* x_end =
+ NSWAdd(index_typed_constant(kTileWidth),
+ NSWMul(x_in_tiles, index_typed_constant(kTileWidth)));
llvm::Value* tile_in_y_bounds =
- b_.CreateOr(b_.CreateICmpULE(y_end, index_typed_constant(height)),
- b_.getInt1(height % kTileHeight == 0));
+ Or(ICmpULE(y_end, index_typed_constant(height)),
+ b_.getInt1(height % kTileHeight == 0));
llvm::Value* tile_in_x_bounds =
- b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(width)),
- b_.getInt1(width % kTileWidth == 0));
+ Or(ICmpULE(x_end, index_typed_constant(width)),
+ b_.getInt1(width % kTileWidth == 0));
// The tile is in y bounds if "height" is a multiple of kTileHeight or
// y_end <= height.
llvm_ir::LlvmIfData if_tile_in_y_bounds_data =
@@ -1188,9 +1182,9 @@ Status IrEmitterUnnested::EmitColumnReduction(
reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
for (int i = 0; i != num_reduces; ++i) {
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
- llvm::Value* x = b_.CreateNSWAdd(
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
- index_typed_constant(x_offset));
+ llvm::Value* x =
+ NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
+ index_typed_constant(x_offset));
llvm::Value* output_address =
GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
@@ -1379,11 +1373,11 @@ Status IrEmitterUnnested::EmitRowReduction(
std::vector<llvm::Value*> partial_reduction_result_addresses;
for (int i = 0; i != num_reduces; ++i) {
llvm::Value* partial_reduction_result_address =
- b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." + llvm::Twine(i));
+ Alloca(element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." + llvm::Twine(i));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
init_value_gens[i](IrArray::Index(index_ty)));
- b_.CreateStore(init_ir_value, partial_reduction_result_address);
+ Store(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
@@ -1392,22 +1386,20 @@ Status IrEmitterUnnested::EmitRowReduction(
llvm::Value* y = tile_index[1];
llvm::Value* x_tile = tile_index[2];
- x_tile = b_.CreateZExtOrTrunc(x_tile, index_ty);
+ x_tile = ZExtOrTrunc(x_tile, index_ty);
llvm::Value* warp_id =
- b_.CreateUDiv(x_tile, index_typed_constant(kWarpSize), "warp_id");
+ UDiv(x_tile, index_typed_constant(kWarpSize), "warp_id");
llvm::Value* lane_id =
- b_.CreateURem(x_tile, index_typed_constant(kWarpSize), "lane_id");
+ URem(x_tile, index_typed_constant(kWarpSize), "lane_id");
// The x-location of the last element in this z-x-tile.
// last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size);
- llvm::Value* last_x = b_.CreateNSWAdd(
+ llvm::Value* last_x = NSWAdd(
lane_id,
- b_.CreateNSWMul(
- index_typed_constant(kWarpSize),
- b_.CreateNSWAdd(
- index_typed_constant(x_tile_size - 1),
- b_.CreateNSWMul(warp_id, index_typed_constant(x_tile_size)))));
+ NSWMul(index_typed_constant(kWarpSize),
+ NSWAdd(index_typed_constant(x_tile_size - 1),
+ NSWMul(warp_id, index_typed_constant(x_tile_size)))));
KernelSupportLibrary ksl(
&b_,
@@ -1419,9 +1411,8 @@ Status IrEmitterUnnested::EmitRowReduction(
auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds,
int64 x_tile_loop_bound) -> Status {
auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status {
- llvm::Value* z = b_.CreateNSWAdd(
- z_indvar,
- b_.CreateNSWMul(index_typed_constant(z_tile_size), z_tile));
+ llvm::Value* z =
+ NSWAdd(z_indvar, NSWMul(index_typed_constant(z_tile_size), z_tile));
TF_RETURN_IF_ERROR(ksl.For(
"x_tile",
/*start=*/index_typed_constant(0),
@@ -1429,22 +1420,20 @@ Status IrEmitterUnnested::EmitRowReduction(
/*step=*/1, [&](llvm::Value* x_indvar) -> Status {
// x = lane_id +
// warpSize * (element_id_in_x_tile + warp_id * x_tile_size);
- llvm::Value* x = b_.CreateNSWAdd(
+ llvm::Value* x = NSWAdd(
lane_id,
- b_.CreateNSWMul(
- index_typed_constant(kWarpSize),
- b_.CreateNSWAdd(
- x_indvar, b_.CreateNSWMul(
- warp_id, llvm::ConstantInt::get(
- index_ty, x_tile_size)))));
+ NSWMul(index_typed_constant(kWarpSize),
+ NSWAdd(x_indvar,
+ NSWMul(warp_id, llvm::ConstantInt::get(
+ index_ty, x_tile_size)))));
// Unless we know the x-tile is entirely in bounds, we have to
// emit a x-in-bounds check before reading from the input.
if (!x_tile_in_bounds) {
llvm_ir::LlvmIfData if_x_in_bounds_data =
llvm_ir::EmitIfThenElse(
- b_.CreateICmpULT(x, index_typed_constant(width)),
- "x_in_bounds", &b_);
+ ICmpULT(x, index_typed_constant(width)), "x_in_bounds",
+ &b_);
// Points b_ to the then-block.
llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block,
&b_);
@@ -1452,7 +1441,7 @@ Status IrEmitterUnnested::EmitRowReduction(
// Emit code that reads the input element and accumulates it
// to the partial reduction result.
- llvm::Value* input_address = b_.CreateAlloca(element_ir_type);
+ llvm::Value* input_address = Alloca(element_ir_type);
{
// {z,y,x} is an index to input_3d_tensor_shape
// [depth,height,width]. We need to convert that to an index
@@ -1483,7 +1472,7 @@ Status IrEmitterUnnested::EmitRowReduction(
for (int i = 0; i != num_reduces; ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
input_gens[i](input_index));
- b_.CreateStore(input_ir_value, input_address);
+ Store(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], input_address},
@@ -1503,8 +1492,8 @@ Status IrEmitterUnnested::EmitRowReduction(
};
llvm::Value* tile_in_bounds =
- b_.CreateOr(b_.getInt1(width % (x_tile_size * kWarpSize) == 0),
- b_.CreateICmpULT(last_x, index_typed_constant(width)));
+ Or(b_.getInt1(width % (x_tile_size * kWarpSize) == 0),
+ ICmpULT(last_x, index_typed_constant(width)));
TF_RETURN_IF_ERROR(
ksl.If(tile_in_bounds,
@@ -1532,20 +1521,18 @@ Status IrEmitterUnnested::EmitRowReduction(
for (int shuffle_distance = 16; shuffle_distance >= 1;
shuffle_distance /= 2) {
llvm::Value* result_from_other_lane =
- b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane");
+ Alloca(element_ir_type, nullptr, "result_from_other_lane");
for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result = b_.CreateLoad(
- b_.CreateBitCast(partial_reduction_result_addresses[i],
- shuffle_ir_type->getPointerTo()),
- "partial_reduction_result");
+ llvm::Value* partial_reduction_result =
+ Load(BitCast(partial_reduction_result_addresses[i],
+ shuffle_ir_type->getPointerTo()),
+ "partial_reduction_result");
CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0)
<< "Requires block size a multiple of the warp size, otherwise we "
"will read undefined elements.";
- b_.CreateStore(
- EmitFullWarpShuffleDown(partial_reduction_result,
- b_.getInt32(shuffle_distance), &b_),
- b_.CreateBitCast(result_from_other_lane,
- shuffle_ir_type->getPointerTo()));
+ Store(EmitFullWarpShuffleDown(partial_reduction_result,
+ b_.getInt32(shuffle_distance), &b_),
+ BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo()));
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], result_from_other_lane},
@@ -1560,8 +1547,7 @@ Status IrEmitterUnnested::EmitRowReduction(
// lane 0 (which holds the partially accumulated result for its warp) to the
// output element.
llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero",
- &b_);
+ ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_);
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
for (int i = 0; i != num_reduces; ++i) {
llvm::Value* output_address =
@@ -1845,7 +1831,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
&b_);
llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
b_.getInt1Ty(), "initialized_flag_address", &b_);
- b_.CreateStore(b_.getInt1(false), initialized_flag_address);
+ Store(b_.getInt1(false), initialized_flag_address);
// Create the inner loop to iterate over the window.
llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_,
@@ -1866,15 +1852,15 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
IrArray::Index operand_index(index_type, source_index.size());
llvm::Value* in_bounds_condition = b_.getInt1(true);
for (int64 i = 0; i < rank; ++i) {
- llvm::Value* strided_index = b_.CreateNSWMul(
+ llvm::Value* strided_index = NSWMul(
source_index[i], index_typed_constant(window.dimensions(i).stride()));
- operand_index[i] = b_.CreateNSWSub(
- b_.CreateNSWAdd(strided_index, window_index[i]),
- index_typed_constant(window.dimensions(i).padding_low()));
- llvm::Value* index_condition = b_.CreateICmpULT(
+ operand_index[i] =
+ NSWSub(NSWAdd(strided_index, window_index[i]),
+ index_typed_constant(window.dimensions(i).padding_low()));
+ llvm::Value* index_condition = ICmpULT(
operand_index[i],
index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i)));
- in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition);
+ in_bounds_condition = And(in_bounds_condition, index_condition);
}
CHECK(in_bounds_condition != nullptr);
@@ -1884,7 +1870,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
- b_.CreateLoad(initialized_flag_address), "initialized", &b_);
+ Load(initialized_flag_address), "initialized", &b_);
// If the initialized_flag is false, initialize the selected value and index
// with the currently visiting operand.
@@ -1892,16 +1878,16 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
const auto save_operand_index = [&](const IrArray::Index& operand_index) {
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
- b_.CreateStore(operand_index[i], selected_index_address_slot);
+ InBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ Store(operand_index[i], selected_index_address_slot);
}
};
IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
llvm::Value* operand_data =
operand_array.EmitReadArrayElement(operand_index, &b_);
- b_.CreateStore(operand_data, selected_value_address);
+ Store(operand_data, selected_value_address);
save_operand_index(operand_index);
- b_.CreateStore(b_.getInt1(true), initialized_flag_address);
+ Store(b_.getInt1(true), initialized_flag_address);
// If the initialized_flag is true, call the `select` function to
// potentially update the selected value and index with the currently
@@ -1917,11 +1903,11 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*select_and_scatter->select(),
{selected_value_address, operand_address}, select_return_buffer));
- llvm::Value* result = b_.CreateLoad(select_return_buffer);
+ llvm::Value* result = Load(select_return_buffer);
// If the 'select' function returns false, update the selected value and the
// index to the currently visiting operand.
- llvm::Value* cond = b_.CreateICmpNE(
+ llvm::Value* cond = ICmpNE(
result,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(
PRED, ir_emitter_context_->llvm_module()),
@@ -1930,7 +1916,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
llvm_ir::LlvmIfData if_select_lhs =
llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
- b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address);
+ Store(Load(operand_address), selected_value_address);
save_operand_index(operand_index);
// After iterating over the window elements, scatter the source element to
@@ -1942,8 +1928,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
IrArray::Index selected_index(operand_index.GetType());
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
- selected_index.push_back(b_.CreateLoad(selected_index_address_slot));
+ InBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ selected_index.push_back(Load(selected_index_address_slot));
}
llvm::Value* source_value_address =
GetIrArray(*source, *select_and_scatter)
@@ -2367,8 +2353,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
*slice.allocation())));
CHECK_NE(loc, nullptr);
} else {
- loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()),
- {b_.getInt64(slice.offset())});
+ loc = InBoundsGEP(kernel_args.at(slice.allocation()),
+ {b_.getInt64(slice.offset())});
}
// If gte_index is nonempty, we have to dereference `loc` to get to the
@@ -2376,8 +2362,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
llvm::Type* int8_double_pointer =
llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0);
for (int64 idx : gte_index) {
- loc = b_.CreateBitCast(loc, int8_double_pointer);
- loc = b_.CreateLoad(b_.CreateInBoundsGEP(loc, {b_.getInt64(idx)}));
+ loc = BitCast(loc, int8_double_pointer);
+ loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)}));
}
bindings_.BindHloToIrValue(*instr, loc, index);
@@ -3154,9 +3140,8 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
const IrArray::Index output_tile_origin = [&] {
IrArray::Index index = output_tile_index;
for (int i = 1; i < 3; ++i) {
- index[i] =
- b_.CreateMul(output_tile_index[i], index_typed_constant(kTileSize),
- "tile_origin." + std::to_string(i));
+ index[i] = Mul(output_tile_index[i], index_typed_constant(kTileSize),
+ "tile_origin." + std::to_string(i));
}
return index;
}();
@@ -3169,12 +3154,12 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
std::vector<llvm::Value*> output_tile_bounds(3);
for (int i = 1; i < 3; ++i) {
// Only last row or column may not have full size.
- output_tile_bounds[i] = b_.CreateSelect(
- b_.CreateICmpEQ(output_tile_index[i],
- index_typed_constant(output_dims_in_tiles[i] - 1)),
- index_typed_constant(reduced_output_dims[i] -
- (output_dims_in_tiles[i] - 1) * kTileSize),
- index_typed_constant(kTileSize), "kTileSize");
+ output_tile_bounds[i] =
+ Select(ICmpEQ(output_tile_index[i],
+ index_typed_constant(output_dims_in_tiles[i] - 1)),
+ index_typed_constant(reduced_output_dims[i] -
+ (output_dims_in_tiles[i] - 1) * kTileSize),
+ index_typed_constant(kTileSize), "kTileSize");
}
KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
@@ -3192,7 +3177,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
// Adds `addend` to the given `dim` of `index`.
auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) {
- index[dim] = b_.CreateAdd(index[dim], addend);
+ index[dim] = Add(index[dim], addend);
return index;
};
const IrArray::Index input_index =
@@ -3208,10 +3193,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
llvm::Value* shmem_buffer = param_shmem_buffers[id];
// TODO(jlebar): Add AA metadata to this store. Tile buffers are
// global variables, so LLVM can't infer much about it.
- b_.CreateStore(
- input_in_logical_shape.EmitReadArrayElement(index, &b_,
- "input_element"),
- b_.CreateGEP(shmem_buffer, {index_typed_constant(0), y_loc, x}));
+ Store(input_in_logical_shape.EmitReadArrayElement(index, &b_,
+ "input_element"),
+ GEP(shmem_buffer, {index_typed_constant(0), y_loc, x}));
}
});
@@ -3232,9 +3216,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
output_index, "output", output_tile_bounds[2], output_tile_bounds[1],
[&](const IrArray::Index& index, llvm::Value* y_loc) {
// TODO(jlebar): Add AA metadata to this load.
- llvm::Instruction* load_from_shmem_buffer = b_.CreateLoad(
- b_.CreateGEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}),
- "output_element");
+ llvm::Instruction* load_from_shmem_buffer =
+ Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}),
+ "output_element");
output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
index, load_from_shmem_buffer, &b_);
});
@@ -3262,7 +3246,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
output_in_reduced_shape_arrays.size());
for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) {
output_in_reduced_shape_arrays[i].EmitWriteArrayElement(
- index, b_.CreateExtractValue(output_value, i), &b_);
+ index, ExtractValue(output_value, i), &b_);
}
} else {
output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index 786448ea76..be12d7c90c 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -249,3 +249,12 @@ cc_library(
"@llvm//:core",
],
)
+
+cc_library(
+ name = "ir_builder_mixin",
+ srcs = [],
+ hdrs = ["ir_builder_mixin.h"],
+ deps = [
+ "@llvm//:core",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h
new file mode 100644
index 0000000000..abc06fb7b4
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h
@@ -0,0 +1,400 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_
+
+#include "llvm/IR/IRBuilder.h"
+
+namespace xla {
+
+// Mixin class that injects more ergonomic versions of llvm::IRBuilder methods
+// into a class. Intended to be used as a CRTP base class, like:
+//
+// class MyIrEmitter : public IrBuilderMixin<MyIrEmitter> {
+// llvm::IRBuilder<>* builder() { return builder_; }
+//
+// void EmitFoo(HloInstruction* foo) {
+// Add(Mul(...), FPToUI(...));
+// }
+// };
+
+template <typename Derived>
+class IrBuilderMixin {
+ protected:
+ template <class... Args>
+ llvm::Value* Add(Args&&... args) {
+ return mixin_builder()->CreateAdd(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::LoadInst* AlignedLoad(Args&&... args) {
+ return mixin_builder()->CreateAlignedLoad(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::StoreInst* AlignedStore(Args&&... args) {
+ return mixin_builder()->CreateAlignedStore(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::AllocaInst* Alloca(Args&&... args) {
+ return mixin_builder()->CreateAlloca(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* And(Args&&... args) {
+ return mixin_builder()->CreateAnd(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* AtomicCmpXchg(Args&&... args) {
+ return mixin_builder()->CreateAtomicCmpXchg(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* AtomicRMW(Args&&... args) {
+ return mixin_builder()->CreateAtomicRMW(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* BitCast(Args&&... args) {
+ return mixin_builder()->CreateBitCast(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Br(Args&&... args) {
+ return mixin_builder()->CreateBr(std::forward<Args>(args)...);
+ }
+
+ llvm::CallInst* Call(llvm::Value* callee,
+ llvm::ArrayRef<llvm::Value*> args = llvm::None,
+ const llvm::Twine& name = "",
+ llvm::MDNode* fp_math_tag = nullptr) {
+ return mixin_builder()->CreateCall(callee, args, name, fp_math_tag);
+ }
+
+ template <class... Args>
+ llvm::BranchInst* CondBr(Args&&... args) {
+ return mixin_builder()->CreateCondBr(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ConstInBoundsGEP1_32(Args&&... args) {
+ return mixin_builder()->CreateConstInBoundsGEP1_32(
+ std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FAdd(Args&&... args) {
+ return mixin_builder()->CreateFAdd(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FMul(Args&&... args) {
+ return mixin_builder()->CreateFMul(std::forward<Args>(args)...);
+ }
+
+ llvm::Value* GEP(llvm::Value* ptr, llvm::ArrayRef<llvm::Value*> idx_list,
+ const llvm::Twine& name = "") {
+ return mixin_builder()->CreateGEP(ptr, idx_list, name);
+ }
+
+ template <class... Args>
+ llvm::Value* ICmpEQ(Args&&... args) {
+ return mixin_builder()->CreateICmpEQ(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ICmpNE(Args&&... args) {
+ return mixin_builder()->CreateICmpNE(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ICmpULE(Args&&... args) {
+ return mixin_builder()->CreateICmpULE(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ICmpULT(Args&&... args) {
+ return mixin_builder()->CreateICmpULT(std::forward<Args>(args)...);
+ }
+
+ llvm::Value* InBoundsGEP(llvm::Value* ptr,
+ llvm::ArrayRef<llvm::Value*> idx_list,
+ const llvm::Twine& name = "") {
+ return mixin_builder()->CreateInBoundsGEP(ptr, idx_list, name);
+ }
+
+ llvm::Value* ExtractValue(llvm::Value* agg, llvm::ArrayRef<unsigned> idxs,
+ const llvm::Twine& name = "") {
+ return mixin_builder()->CreateExtractValue(agg, idxs, name);
+ }
+
+ llvm::Value* InsertValue(llvm::Value* agg, llvm::Value* val,
+ llvm::ArrayRef<unsigned> idxs,
+ const llvm::Twine& name = "") {
+ return mixin_builder()->CreateInsertValue(agg, val, idxs, name);
+ }
+
+ template <class... Args>
+ llvm::Value* IntToPtr(Args&&... args) {
+ return mixin_builder()->CreateIntToPtr(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::LoadInst* Load(Args&&... args) {
+ return mixin_builder()->CreateLoad(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::CallInst* MemCpy(Args&&... args) {
+ return mixin_builder()->CreateMemCpy(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Mul(Args&&... args) {
+ return mixin_builder()->CreateMul(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* NSWAdd(Args&&... args) {
+ return mixin_builder()->CreateNSWAdd(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* NSWMul(Args&&... args) {
+ return mixin_builder()->CreateNSWMul(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* NSWSub(Args&&... args) {
+ return mixin_builder()->CreateNSWSub(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Or(Args&&... args) {
+ return mixin_builder()->CreateOr(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* PointerCast(Args&&... args) {
+ return mixin_builder()->CreatePointerCast(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* PtrToInt(Args&&... args) {
+ return mixin_builder()->CreatePtrToInt(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* SDiv(Args&&... args) {
+ return mixin_builder()->CreateSDiv(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Select(Args&&... args) {
+ return mixin_builder()->CreateSelect(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* SRem(Args&&... args) {
+ return mixin_builder()->CreateSRem(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::StoreInst* Store(Args&&... args) {
+ return mixin_builder()->CreateStore(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* UDiv(Args&&... args) {
+ return mixin_builder()->CreateUDiv(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* URem(Args&&... args) {
+ return mixin_builder()->CreateURem(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* VectorSplat(Args&&... args) {
+ return mixin_builder()->CreateVectorSplat(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ZExtOrTrunc(Args&&... args) {
+ return mixin_builder()->CreateZExtOrTrunc(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* AShr(Args&&... args) {
+ return mixin_builder()->CreateAShr(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FCmpOEQ(Args&&... args) {
+ return mixin_builder()->CreateFCmpOEQ(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FCmpOLT(Args&&... args) {
+ return mixin_builder()->CreateFCmpOLT(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FCmpONE(Args&&... args) {
+ return mixin_builder()->CreateFCmpONE(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FCmpUNE(Args&&... args) {
+ return mixin_builder()->CreateFCmpUNE(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FDiv(Args&&... args) {
+ return mixin_builder()->CreateFDiv(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FNeg(Args&&... args) {
+ return mixin_builder()->CreateFNeg(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FPCast(Args&&... args) {
+ return mixin_builder()->CreateFPCast(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FPToSI(Args&&... args) {
+ return mixin_builder()->CreateFPToSI(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FPToUI(Args&&... args) {
+ return mixin_builder()->CreateFPToUI(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FPTrunc(Args&&... args) {
+ return mixin_builder()->CreateFPTrunc(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FRem(Args&&... args) {
+ return mixin_builder()->CreateFRem(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FSub(Args&&... args) {
+ return mixin_builder()->CreateFSub(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ICmpSGE(Args&&... args) {
+ return mixin_builder()->CreateICmpSGE(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ICmpSLT(Args&&... args) {
+ return mixin_builder()->CreateICmpSLT(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* IntCast(Args&&... args) {
+ return mixin_builder()->CreateIntCast(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* LShr(Args&&... args) {
+ return mixin_builder()->CreateLShr(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* MemSet(Args&&... args) {
+ return mixin_builder()->CreateMemSet(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Neg(Args&&... args) {
+ return mixin_builder()->CreateNeg(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Not(Args&&... args) {
+ return mixin_builder()->CreateNot(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::PHINode* PHI(Args&&... args) {
+ return mixin_builder()->CreatePHI(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* RetVoid(Args&&... args) {
+ return mixin_builder()->CreateRetVoid(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* SExtOrTrunc(Args&&... args) {
+ return mixin_builder()->CreateSExtOrTrunc(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Shl(Args&&... args) {
+ return mixin_builder()->CreateShl(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* SIToFP(Args&&... args) {
+ return mixin_builder()->CreateSIToFP(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Sub(Args&&... args) {
+ return mixin_builder()->CreateSub(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Trunc(Args&&... args) {
+ return mixin_builder()->CreateTrunc(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* UIToFP(Args&&... args) {
+ return mixin_builder()->CreateUIToFP(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Unreachable(Args&&... args) {
+ return mixin_builder()->CreateUnreachable(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Xor(Args&&... args) {
+ return mixin_builder()->CreateXor(std::forward<Args>(args)...);
+ }
+
+ private:
+ llvm::IRBuilder<>* mixin_builder() {
+ return static_cast<Derived*>(this)->builder();
+ }
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_