aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-08-27 18:50:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 18:55:02 -0700
commitfa607e7e9224b4d88ead0a81fc65c7884d25950a (patch)
tree3d0acc58934efb515a5b28d38b9397b7aa467205 /tensorflow/compiler/xla/service/cpu
parent9422d3d57c62399e425f63a769dcfa6ebd163bdc (diff)
Use a mixin to reduce llvm::IRBuilder<> related boilerplate.
PiperOrigin-RevId: 210472260
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu')
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc25
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc376
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h7
5 files changed, 204 insertions, 211 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index f0adfc5d45..4cd192873f 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -278,6 +278,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index db54454707..c8312d80bd 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -30,15 +30,16 @@ limitations under the License.
namespace xla {
namespace cpu {
-StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
- PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const {
+StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
+ llvm::Value* lhs,
+ llvm::Value* rhs) {
string function_name;
bool cast_result_to_fp16 = false;
switch (prim_type) {
case F16:
cast_result_to_fp16 = true;
- lhs = b_->CreateFPCast(lhs, b_->getFloatTy());
- rhs = b_->CreateFPCast(rhs, b_->getFloatTy());
+ lhs = FPCast(lhs, b_->getFloatTy());
+ rhs = FPCast(rhs, b_->getFloatTy());
TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "atan2f";
@@ -58,21 +59,21 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
// Create an instruction to call the function.
- llvm::Value* result = b_->CreateCall(function, {lhs, rhs});
+ llvm::Value* result = Call(function, {lhs, rhs});
if (cast_result_to_fp16) {
- result = b_->CreateFPCast(result, b_->getHalfTy());
+ result = FPCast(result, b_->getHalfTy());
}
return result;
}
-StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
+ llvm::Value* value) {
bool cast_result_to_fp16 = false;
string function_name;
switch (prim_type) {
case F16:
cast_result_to_fp16 = true;
- value = b_->CreateFPCast(value, b_->getFloatTy());
+ value = FPCast(value, b_->getFloatTy());
TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "tanhf";
@@ -91,16 +92,16 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
// Create an instruction to call the function.
- llvm::Value* result = b_->CreateCall(function, value);
+ llvm::Value* result = Call(function, value);
if (cast_result_to_fp16) {
- result = b_->CreateFPCast(result, b_->getHalfTy());
+ result = FPCast(result, b_->getHalfTy());
}
return result;
}
llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const {
+ const HloToElementGeneratorMap& operand_to_generator) {
if (hlo->opcode() == HloOpcode::kMap) {
return [this, hlo, &operand_to_generator](
const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
index 76833e765d..e3fba9306b 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
@@ -36,13 +36,13 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
llvm_ir::ElementGenerator MakeElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const override;
+ const HloToElementGeneratorMap& operand_to_generator) override;
protected:
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
- llvm::Value* rhs) const override;
+ llvm::Value* rhs) override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
IrEmitter* ir_emitter_;
};
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 321c2e9896..903e73f606 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -170,9 +170,9 @@ IrEmitter::~IrEmitter() {}
Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
VLOG(2) << "HandleBitcast: " << bitcast->ToString();
emitted_value_[bitcast] =
- b_.CreateBitCast(GetEmittedValueFor(bitcast->operand(0)),
- IrShapeType(bitcast->shape())->getPointerTo(),
- AsStringRef(IrName(bitcast)));
+ BitCast(GetEmittedValueFor(bitcast->operand(0)),
+ IrShapeType(bitcast->shape())->getPointerTo(),
+ AsStringRef(IrName(bitcast)));
return Status::OK();
}
@@ -439,22 +439,22 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
// of size exactly 'length_32', and the runtime is responsible for
// check-failing the process if there is a mismatch, versus passing us back a
// buffer that we might overrun.
- llvm::Value* acquired_pointer = b_.CreateCall(
- acquire_func,
- {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)});
+ llvm::Value* acquired_pointer =
+ Call(acquire_func,
+ {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)});
if (kind == XfeedKind::kInfeed) {
// Copy to the program buffer address from the acquired buffer.
- b_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer,
- /*SrcAlign=*/1, length_32);
+ MemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer,
+ /*SrcAlign=*/1, length_32);
} else {
// Outfeed -- copy from the in-program address to the acquired buffer.
- b_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address,
- /*SrcAlign=*/1, length_32);
+ MemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address,
+ /*SrcAlign=*/1, length_32);
}
- b_.CreateCall(release_func, {b_.getInt32(length_32), acquired_pointer,
- shape_ptr, b_.getInt32(shape_length)});
+ Call(release_func, {b_.getInt32(length_32), acquired_pointer, shape_ptr,
+ b_.getInt32(shape_length)});
return Status::OK();
}
@@ -518,8 +518,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
"reduce_window_accumulator_address", &b_,
MinimumAlignmentForPrimitiveType(operand_element_type));
- b_.CreateStore(b_.CreateLoad(GetEmittedValueFor(reduce_window->operand(1))),
- accumulator_address);
+ Store(Load(GetEmittedValueFor(reduce_window->operand(1))),
+ accumulator_address);
llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_);
std::vector<int64> window_size;
@@ -536,22 +536,21 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
llvm::Value* in_bounds_condition = nullptr;
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* strided_index =
- b_.CreateNSWMul(index[i], b_.getInt64(window.dimensions(i).stride()));
- input_index[i] =
- b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]),
- b_.getInt64(window.dimensions(i).padding_low()));
+ NSWMul(index[i], b_.getInt64(window.dimensions(i).stride()));
+ input_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
+ b_.getInt64(window.dimensions(i).padding_low()));
// We need to check if 0 <= input_index[i] < bound, as otherwise we are in
// the padding so that we can skip the computation. That is equivalent to
// input_index[i] < bound as an *unsigned* comparison, since a negative
// value will wrap to a large positive value.
- llvm::Value* index_condition = b_.CreateICmpULT(
- input_index[i],
- b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
+ llvm::Value* index_condition =
+ ICmpULT(input_index[i],
+ b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
if (in_bounds_condition == nullptr) {
in_bounds_condition = index_condition;
} else {
- in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition);
+ in_bounds_condition = And(in_bounds_condition, index_condition);
}
}
CHECK(in_bounds_condition != nullptr);
@@ -564,12 +563,12 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
llvm_ir::IrArray input_array(GetIrArrayFor(operand));
llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_);
llvm::Value* result = EmitThreadLocalCall(
- *reduce_window->to_apply(),
- {b_.CreateLoad(accumulator_address), input_value}, "reducer_function");
- b_.CreateStore(result, accumulator_address);
+ *reduce_window->to_apply(), {Load(accumulator_address), input_value},
+ "reducer_function");
+ Store(result, accumulator_address);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
- return b_.CreateLoad(accumulator_address);
+ return Load(accumulator_address);
}
Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
@@ -646,7 +645,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"),
[this, init_value](const llvm_ir::IrArray::Index& target_index) {
llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
- return b_.CreateLoad(init_value_addr);
+ return Load(init_value_addr);
}));
// Create a loop to iterate over the source array to scatter to the output.
@@ -666,7 +665,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_);
llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
b_.getInt1Ty(), "initialized_flag_address", &b_);
- b_.CreateStore(b_.getInt1(false), initialized_flag_address);
+ Store(b_.getInt1(false), initialized_flag_address);
// Create the inner loop to iterate over the window.
llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_);
@@ -684,15 +683,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
llvm_ir::IrArray::Index operand_index(b_.getInt64Ty(), source_index.size());
llvm::Value* in_bounds_condition = b_.getTrue();
for (int64 i = 0; i < rank; ++i) {
- llvm::Value* strided_index = b_.CreateNSWMul(
- source_index[i], b_.getInt64(window.dimensions(i).stride()));
- operand_index[i] =
- b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]),
- b_.getInt64(window.dimensions(i).padding_low()));
- llvm::Value* index_condition = b_.CreateICmpULT(
- operand_index[i],
- b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
- in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition);
+ llvm::Value* strided_index =
+ NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride()));
+ operand_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
+ b_.getInt64(window.dimensions(i).padding_low()));
+ llvm::Value* index_condition =
+ ICmpULT(operand_index[i],
+ b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
+ in_bounds_condition = And(in_bounds_condition, index_condition);
}
CHECK(in_bounds_condition != nullptr);
@@ -702,7 +700,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
- b_.CreateLoad(initialized_flag_address), "initialized", &b_);
+ Load(initialized_flag_address), "initialized", &b_);
// If the initialized_flag is false, initialize the selected value and index
// with the currently visiting operand.
@@ -711,38 +709,37 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
[&](const llvm_ir::IrArray::Index& operand_index) {
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
- b_.CreateStore(operand_index[i], selected_index_address_slot);
+ InBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ Store(operand_index[i], selected_index_address_slot);
}
};
llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
llvm::Value* operand_data =
operand_array.EmitReadArrayElement(operand_index, &b_);
- b_.CreateStore(operand_data, selected_value_address);
+ Store(operand_data, selected_value_address);
save_operand_index(operand_index);
- b_.CreateStore(b_.getInt1(true), initialized_flag_address);
+ Store(b_.getInt1(true), initialized_flag_address);
// If the initialized_flag is true, call the `select` function to potentially
// update the selected value and index with the currently visiting operand.
SetToFirstInsertPoint(if_initialized.true_block, &b_);
llvm::Value* operand_address =
operand_array.EmitArrayElementAddress(operand_index, &b_);
- llvm::Value* operand_element = b_.CreateLoad(operand_address);
+ llvm::Value* operand_element = Load(operand_address);
llvm::Value* result = EmitThreadLocalCall(
*select_and_scatter->select(),
- {b_.CreateLoad(selected_value_address), operand_element},
- "select_function");
+ {Load(selected_value_address), operand_element}, "select_function");
// If the 'select' function returns false, update the selected value and the
// index to the currently visiting operand.
- llvm::Value* cond = b_.CreateICmpNE(
+ llvm::Value* cond = ICmpNE(
result,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
"boolean_predicate");
llvm_ir::LlvmIfData if_select_lhs =
llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
- b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address);
+ Store(Load(operand_address), selected_value_address);
save_operand_index(operand_index);
// After iterating over the window elements, scatter the source element to
@@ -753,8 +750,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
llvm_ir::IrArray::Index selected_index(source_index.GetType());
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
- selected_index.push_back(b_.CreateLoad(selected_index_address_slot));
+ InBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ selected_index.push_back(Load(selected_index_address_slot));
}
llvm_ir::IrArray source_array(GetIrArrayFor(source));
llvm::Value* source_value =
@@ -836,7 +833,7 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
lhs_llvm_type, "convolution_sum_address", &b_,
MinimumAlignmentForPrimitiveType(lhs_element_type));
llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type);
- b_.CreateStore(constant_zero, sum_address);
+ Store(constant_zero, sum_address);
llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_);
std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
@@ -863,11 +860,11 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
llvm::Value* kernel_index,
const WindowDimension& window_dim) {
llvm::Value* strided_index =
- b_.CreateNSWMul(output_index, b_.getInt64(window_dim.stride()));
- llvm::Value* dilated_kernel_index = b_.CreateNSWMul(
- kernel_index, b_.getInt64(window_dim.window_dilation()));
- return b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, dilated_kernel_index),
- b_.getInt64(window_dim.padding_low()));
+ NSWMul(output_index, b_.getInt64(window_dim.stride()));
+ llvm::Value* dilated_kernel_index =
+ NSWMul(kernel_index, b_.getInt64(window_dim.window_dilation()));
+ return NSWSub(NSWAdd(strided_index, dilated_kernel_index),
+ b_.getInt64(window_dim.padding_low()));
};
std::vector<llvm::Value*> input_spatial(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
@@ -884,9 +881,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
// Also need to check that the input coordinates are not in one of the
// holes created by base dilation.
const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
- llvm::Value* remainder =
- b_.CreateSRem(input_index, b_.getInt64(base_dilation));
- return b_.CreateICmpEQ(remainder, b_.getInt64(0));
+ llvm::Value* remainder = SRem(input_index, b_.getInt64(base_dilation));
+ return ICmpEQ(remainder, b_.getInt64(0));
};
llvm::Value* in_bounds_condition = b_.getInt1(true);
@@ -894,17 +890,17 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound(
lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
window.dimensions(i).base_dilation()));
- llvm::Value* dim_in_bound = b_.CreateICmpULT(input_spatial[i], input_bound);
+ llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound);
llvm::Value* dim_not_in_hole =
not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
- llvm::Value* dim_ok = b_.CreateAnd(dim_in_bound, dim_not_in_hole);
- in_bounds_condition = b_.CreateAnd(in_bounds_condition, dim_ok);
+ llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole);
+ in_bounds_condition = And(in_bounds_condition, dim_ok);
}
// Now we need to map the dilated base coordinates back to the actual
// data indices on the lhs.
const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
- return b_.CreateSDiv(input_index, b_.getInt64(base_dilation));
+ return SDiv(input_index, b_.getInt64(base_dilation));
};
for (int i = 0; i < num_spatial_dims; ++i) {
input_spatial[i] =
@@ -929,8 +925,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
for (int i = 0; i < num_spatial_dims; ++i) {
kernel_index[dnums.kernel_spatial_dimensions(i)] =
window.dimensions(i).window_reversal()
- ? b_.CreateNSWSub(b_.getInt64(window.dimensions(i).size() - 1),
- kernel_spatial[i])
+ ? NSWSub(b_.getInt64(window.dimensions(i).size() - 1),
+ kernel_spatial[i])
: kernel_spatial[i];
}
@@ -939,13 +935,13 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
llvm_ir::IrArray input_array(GetIrArrayFor(lhs));
llvm::Value* product =
- b_.CreateFMul(input_array.EmitReadArrayElement(input_index, &b_),
- kernel_array.EmitReadArrayElement(kernel_index, &b_));
- llvm::Value* sum = b_.CreateFAdd(b_.CreateLoad(sum_address), product);
- b_.CreateStore(sum, sum_address);
+ FMul(input_array.EmitReadArrayElement(input_index, &b_),
+ kernel_array.EmitReadArrayElement(kernel_index, &b_));
+ llvm::Value* sum = FAdd(Load(sum_address), product);
+ Store(sum, sum_address);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
- return b_.CreateLoad(sum_address);
+ return Load(sum_address);
}
Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
@@ -1071,34 +1067,32 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
conv_func->setCallingConv(llvm::CallingConv::C);
conv_func->setDoesNotThrow();
conv_func->setOnlyAccessesArgMemory();
- b_.CreateCall(
- conv_func,
- {
- GetExecutableRunOptionsArgument(),
- b_.CreateBitCast(GetEmittedValueFor(convolution), ir_ptr_type),
- b_.CreateBitCast(lhs_address, ir_ptr_type),
- b_.CreateBitCast(rhs_address, ir_ptr_type),
- b_.getInt64(input_batch),
- b_.getInt64(input_rows),
- b_.getInt64(input_cols),
- b_.getInt64(input_channels),
- b_.getInt64(kernel_rows),
- b_.getInt64(kernel_cols),
- b_.getInt64(kernel_channels),
- b_.getInt64(kernel_filters),
- b_.getInt64(output_rows),
- b_.getInt64(output_cols),
- b_.getInt64(row_stride),
- b_.getInt64(col_stride),
- b_.getInt64(padding_top),
- b_.getInt64(padding_bottom),
- b_.getInt64(padding_left),
- b_.getInt64(padding_right),
- b_.getInt64(lhs_row_dilation),
- b_.getInt64(lhs_col_dilation),
- b_.getInt64(rhs_row_dilation),
- b_.getInt64(rhs_col_dilation),
- });
+ Call(conv_func, {
+ GetExecutableRunOptionsArgument(),
+ BitCast(GetEmittedValueFor(convolution), ir_ptr_type),
+ BitCast(lhs_address, ir_ptr_type),
+ BitCast(rhs_address, ir_ptr_type),
+ b_.getInt64(input_batch),
+ b_.getInt64(input_rows),
+ b_.getInt64(input_cols),
+ b_.getInt64(input_channels),
+ b_.getInt64(kernel_rows),
+ b_.getInt64(kernel_cols),
+ b_.getInt64(kernel_channels),
+ b_.getInt64(kernel_filters),
+ b_.getInt64(output_rows),
+ b_.getInt64(output_cols),
+ b_.getInt64(row_stride),
+ b_.getInt64(col_stride),
+ b_.getInt64(padding_top),
+ b_.getInt64(padding_bottom),
+ b_.getInt64(padding_left),
+ b_.getInt64(padding_right),
+ b_.getInt64(lhs_row_dilation),
+ b_.getInt64(lhs_col_dilation),
+ b_.getInt64(rhs_row_dilation),
+ b_.getInt64(rhs_col_dilation),
+ });
return Status::OK();
}
@@ -1158,15 +1152,14 @@ Status IrEmitter::HandleFft(HloInstruction* fft) {
fft_func->setDoesNotThrow();
fft_func->setOnlyAccessesInaccessibleMemOrArgMem();
const int fft_rank = fft_length.size();
- b_.CreateCall(
- fft_func,
- {GetExecutableRunOptionsArgument(),
- b_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type),
- b_.CreateBitCast(operand_address, int8_ptr_type),
- b_.getInt32(fft->fft_type()), b_.getInt32(fft_rank),
- b_.getInt64(input_batch), b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
- b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
- b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)});
+ Call(fft_func,
+ {GetExecutableRunOptionsArgument(),
+ BitCast(GetEmittedValueFor(fft), int8_ptr_type),
+ BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()),
+ b_.getInt32(fft_rank), b_.getInt64(input_batch),
+ b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
+ b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
+ b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)});
return Status::OK();
}
@@ -1205,8 +1198,8 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape));
// TODO(b/63762267): Be more aggressive about specifying alignment.
- b_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
- /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape));
+ MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
+ /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape));
}
llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_, module_);
return Status::OK();
@@ -1465,19 +1458,19 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
accumulator_shard_type, "accumulator", &b_, 0));
}
- llvm::Value* init_value_ssa = b_.CreateLoad(GetEmittedValueFor(init_value));
+ llvm::Value* init_value_ssa = Load(GetEmittedValueFor(init_value));
for (llvm::Value* accumulator_shard : accumulator) {
llvm::Value* initial_value;
auto shard_type = accumulator_shard->getType()->getPointerElementType();
if (auto vector_type = llvm::dyn_cast<llvm::VectorType>(shard_type)) {
initial_value =
- b_.CreateVectorSplat(vector_type->getNumElements(), init_value_ssa);
+ VectorSplat(vector_type->getNumElements(), init_value_ssa);
} else {
initial_value = init_value_ssa;
}
- b_.CreateAlignedStore(initial_value, accumulator_shard, element_alignment);
+ AlignedStore(initial_value, accumulator_shard, element_alignment);
}
llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"),
@@ -1499,24 +1492,24 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
}
CHECK(output_index.end() == it);
- llvm::Value* input_address = b_.CreateBitCast(
+ llvm::Value* input_address = BitCast(
arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy());
for (int i = 0; i < accumulator.size(); i++) {
auto input_address_typed =
- b_.CreateBitCast(input_address, accumulator[i]->getType());
+ BitCast(input_address, accumulator[i]->getType());
auto current_accumulator_value =
- b_.CreateAlignedLoad(accumulator[i], element_alignment);
- auto addend = b_.CreateAlignedLoad(input_address_typed, element_alignment);
+ AlignedLoad(accumulator[i], element_alignment);
+ auto addend = AlignedLoad(input_address_typed, element_alignment);
arg_array.AnnotateLoadStoreInstructionWithMetadata(addend);
auto reduced_result =
reduction_generator(&b_, current_accumulator_value, addend);
- b_.CreateAlignedStore(reduced_result, accumulator[i], element_alignment);
+ AlignedStore(reduced_result, accumulator[i], element_alignment);
if (i != (accumulator.size() - 1)) {
- input_address = b_.CreateConstInBoundsGEP1_32(reduced_result->getType(),
- input_address_typed, 1);
+ input_address = ConstInBoundsGEP1_32(reduced_result->getType(),
+ input_address_typed, 1);
}
}
@@ -1525,8 +1518,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
ShardedVector result_ssa;
result_ssa.reserve(accumulator.size());
for (auto accumulator_shard : accumulator) {
- result_ssa.push_back(
- b_.CreateAlignedLoad(accumulator_shard, element_alignment));
+ result_ssa.push_back(AlignedLoad(accumulator_shard, element_alignment));
}
return result_ssa;
}
@@ -1535,18 +1527,18 @@ void IrEmitter::EmitShardedVectorStore(
llvm::Value* store_address, const std::vector<llvm::Value*>& value_to_store,
const int alignment, const llvm_ir::IrArray& containing_array) {
for (int i = 0; i < value_to_store.size(); i++) {
- auto store_address_typed = b_.CreateBitCast(
- store_address,
- llvm::PointerType::getUnqual(value_to_store[i]->getType()));
+ auto store_address_typed =
+ BitCast(store_address,
+ llvm::PointerType::getUnqual(value_to_store[i]->getType()));
- auto store_instruction = b_.CreateAlignedStore(
- value_to_store[i], store_address_typed, alignment);
+ auto store_instruction =
+ AlignedStore(value_to_store[i], store_address_typed, alignment);
containing_array.AnnotateLoadStoreInstructionWithMetadata(
store_instruction);
if (i != (value_to_store.size() - 1)) {
- store_address = b_.CreateConstInBoundsGEP1_32(
- value_to_store[i]->getType(), store_address_typed, 1);
+ store_address = ConstInBoundsGEP1_32(value_to_store[i]->getType(),
+ store_address_typed, 1);
}
}
}
@@ -1711,8 +1703,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator",
&b_, MinimumAlignmentForPrimitiveType(accumulator_type));
llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
- llvm::Value* load_init_value = b_.CreateLoad(init_value_addr);
- b_.CreateStore(load_init_value, accumulator_addr);
+ llvm::Value* load_init_value = Load(init_value_addr);
+ Store(load_init_value, accumulator_addr);
// The enclosing loops go over all the target elements. Now we have to compute
// the actual target element. For this, we build a new loop nest to iterate
@@ -1745,12 +1737,12 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
// Apply the reduction function to the loaded value.
llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_);
llvm::Value* result = EmitThreadLocalCall(
- *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element},
+ *reduce->to_apply(), {Load(accumulator_addr), input_element},
"reduce_function");
- b_.CreateStore(result, accumulator_addr);
+ Store(result, accumulator_addr);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
- return b_.CreateLoad(accumulator_addr);
+ return Load(accumulator_addr);
}
Status IrEmitter::HandleReduce(HloInstruction* reduce) {
@@ -1988,7 +1980,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) {
[this, pad](const llvm_ir::IrArray::Index& target_index) {
const HloInstruction* padding_value = pad->operand(1);
llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value);
- return b_.CreateLoad(padding_value_addr);
+ return Load(padding_value_addr);
}));
// Create a loop to iterate over the operand elements and update the output
@@ -2010,10 +2002,10 @@ Status IrEmitter::HandlePad(HloInstruction* pad) {
const PaddingConfig& padding_config = pad->padding_config();
llvm_ir::IrArray::Index output_index(operand_index.GetType());
for (size_t i = 0; i < operand_index.size(); ++i) {
- llvm::Value* offset = b_.CreateMul(
- operand_index[i],
- b_.getInt64(padding_config.dimensions(i).interior_padding() + 1));
- llvm::Value* index = b_.CreateAdd(
+ llvm::Value* offset =
+ Mul(operand_index[i],
+ b_.getInt64(padding_config.dimensions(i).interior_padding() + 1));
+ llvm::Value* index = Add(
offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low()));
output_index.push_back(index);
}
@@ -2124,10 +2116,10 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
for (size_t i = 0; i < operands.size(); ++i) {
const HloInstruction* operand = operands[i];
llvm::Value* operand_as_i8ptr =
- b_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type);
+ PointerCast(GetEmittedValueFor(operand), i8_ptr_type);
llvm::Value* slot_in_operands_alloca =
- b_.CreateInBoundsGEP(operands_alloca, {b_.getInt64(i)});
- b_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca);
+ InBoundsGEP(operands_alloca, {b_.getInt64(i)});
+ Store(operand_as_i8ptr, slot_in_operands_alloca);
}
auto* custom_call_ir_function =
llvm::cast<llvm::Function>(module_->getOrInsertFunction(
@@ -2139,9 +2131,9 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
auto* output_address_arg =
- b_.CreatePointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
+ PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
- b_.CreateCall(custom_call_ir_function, {output_address_arg, operands_alloca});
+ Call(custom_call_ir_function, {output_address_arg, operands_alloca});
return Status::OK();
}
@@ -2200,15 +2192,14 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
llvm::BasicBlock* header_bb = llvm::BasicBlock::Create(
module_->getContext(), AsStringRef(IrName(xla_while, "header")),
compute_function_->function());
- b_.CreateBr(header_bb);
+ Br(header_bb);
b_.SetInsertPoint(header_bb);
// Calls the condition function to determine whether to proceed with the
// body. It must return a bool, so use the scalar call form.
EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond"));
- llvm::Value* while_predicate = b_.CreateICmpNE(
- b_.CreateLoad(
- GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
+ llvm::Value* while_predicate = ICmpNE(
+ Load(GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
// Branches to the body or to the while exit depending on the condition.
@@ -2217,7 +2208,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
compute_function_->function());
llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create(
module_->getContext(), AsStringRef(IrName(xla_while, "exit")));
- b_.CreateCondBr(while_predicate, body_bb, exit_bb);
+ CondBr(while_predicate, body_bb, exit_bb);
// Calls the body function from the body block.
b_.SetInsertPoint(body_bb);
@@ -2226,7 +2217,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body"));
// Finishes with a branch back to the header.
- b_.CreateBr(header_bb);
+ Br(header_bb);
// Adds the exit block to the function and sets the insert point there.
compute_function_->function()->getBasicBlockList().push_back(exit_bb);
@@ -2273,7 +2264,6 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
output_min2maj.end());
llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
- llvm::Type* i8_type = b_.getInt8Ty();
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate));
llvm_ir::IrArray target_array = GetIrArrayFor(concatenate);
@@ -2296,9 +2286,9 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
// Contiguous subregions from each operand to the concatenate contribute to a
// contiguous subregion in the target buffer starting at target_region_begin.
llvm::Value* target_region_begin =
- b_.CreateBitCast(target_array.EmitArrayElementAddress(
- outer_dims_index, &b_, "target_region"),
- i8_ptr_type);
+ BitCast(target_array.EmitArrayElementAddress(outer_dims_index, &b_,
+ "target_region"),
+ i8_ptr_type);
int64 byte_offset_into_target_region = 0;
int64 inner_dims_product =
@@ -2312,13 +2302,12 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
for (HloInstruction* operand : operands) {
const Shape& input_shape = operand->shape();
llvm_ir::IrArray source_array = GetIrArrayFor(operand);
- llvm::Value* copy_source_address = b_.CreateBitCast(
+ llvm::Value* copy_source_address = BitCast(
source_array.EmitArrayElementAddress(outer_dims_index, &b_, "src_addr"),
i8_ptr_type);
llvm::Value* copy_target_address =
- b_.CreateGEP(i8_type, target_region_begin,
- b_.getInt64(byte_offset_into_target_region));
+ GEP(target_region_begin, b_.getInt64(byte_offset_into_target_region));
EmitTransferElements(
copy_target_address, copy_source_address,
@@ -2350,15 +2339,15 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_));
if (element_count == 1) {
- auto* load_instruction = b_.CreateAlignedLoad(
- b_.CreateBitCast(source, primitive_ptr_type), element_alignment);
+ auto* load_instruction =
+ AlignedLoad(BitCast(source, primitive_ptr_type), element_alignment);
source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction);
- auto* store_instruction = b_.CreateAlignedStore(
- load_instruction, b_.CreateBitCast(target, primitive_ptr_type),
- element_alignment);
+ auto* store_instruction =
+ AlignedStore(load_instruction, BitCast(target, primitive_ptr_type),
+ element_alignment);
target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
} else {
- auto* memcpy_instruction = b_.CreateMemCpy(
+ auto* memcpy_instruction = MemCpy(
target, /*DstAlign=*/element_alignment, source,
/*SrcAlign=*/element_alignment, element_count * primitive_type_size);
@@ -2420,9 +2409,9 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
// cond_result = true_computation(true_operand)
// else
// cond_result = false_computation(false_operand)
- llvm::LoadInst* pred_value = b_.CreateLoad(
- GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value");
- llvm::Value* pred_cond = b_.CreateICmpNE(
+ llvm::LoadInst* pred_value =
+ Load(GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value");
+ llvm::Value* pred_cond = ICmpNE(
pred_value,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
"boolean_predicate");
@@ -2509,8 +2498,8 @@ llvm::Value* IrEmitter::GetProfileCounterCommon(
int64 prof_counter_idx = it->second;
string counter_name = IrName("prof_counter", hlo.name());
- return b_.CreateGEP(GetProfileCountersArgument(),
- b_.getInt64(prof_counter_idx), AsStringRef(counter_name));
+ return GEP(GetProfileCountersArgument(), b_.getInt64(prof_counter_idx),
+ AsStringRef(counter_name));
}
void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b,
@@ -2664,8 +2653,7 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
llvm::Value* params = compute_function_->parameters_arg();
llvm::Value* param_address_offset =
llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
- llvm::LoadInst* param_address_untyped =
- b_.CreateLoad(param_address_offset);
+ llvm::LoadInst* param_address_untyped = Load(param_address_offset);
if (!ShapeUtil::IsOpaque(target_shape)) {
AttachAlignmentMetadataForLoad(param_address_untyped, target_shape);
@@ -2693,8 +2681,7 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
}
return buf_it->second;
}();
- return b_.CreateBitCast(tempbuf_address,
- IrShapeType(target_shape)->getPointerTo());
+ return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo());
}
llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
@@ -2702,7 +2689,7 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
GetTempBuffersArgument(), slice.index(), &b_);
- llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr);
+ llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr);
if (hlo_module_config_.debug_options()
.xla_llvm_enable_invariant_load_metadata()) {
tempbuf_address_base->setMetadata(
@@ -2716,10 +2703,10 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
if (slice.offset() > 0) {
// Adjust the address to account for the slice offset.
tempbuf_address_untyped =
- b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
+ InBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
}
- return b_.CreateBitCast(tempbuf_address_untyped,
- IrShapeType(target_shape)->getPointerTo());
+ return BitCast(tempbuf_address_untyped,
+ IrShapeType(target_shape)->getPointerTo());
}
llvm::Value* IrEmitter::EmitTempBufferPointer(
@@ -2805,8 +2792,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source,
llvm::Value* destination_value = GetEmittedValueFor(&destination);
int64 source_size = ByteSizeOf(source.shape());
// TODO(b/63762267): Be more aggressive about specifying alignment.
- b_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value,
- /*SrcAlign=*/1, source_size);
+ MemCpy(destination_value, /*DstAlign=*/1, source_value,
+ /*SrcAlign=*/1, source_size);
return Status::OK();
}
@@ -2860,7 +2847,7 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
CHECK(!parameter->getType()->isPointerTy());
llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry(
parameter->getType(), "arg_addr", &b_);
- b_.CreateStore(parameter, parameter_addr);
+ Store(parameter, parameter_addr);
parameter_addrs.push_back(parameter_addr);
}
@@ -2869,29 +2856,28 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
absl::StrCat(name, "_retval_addr"), &b_,
MinimumAlignmentForPrimitiveType(return_type));
- b_.CreateCall(
- FindOrDie(emitted_functions_, &callee),
- GetArrayFunctionCallArguments(
- parameter_addrs, &b_, name,
- /*return_value_buffer=*/return_value_buffer,
- /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/
- llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
- /*profile_counters_arg=*/GetProfileCountersArgument()));
+ Call(FindOrDie(emitted_functions_, &callee),
+ GetArrayFunctionCallArguments(
+ parameter_addrs, &b_, name,
+ /*return_value_buffer=*/return_value_buffer,
+ /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
+ /*temp_buffers_arg=*/
+ llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
+ /*profile_counters_arg=*/GetProfileCountersArgument()));
- return b_.CreateLoad(return_value_buffer);
+ return Load(return_value_buffer);
}
void IrEmitter::EmitGlobalCall(const HloComputation& callee,
absl::string_view name) {
- b_.CreateCall(FindOrDie(emitted_functions_, &callee),
- GetArrayFunctionCallArguments(
- /*parameter_addresses=*/{}, &b_, name,
- /*return_value_buffer=*/
- llvm::Constant::getNullValue(b_.getInt8PtrTy()),
- /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/GetTempBuffersArgument(),
- /*profile_counters_arg=*/GetProfileCountersArgument()));
+ Call(FindOrDie(emitted_functions_, &callee),
+ GetArrayFunctionCallArguments(
+ /*parameter_addresses=*/{}, &b_, name,
+ /*return_value_buffer=*/
+ llvm::Constant::getNullValue(b_.getInt8PtrTy()),
+ /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
+ /*temp_buffers_arg=*/GetTempBuffersArgument(),
+ /*profile_counters_arg=*/GetProfileCountersArgument()));
}
llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 99c080b3db..ec68710d3f 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -55,7 +56,8 @@ namespace cpu {
// This class is the top-level API for the XLA HLO --> LLVM IR compiler. It
// implements the DfsHloVisitor interface and emits HLO computations as LLVM IR
// functions.
-class IrEmitter : public DfsHloVisitorWithDefault {
+class IrEmitter : public DfsHloVisitorWithDefault,
+ public IrBuilderMixin<IrEmitter> {
public:
// Create a new LLVM IR emitter.
//
@@ -100,6 +102,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
llvm::IRBuilder<>* b() { return &b_; }
+ // builder() is for IrBuilderMixin.
+ llvm::IRBuilder<>* builder() { return &b_; }
+
// Emit an LLVM global variable for every constant buffer allocation.
Status EmitConstantGlobals();