aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/ir_emitter.cc')
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc1642
1 files changed, 801 insertions, 841 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 6b66a4b0b7..09909b62ba 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -51,10 +51,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
-#include "tensorflow/compiler/xla/service/llvm_ir/ops.h"
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -89,14 +90,14 @@ IrEmitter::IrEmitter(
: assignment_(assignment),
module_(llvm_module),
arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()),
- ir_builder_(llvm_module->getContext()),
+ b_(llvm_module->getContext()),
instruction_to_profile_idx_(std::move(instruction_to_profile_idx)),
computation_to_profile_idx_(std::move(computation_to_profile_idx)),
alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
hlo_module_config_(hlo_module.config()),
is_top_level_computation_(false),
target_machine_features_(*target_machine_features) {
- ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags(
+ b_.setFastMathFlags(llvm_ir::GetFastMathFlags(
/*fast_math_enabled=*/hlo_module_config_.debug_options()
.xla_enable_fast_math()));
}
@@ -115,6 +116,19 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
computation->root_instruction()->outer_dimension_partitions().size();
}
+ if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) {
+ TF_ASSIGN_OR_RETURN(
+ computation_root_allocation_,
+ assignment_.GetUniqueTopLevelSlice(computation->root_instruction()));
+ }
+
+ for (const HloInstruction* param : computation->parameter_instructions()) {
+ TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice,
+ assignment_.GetUniqueTopLevelSlice(param));
+ computation_parameter_allocations_[param_slice.allocation()->index()] =
+ param->parameter_number();
+ }
+
InitializeIrFunction(function_name);
// The rdtscp instruction is x86 specific. We will fallback to LLVM's generic
// readcyclecounter if it is unavailable.
@@ -131,6 +145,8 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
// Delete 'compute_function', finalizing 'ir_function' and restoring caller
// IR insert point.
compute_function_.reset();
+ computation_root_allocation_ = BufferAllocation::Slice();
+ computation_parameter_allocations_.clear();
return ir_function;
}
@@ -146,7 +162,7 @@ void IrEmitter::InitializeIrFunction(const string& function_name) {
new IrFunction(function_name, linkage,
options::OptimizeForSizeRequested(hlo_module_config_),
hlo_module_config_.debug_options().xla_enable_fast_math(),
- module_, &ir_builder_, num_dynamic_loop_bounds_));
+ module_, &b_, num_dynamic_loop_bounds_));
}
IrEmitter::~IrEmitter() {}
@@ -154,9 +170,9 @@ IrEmitter::~IrEmitter() {}
Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
VLOG(2) << "HandleBitcast: " << bitcast->ToString();
emitted_value_[bitcast] =
- ir_builder_.CreateBitCast(GetEmittedValueFor(bitcast->operand(0)),
- IrShapeType(bitcast->shape())->getPointerTo(),
- AsStringRef(IrName(bitcast)));
+ b_.CreateBitCast(GetEmittedValueFor(bitcast->operand(0)),
+ IrShapeType(bitcast->shape())->getPointerTo(),
+ AsStringRef(IrName(bitcast)));
return Status::OK();
}
@@ -175,25 +191,36 @@ llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
result_global, IrShapeType(literal.shape())->getPointerTo());
}
-Status IrEmitter::HandleConstant(HloInstruction* constant) {
- VLOG(2) << "HandleConstant: " << constant->ToString();
- const Literal& literal = constant->literal();
- llvm::Constant* global_for_const;
+Status IrEmitter::EmitConstantGlobals() {
+ for (const BufferAllocation& allocation : assignment_.Allocations()) {
+ if (!allocation.is_constant()) {
+ continue;
+ }
- auto it = emitted_literals_.find(&literal);
- if (it != emitted_literals_.end()) {
- global_for_const = it->second;
- } else {
- global_for_const = EmitGlobalForLiteral(literal);
- emitted_literals_[&literal] = global_for_const;
+ const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation);
+ llvm::Constant* global_for_const;
+ auto it = emitted_literals_.find(&literal);
+ if (it != emitted_literals_.end()) {
+ global_for_const = it->second;
+ } else {
+ global_for_const = EmitGlobalForLiteral(literal);
+ InsertOrDie(&emitted_literals_, &literal, global_for_const);
+ }
+
+ InsertOrDie(&constant_buffer_to_global_, allocation.index(),
+ global_for_const);
}
- emitted_value_[constant] = global_for_const;
- VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*global_for_const);
- VLOG(2) << " its type: "
- << llvm_ir::DumpToString(*global_for_const->getType());
+
return Status::OK();
}
+Status IrEmitter::HandleConstant(HloInstruction* constant) {
+ VLOG(2) << "HandleConstant: " << constant->ToString();
+ // IrEmitter::EmitConstantGlobals has already taken care of emitting the body
+ // of the constant.
+ return EmitTargetAddressForOp(constant);
+}
+
Status IrEmitter::HandleCopy(HloInstruction* copy) {
if (ShapeUtil::IsTuple(copy->shape())) {
// kCopy shallow copies a tuple so just memcpy the top-level buffer.
@@ -273,27 +300,30 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
const Shape& shape = get_tuple_element->shape();
emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
- GetEmittedValueFor(operand), &ir_builder_, module_);
+ GetEmittedValueFor(operand), &b_, module_);
return Status::OK();
}
Status IrEmitter::HandleSelect(HloInstruction* select) {
auto pred = select->operand(0);
- auto on_true = select->operand(1);
- auto on_false = select->operand(2);
TF_RET_CHECK(pred->shape().element_type() == PRED);
-
- if (ShapeUtil::IsTuple(select->shape())) {
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(select));
- llvm_ir::EmitTupleSelect(
- GetIrArrayFor(select), GetIrArrayFor(pred), GetEmittedValueFor(on_true),
- GetEmittedValueFor(on_false), &ir_builder_, module_);
- return Status::OK();
- }
-
return DefaultAction(select);
}
+Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
+ auto pred = tuple_select->operand(0);
+ auto on_true = tuple_select->operand(1);
+ auto on_false = tuple_select->operand(2);
+ TF_RET_CHECK(pred->shape().element_type() == PRED);
+ TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()));
+ TF_RET_CHECK(ShapeUtil::IsTuple(tuple_select->shape()));
+ TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select));
+ llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred),
+ GetEmittedValueFor(on_true),
+ GetEmittedValueFor(on_false), &b_, module_);
+ return Status::OK();
+}
+
Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
VLOG(2) << "HandleInfeed: " << infeed->ToString();
@@ -313,8 +343,8 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
assignment_.GetUniqueSlice(infeed, {1}));
llvm::Value* token_address = EmitTempBufferPointer(
token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1));
- llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address},
- &ir_builder_, module_);
+ llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_,
+ module_);
if (ShapeUtil::IsTuple(data_shape)) {
TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape));
@@ -345,7 +375,7 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
}
llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape),
- tuple_element_addresses, &ir_builder_, module_);
+ tuple_element_addresses, &b_, module_);
} else {
TF_RETURN_IF_ERROR(
EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address));
@@ -366,14 +396,14 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
int32 length_32 = static_cast<int32>(length);
int32 shape_length;
- TF_ASSIGN_OR_RETURN(llvm::Value * shape_ptr,
- llvm_ir::EncodeSelfDescribingShapeConstant(
- shape, &shape_length, &ir_builder_));
+ TF_ASSIGN_OR_RETURN(
+ llvm::Value * shape_ptr,
+ llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_));
// The signature of the acquire infeed buffer function is:
//
// (void*)(int32 length);
- llvm::Type* int32_type = ir_builder_.getInt32Ty();
+ llvm::Type* int32_type = b_.getInt32Ty();
llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
llvm::FunctionType* acquire_type = llvm::FunctionType::get(
i8_ptr_type, {int32_type, i8_ptr_type, int32_type},
@@ -393,8 +423,7 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
//
// (void)(int32 length, void* buffer);
llvm::FunctionType* release_type = llvm::FunctionType::get(
- ir_builder_.getVoidTy(),
- {int32_type, i8_ptr_type, i8_ptr_type, int32_type},
+ b_.getVoidTy(), {int32_type, i8_ptr_type, i8_ptr_type, int32_type},
/*isVarArg=*/false);
llvm::Function* release_func;
@@ -411,25 +440,22 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
// of size exactly 'length_32', and the runtime is responsible for
// check-failing the process if there is a mismatch, versus passing us back a
// buffer that we might overrun.
- llvm::Value* acquired_pointer = ir_builder_.CreateCall(
- acquire_func, {ir_builder_.getInt32(length_32), shape_ptr,
- ir_builder_.getInt32(shape_length)});
+ llvm::Value* acquired_pointer = b_.CreateCall(
+ acquire_func,
+ {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)});
if (kind == XfeedKind::kInfeed) {
// Copy to the program buffer address from the acquired buffer.
- ir_builder_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1,
- acquired_pointer,
- /*SrcAlign=*/1, length_32);
+ b_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer,
+ /*SrcAlign=*/1, length_32);
} else {
// Outfeed -- copy from the in-program address to the acquired buffer.
- ir_builder_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1,
- program_buffer_address,
- /*SrcAlign=*/1, length_32);
+ b_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address,
+ /*SrcAlign=*/1, length_32);
}
- ir_builder_.CreateCall(release_func,
- {ir_builder_.getInt32(length_32), acquired_pointer,
- shape_ptr, ir_builder_.getInt32(shape_length)});
+ b_.CreateCall(release_func, {b_.getInt32(length_32), acquired_pointer,
+ shape_ptr, b_.getInt32(shape_length)});
return Status::OK();
}
@@ -450,7 +476,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
ShapeUtil::GetTupleElementShape(operand_shape, i);
llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement(
tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape),
- value, &ir_builder_, module_);
+ value, &b_, module_);
TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed,
tuple_element_shape, tuple_element));
}
@@ -469,46 +495,96 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
for (auto operand : tuple->operands()) {
base_ptrs.push_back(GetEmittedValueFor(operand));
}
- llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &ir_builder_, module_);
+ llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_, module_);
return Status::OK();
}
-Status IrEmitter::HandleMap(HloInstruction* map) {
- gtl::ArraySlice<HloInstruction*> operands(map->operands());
- HloComputation* function = map->to_apply();
- // The called computation should have been emitted previously.
- llvm::Function* mapped_ir_function = FindOrDie(emitted_functions_, function);
+llvm::Value* IrEmitter::EmitElementalMap(
+ const HloMapInstruction& map_instr,
+ tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
+ tensorflow::StringPiece name) {
+ return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name);
+}
- return EmitTargetElementLoop(map, [this, map, operands, mapped_ir_function](
- const llvm_ir::IrArray::Index& index) {
- std::vector<llvm::Value*> parameter_addresses;
- for (const HloInstruction* operand : operands) {
- const llvm_ir::IrArray& array = GetIrArrayFor(operand);
- parameter_addresses.push_back(
- array.EmitArrayElementAddress(index, &ir_builder_));
+StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
+ HloReduceWindowInstruction* reduce_window,
+ const llvm_ir::IrArray::Index& index) {
+ const HloInstruction* operand = reduce_window->operand(0);
+ const Window& window = reduce_window->window();
+
+ // We fold inputs into the accumulator and initialize it to
+ // the initial value on the reduce_window.
+ PrimitiveType operand_element_type = operand->shape().element_type();
+ llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
+ "reduce_window_accumulator_address", &b_,
+ MinimumAlignmentForPrimitiveType(operand_element_type));
+ b_.CreateStore(b_.CreateLoad(GetEmittedValueFor(reduce_window->operand(1))),
+ accumulator_address);
+
+ llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_);
+ std::vector<int64> window_size;
+ for (const auto& dim : window.dimensions()) {
+ window_size.push_back(dim.size());
+ }
+ const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape(
+ ShapeUtil::MakeShape(operand_element_type, window_size), "window");
+ CHECK_EQ(window_index.size(), index.size());
+
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
+
+ llvm_ir::IrArray::Index input_index(b_.getInt64Ty(), index.size());
+ llvm::Value* in_bounds_condition = nullptr;
+ for (size_t i = 0; i < index.size(); ++i) {
+ llvm::Value* strided_index =
+ b_.CreateNSWMul(index[i], b_.getInt64(window.dimensions(i).stride()));
+ input_index[i] =
+ b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]),
+ b_.getInt64(window.dimensions(i).padding_low()));
+
+ // We need to check if 0 <= input_index[i] < bound, as otherwise we are in
+ // the padding so that we can skip the computation. That is equivalent to
+ // input_index[i] < bound as an *unsigned* comparison, since a negative
+ // value will wrap to a large positive value.
+ llvm::Value* index_condition = b_.CreateICmpULT(
+ input_index[i],
+ b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
+ if (in_bounds_condition == nullptr) {
+ in_bounds_condition = index_condition;
+ } else {
+ in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition);
}
- return EmitElementFunctionCall(mapped_ir_function, map->shape(),
- parameter_addresses, "map_function");
- });
+ }
+ CHECK(in_bounds_condition != nullptr);
+
+ llvm_ir::LlvmIfData if_data =
+ llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
+ SetToFirstInsertPoint(if_data.true_block, &b_);
+
+ // We are not in the padding, so carry out the computation.
+ llvm_ir::IrArray input_array(GetIrArrayFor(operand));
+ llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_);
+ llvm::Value* result = EmitThreadLocalCall(
+ *reduce_window->to_apply(),
+ {b_.CreateLoad(accumulator_address), input_value}, "reducer_function");
+ b_.CreateStore(result, accumulator_address);
+
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
+ return b_.CreateLoad(accumulator_address);
}
Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
- auto operand = reduce_window->operand(0);
- const Window& window = reduce_window->window();
- HloComputation* function = reduce_window->to_apply();
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
- /*instruction=*/*reduce_window, /*operands=*/{operand},
+ /*instruction=*/*reduce_window,
+ /*operands=*/{reduce_window->operand(0)},
/*supported_types=*/{F32, BF16, S32}));
// TODO(b/31410564): Implement dilation for reduce-window.
- if (window_util::HasDilation(window)) {
+ if (window_util::HasDilation(reduce_window->window())) {
return Unimplemented(
"Dilation for ReduceWindow is not implemented on CPU.");
}
- // The called computation should have been emitted previously.
- llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
-
// Pseudo code for reduce window:
//
// for (coordinates O in the output)
@@ -523,73 +599,9 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
// This is completely un-optimized and just here to have something
// that works.
return EmitTargetElementLoop(
- reduce_window, [this, reduce_window, operand, window,
- reducer_function](const llvm_ir::IrArray::Index& index) {
- // We fold inputs into the accumulator and initialize it to
- // the initial value on the reduce_window.
- PrimitiveType operand_element_type = operand->shape().element_type();
- llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry(
- llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
- "reduce_window_accumulator_address", &ir_builder_,
- MinimumAlignmentForPrimitiveType(operand_element_type));
- ir_builder_.CreateStore(ir_builder_.CreateLoad(GetEmittedValueFor(
- reduce_window->operand(1))),
- accumulator_address);
-
- llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"),
- &ir_builder_);
- std::vector<int64> window_size;
- for (const auto& dim : window.dimensions()) {
- window_size.push_back(dim.size());
- }
- const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape(
- ShapeUtil::MakeShape(operand_element_type, window_size), "window");
- CHECK_EQ(window_index.size(), index.size());
-
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
-
- llvm_ir::IrArray::Index input_index(ir_builder_.getInt64Ty(),
- index.size());
- llvm::Value* in_bounds_condition = nullptr;
- for (size_t i = 0; i < index.size(); ++i) {
- llvm::Value* strided_index = ir_builder_.CreateNSWMul(
- index[i], ir_builder_.getInt64(window.dimensions(i).stride()));
- input_index[i] = ir_builder_.CreateNSWSub(
- ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
- ir_builder_.getInt64(window.dimensions(i).padding_low()));
-
- // We need to check if 0 <= input_index[i] < bound, as
- // otherwise we are in the padding so that we can skip the
- // computation. That is equivalent to input_index[i] < bound
- // as an *unsigned* comparison, since a negative value will
- // wrap to a large positive value.
- llvm::Value* index_condition = ir_builder_.CreateICmpULT(
- input_index[i], ir_builder_.getInt64(ShapeUtil::GetDimension(
- operand->shape(), i)));
- if (in_bounds_condition == nullptr) {
- in_bounds_condition = index_condition;
- } else {
- in_bounds_condition =
- ir_builder_.CreateAnd(in_bounds_condition, index_condition);
- }
- }
- CHECK(in_bounds_condition != nullptr);
-
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- in_bounds_condition, "in-bounds", &ir_builder_);
- SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
-
- // We are not in the padding, so carry out the computation.
- llvm_ir::IrArray input_array(GetIrArrayFor(operand));
- llvm::Value* input_value_address =
- input_array.EmitArrayElementAddress(input_index, &ir_builder_);
- llvm::Value* result = EmitElementFunctionCall(
- reducer_function, reduce_window->shape(),
- {accumulator_address, input_value_address}, "reducer_function");
- ir_builder_.CreateStore(result, accumulator_address);
-
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
- return ir_builder_.CreateLoad(accumulator_address);
+ reduce_window, [&](const llvm_ir::IrArray::Index& index) {
+ return EmitTargetElementLoopBodyForReduceWindow(
+ Cast<HloReduceWindowInstruction>(reduce_window), index);
});
}
@@ -610,12 +622,6 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
"Dilation for SelectAndScatter is not implemented on CPU. ");
}
- // The select and scatter computations should have been emitted previously.
- llvm::Function* select_function =
- FindOrDie(emitted_functions_, select_and_scatter->select());
- llvm::Function* scatter_function =
- FindOrDie(emitted_functions_, select_and_scatter->scatter());
-
// Pseudo code for select-and-scatter:
//
// initialized_flag is initially off for every window, and is turned on after
@@ -641,141 +647,128 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"),
[this, init_value](const llvm_ir::IrArray::Index& target_index) {
llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
- return ir_builder_.CreateLoad(init_value_addr);
+ return b_.CreateLoad(init_value_addr);
}));
// Create a loop to iterate over the source array to scatter to the output.
- llvm_ir::ForLoopNest source_loops(IrName(select_and_scatter), &ir_builder_);
+ llvm_ir::ForLoopNest source_loops(IrName(select_and_scatter), &b_);
const llvm_ir::IrArray::Index source_index =
source_loops.AddLoopsForShape(source->shape(), "source");
- SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(),
- &ir_builder_);
+ SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(), &b_);
// Allocate space to keep the currently selected value, its index, and
// the boolean initialized_flag, which is initially set to false.
llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
- "selected_value_address", &ir_builder_,
+ "selected_value_address", &b_,
MinimumAlignmentForPrimitiveType(operand_element_type));
llvm::Value* selected_index_address =
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
- ir_builder_.getInt64Ty(), ir_builder_.getInt32(rank),
- "selected_index_address", &ir_builder_);
+ b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_);
llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
- ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_);
- ir_builder_.CreateStore(ir_builder_.getInt1(false), initialized_flag_address);
+ b_.getInt1Ty(), "initialized_flag_address", &b_);
+ b_.CreateStore(b_.getInt1(false), initialized_flag_address);
// Create the inner loop to iterate over the window.
- llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"),
- &ir_builder_);
+ llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_);
std::vector<int64> window_size;
for (const auto& dim : window.dimensions()) {
window_size.push_back(dim.size());
}
const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape(
ShapeUtil::MakeShape(operand_element_type, window_size), "window");
- SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
- &ir_builder_);
+ SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), &b_);
// Compute the operand index to visit and evaluate the condition whether the
// operand index is within the bounds. The unsigned comparison includes
// checking whether the operand index >= 0.
- llvm_ir::IrArray::Index operand_index(ir_builder_.getInt64Ty(),
- source_index.size());
- llvm::Value* in_bounds_condition = ir_builder_.getTrue();
+ llvm_ir::IrArray::Index operand_index(b_.getInt64Ty(), source_index.size());
+ llvm::Value* in_bounds_condition = b_.getTrue();
for (int64 i = 0; i < rank; ++i) {
- llvm::Value* strided_index = ir_builder_.CreateNSWMul(
- source_index[i], ir_builder_.getInt64(window.dimensions(i).stride()));
- operand_index[i] = ir_builder_.CreateNSWSub(
- ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
- ir_builder_.getInt64(window.dimensions(i).padding_low()));
- llvm::Value* index_condition = ir_builder_.CreateICmpULT(
+ llvm::Value* strided_index = b_.CreateNSWMul(
+ source_index[i], b_.getInt64(window.dimensions(i).stride()));
+ operand_index[i] =
+ b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]),
+ b_.getInt64(window.dimensions(i).padding_low()));
+ llvm::Value* index_condition = b_.CreateICmpULT(
operand_index[i],
- ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
- in_bounds_condition =
- ir_builder_.CreateAnd(in_bounds_condition, index_condition);
+ b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
+ in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition);
}
CHECK(in_bounds_condition != nullptr);
// Only need to do something if the operand index is within the bounds. First
// check if the initialized_flag is set.
llvm_ir::LlvmIfData if_in_bounds =
- llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_);
- SetToFirstInsertPoint(if_in_bounds.true_block, &ir_builder_);
- llvm_ir::LlvmIfData if_initialized =
- llvm_ir::EmitIfThenElse(ir_builder_.CreateLoad(initialized_flag_address),
- "initialized", &ir_builder_);
+ llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
+ SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
+ llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
+ b_.CreateLoad(initialized_flag_address), "initialized", &b_);
// If the initialized_flag is false, initialize the selected value and index
// with the currently visiting operand.
- SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_);
+ SetToFirstInsertPoint(if_initialized.false_block, &b_);
const auto save_operand_index =
[&](const llvm_ir::IrArray::Index& operand_index) {
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- ir_builder_.CreateInBoundsGEP(selected_index_address,
- {ir_builder_.getInt32(i)});
- ir_builder_.CreateStore(operand_index[i],
- selected_index_address_slot);
+ b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ b_.CreateStore(operand_index[i], selected_index_address_slot);
}
};
llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
llvm::Value* operand_data =
- operand_array.EmitReadArrayElement(operand_index, &ir_builder_);
- ir_builder_.CreateStore(operand_data, selected_value_address);
+ operand_array.EmitReadArrayElement(operand_index, &b_);
+ b_.CreateStore(operand_data, selected_value_address);
save_operand_index(operand_index);
- ir_builder_.CreateStore(ir_builder_.getInt1(true), initialized_flag_address);
+ b_.CreateStore(b_.getInt1(true), initialized_flag_address);
// If the initialized_flag is true, call the `select` function to potentially
// update the selected value and index with the currently visiting operand.
- SetToFirstInsertPoint(if_initialized.true_block, &ir_builder_);
- const Shape output_shape = ShapeUtil::MakeShape(PRED, {});
+ SetToFirstInsertPoint(if_initialized.true_block, &b_);
llvm::Value* operand_address =
- operand_array.EmitArrayElementAddress(operand_index, &ir_builder_);
- llvm::Value* result = EmitElementFunctionCall(
- select_function, output_shape, {selected_value_address, operand_address},
+ operand_array.EmitArrayElementAddress(operand_index, &b_);
+ llvm::Value* operand_element = b_.CreateLoad(operand_address);
+ llvm::Value* result = EmitThreadLocalCall(
+ *select_and_scatter->select(),
+ {b_.CreateLoad(selected_value_address), operand_element},
"select_function");
// If the 'select' function returns false, update the selected value and the
// index to the currently visiting operand.
- llvm::Value* cond = ir_builder_.CreateICmpNE(
+ llvm::Value* cond = b_.CreateICmpNE(
result,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
"boolean_predicate");
llvm_ir::LlvmIfData if_select_lhs =
- llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_);
- SetToFirstInsertPoint(if_select_lhs.false_block, &ir_builder_);
- ir_builder_.CreateStore(ir_builder_.CreateLoad(operand_address),
- selected_value_address);
+ llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
+ SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
+ b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address);
save_operand_index(operand_index);
// After iterating over the window elements, scatter the source element to
// the selected index of the output. The value we store at the output
// location is computed by calling the `scatter` function with the source
// value and the current output value.
- SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
- &ir_builder_);
+ SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &b_);
llvm_ir::IrArray::Index selected_index(source_index.GetType());
for (int64 i = 0; i < rank; ++i) {
- llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP(
- selected_index_address, {ir_builder_.getInt32(i)});
- selected_index.push_back(
- ir_builder_.CreateLoad(selected_index_address_slot));
+ llvm::Value* selected_index_address_slot =
+ b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ selected_index.push_back(b_.CreateLoad(selected_index_address_slot));
}
llvm_ir::IrArray source_array(GetIrArrayFor(source));
- llvm::Value* source_value_address =
- source_array.EmitArrayElementAddress(source_index, &ir_builder_);
+ llvm::Value* source_value =
+ source_array.EmitReadArrayElement(source_index, &b_);
llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter));
- llvm::Value* output_value_address =
- output_array.EmitArrayElementAddress(selected_index, &ir_builder_);
- llvm::Value* scatter_value = EmitElementFunctionCall(
- scatter_function, source->shape(),
- {output_value_address, source_value_address}, "scatter_function");
- output_array.EmitWriteArrayElement(selected_index, scatter_value,
- &ir_builder_);
-
- SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(),
- &ir_builder_);
+ llvm::Value* output_value =
+ output_array.EmitReadArrayElement(selected_index, &b_);
+ llvm::Value* scatter_value =
+ EmitThreadLocalCall(*select_and_scatter->scatter(),
+ {output_value, source_value}, "scatter_function");
+ output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_);
+
+ SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_);
return Status::OK();
}
@@ -814,21 +807,155 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
// Dot operation is complicated so we delegate to a helper class.
return DotOpEmitter::EmitDotOperation(
*dot, target_array, lhs_array, rhs_array, /*addend_array=*/nullptr,
- GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_,
+ GetExecutableRunOptionsArgument(), &b_, hlo_module_config_,
target_machine_features_);
}
+StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
+ HloConvolutionInstruction* convolution,
+ const llvm_ir::IrArray::Index& index) {
+ const HloInstruction* lhs = convolution->operand(0);
+ const HloInstruction* rhs = convolution->operand(1);
+ const Window& window = convolution->window();
+
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
+ int num_spatial_dims = dnums.output_spatial_dimensions_size();
+ std::vector<llvm::Value*> output_spatial(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
+ }
+ llvm::Value* output_feature = index[dnums.output_feature_dimension()];
+ llvm::Value* batch = index[dnums.output_batch_dimension()];
+
+ // We will accumulate the products into this sum to calculate the output entry
+ // at the given index.
+ PrimitiveType lhs_element_type = lhs->shape().element_type();
+ llvm::Type* lhs_llvm_type =
+ llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
+ llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ lhs_llvm_type, "convolution_sum_address", &b_,
+ MinimumAlignmentForPrimitiveType(lhs_element_type));
+ llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type);
+ b_.CreateStore(constant_zero, sum_address);
+
+ llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_);
+ std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ kernel_spatial[i] =
+ loops
+ .AddLoop(
+ 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
+ tensorflow::strings::StrCat("k", i))
+ ->GetIndVarValue();
+ }
+ llvm::Value* input_feature =
+ loops
+ .AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()),
+ "iz")
+ ->GetIndVarValue();
+
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
+
+ // Calculate the spatial index in the input array, taking striding, dilation
+ // and padding into account. An index in the padding will be out of the bounds
+ // of the array.
+ const auto calculate_input_index = [this](llvm::Value* output_index,
+ llvm::Value* kernel_index,
+ const WindowDimension& window_dim) {
+ llvm::Value* strided_index =
+ b_.CreateNSWMul(output_index, b_.getInt64(window_dim.stride()));
+ llvm::Value* dilated_kernel_index = b_.CreateNSWMul(
+ kernel_index, b_.getInt64(window_dim.window_dilation()));
+ return b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, dilated_kernel_index),
+ b_.getInt64(window_dim.padding_low()));
+ };
+ std::vector<llvm::Value*> input_spatial(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ input_spatial[i] = calculate_input_index(
+ output_spatial[i], kernel_spatial[i], window.dimensions(i));
+ }
+
+ // We need to check if 0 <= input dim < bound, as otherwise we are in the
+ // padding so that we can skip the computation. That is equivalent to input
+ // dim < bound as an *unsigned* comparison, since a negative value will wrap
+ // to a large positive value. The input dim is dilated, so we need to dilate
+ // the bound as well to match.
+
+ // Also need to check that the input coordinates are not in one of the
+ // holes created by base dilation.
+ const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
+ llvm::Value* remainder =
+ b_.CreateSRem(input_index, b_.getInt64(base_dilation));
+ return b_.CreateICmpEQ(remainder, b_.getInt64(0));
+ };
+
+ llvm::Value* in_bounds_condition = b_.getInt1(true);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound(
+ lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
+ window.dimensions(i).base_dilation()));
+ llvm::Value* dim_in_bound = b_.CreateICmpULT(input_spatial[i], input_bound);
+ llvm::Value* dim_not_in_hole =
+ not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
+ llvm::Value* dim_ok = b_.CreateAnd(dim_in_bound, dim_not_in_hole);
+ in_bounds_condition = b_.CreateAnd(in_bounds_condition, dim_ok);
+ }
+
+ // Now we need to map the dilated base coordinates back to the actual
+ // data indices on the lhs.
+ const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
+ return b_.CreateSDiv(input_index, b_.getInt64(base_dilation));
+ };
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ input_spatial[i] =
+ undilate(input_spatial[i], window.dimensions(i).base_dilation());
+ }
+
+ llvm_ir::LlvmIfData if_data =
+ llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
+ SetToFirstInsertPoint(if_data.true_block, &b_);
+
+ // We are not in the padding, so carry out the computation.
+ int num_dims = num_spatial_dims + 2;
+ llvm_ir::IrArray::Index input_index(b_.getInt64Ty(), num_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
+ }
+ input_index[dnums.input_feature_dimension()] = input_feature;
+ input_index[dnums.input_batch_dimension()] = batch;
+
+ llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs));
+ llvm_ir::IrArray::Index kernel_index(b_.getInt64Ty(), num_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ kernel_index[dnums.kernel_spatial_dimensions(i)] =
+ window.dimensions(i).window_reversal()
+ ? b_.CreateNSWSub(b_.getInt64(window.dimensions(i).size() - 1),
+ kernel_spatial[i])
+ : kernel_spatial[i];
+ }
+
+ kernel_index[dnums.kernel_input_feature_dimension()] = input_feature;
+ kernel_index[dnums.kernel_output_feature_dimension()] = output_feature;
+
+ llvm_ir::IrArray input_array(GetIrArrayFor(lhs));
+ llvm::Value* product =
+ b_.CreateFMul(input_array.EmitReadArrayElement(input_index, &b_),
+ kernel_array.EmitReadArrayElement(kernel_index, &b_));
+ llvm::Value* sum = b_.CreateFAdd(b_.CreateLoad(sum_address), product);
+ b_.CreateStore(sum, sum_address);
+
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
+ return b_.CreateLoad(sum_address);
+}
+
Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
auto lhs = convolution->operand(0);
auto rhs = convolution->operand(1);
- const auto& window = convolution->window();
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*convolution, /*operands=*/{lhs, rhs},
/*supported_types=*/{F16, F32, C64}));
- const ConvolutionDimensionNumbers& dnums =
- convolution->convolution_dimension_numbers();
-
// TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support
// different data layouts.
if (PotentiallyImplementedAsEigenConvolution(*convolution,
@@ -908,12 +1035,12 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
PrimitiveType primitive_type = lhs->shape().element_type();
llvm::Type* ir_ptr_type = primitive_type == F16
- ? ir_builder_.getHalfTy()->getPointerTo()
- : ir_builder_.getFloatTy()->getPointerTo();
- llvm::Type* int64_type = ir_builder_.getInt64Ty();
- llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo();
+ ? b_.getHalfTy()->getPointerTo()
+ : b_.getFloatTy()->getPointerTo();
+ llvm::Type* int64_type = b_.getInt64Ty();
+ llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo();
llvm::FunctionType* conv_type = llvm::FunctionType::get(
- ir_builder_.getVoidTy(),
+ b_.getVoidTy(),
{int8_ptr_type, ir_ptr_type, ir_ptr_type, ir_ptr_type, int64_type,
int64_type, int64_type, int64_type, int64_type, int64_type,
int64_type, int64_type, int64_type, int64_type, int64_type,
@@ -945,34 +1072,34 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
conv_func->setCallingConv(llvm::CallingConv::C);
conv_func->setDoesNotThrow();
conv_func->setOnlyAccessesArgMemory();
- ir_builder_.CreateCall(
- conv_func, {
- GetExecutableRunOptionsArgument(),
- ir_builder_.CreateBitCast(
- GetEmittedValueFor(convolution), ir_ptr_type),
- ir_builder_.CreateBitCast(lhs_address, ir_ptr_type),
- ir_builder_.CreateBitCast(rhs_address, ir_ptr_type),
- ir_builder_.getInt64(input_batch),
- ir_builder_.getInt64(input_rows),
- ir_builder_.getInt64(input_cols),
- ir_builder_.getInt64(input_channels),
- ir_builder_.getInt64(kernel_rows),
- ir_builder_.getInt64(kernel_cols),
- ir_builder_.getInt64(kernel_channels),
- ir_builder_.getInt64(kernel_filters),
- ir_builder_.getInt64(output_rows),
- ir_builder_.getInt64(output_cols),
- ir_builder_.getInt64(row_stride),
- ir_builder_.getInt64(col_stride),
- ir_builder_.getInt64(padding_top),
- ir_builder_.getInt64(padding_bottom),
- ir_builder_.getInt64(padding_left),
- ir_builder_.getInt64(padding_right),
- ir_builder_.getInt64(lhs_row_dilation),
- ir_builder_.getInt64(lhs_col_dilation),
- ir_builder_.getInt64(rhs_row_dilation),
- ir_builder_.getInt64(rhs_col_dilation),
- });
+ b_.CreateCall(
+ conv_func,
+ {
+ GetExecutableRunOptionsArgument(),
+ b_.CreateBitCast(GetEmittedValueFor(convolution), ir_ptr_type),
+ b_.CreateBitCast(lhs_address, ir_ptr_type),
+ b_.CreateBitCast(rhs_address, ir_ptr_type),
+ b_.getInt64(input_batch),
+ b_.getInt64(input_rows),
+ b_.getInt64(input_cols),
+ b_.getInt64(input_channels),
+ b_.getInt64(kernel_rows),
+ b_.getInt64(kernel_cols),
+ b_.getInt64(kernel_channels),
+ b_.getInt64(kernel_filters),
+ b_.getInt64(output_rows),
+ b_.getInt64(output_cols),
+ b_.getInt64(row_stride),
+ b_.getInt64(col_stride),
+ b_.getInt64(padding_top),
+ b_.getInt64(padding_bottom),
+ b_.getInt64(padding_left),
+ b_.getInt64(padding_right),
+ b_.getInt64(lhs_row_dilation),
+ b_.getInt64(lhs_col_dilation),
+ b_.getInt64(rhs_row_dilation),
+ b_.getInt64(rhs_col_dilation),
+ });
return Status::OK();
}
@@ -985,150 +1112,9 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
// See the description of convolution in the XLA documentation for the pseudo
// code for convolution.
return EmitTargetElementLoop(
- convolution, [this, convolution, lhs, rhs, window,
- dnums](const llvm_ir::IrArray::Index& index) {
- int num_spatial_dims = dnums.output_spatial_dimensions_size();
- std::vector<llvm::Value*> output_spatial(num_spatial_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
- }
- llvm::Value* output_feature = index[dnums.output_feature_dimension()];
- llvm::Value* batch = index[dnums.output_batch_dimension()];
-
- // We will accumulate the products into this sum to calculate
- // the output entry at the given index.
- PrimitiveType lhs_element_type = lhs->shape().element_type();
- llvm::Type* lhs_llvm_type =
- llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
- llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
- lhs_llvm_type, "convolution_sum_address", &ir_builder_,
- MinimumAlignmentForPrimitiveType(lhs_element_type));
- llvm::Value* constant_zero =
- llvm::Constant::getNullValue(lhs_llvm_type);
- ir_builder_.CreateStore(constant_zero, sum_address);
-
- llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &ir_builder_);
- std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- kernel_spatial[i] =
- loops
- .AddLoop(0,
- rhs->shape().dimensions(
- dnums.kernel_spatial_dimensions(i)),
- tensorflow::strings::StrCat("k", i))
- ->GetIndVarValue();
- }
- llvm::Value* input_feature =
- loops
- .AddLoop(
- 0, lhs->shape().dimensions(dnums.input_feature_dimension()),
- "iz")
- ->GetIndVarValue();
-
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
-
- // Calculate the spatial index in the input array, taking striding,
- // dilation and padding into account. An index in the padding will be
- // out of the bounds of the array.
- const auto calculate_input_index =
- [this](llvm::Value* output_index, llvm::Value* kernel_index,
- const WindowDimension& window_dim) {
- llvm::Value* strided_index = ir_builder_.CreateNSWMul(
- output_index, ir_builder_.getInt64(window_dim.stride()));
- llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul(
- kernel_index,
- ir_builder_.getInt64(window_dim.window_dilation()));
- return ir_builder_.CreateNSWSub(
- ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index),
- ir_builder_.getInt64(window_dim.padding_low()));
- };
- std::vector<llvm::Value*> input_spatial(num_spatial_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- input_spatial[i] = calculate_input_index(
- output_spatial[i], kernel_spatial[i], window.dimensions(i));
- }
-
- // We need to check if 0 <= input dim < bound, as otherwise we are in
- // the padding so that we can skip the computation. That is equivalent
- // to input dim < bound as an *unsigned* comparison, since a negative
- // value will wrap to a large positive value. The input dim is dilated,
- // so we need to dilate the bound as well to match.
-
- // Also need to check that the input coordinates are not in one of the
- // holes created by base dilation.
- const auto not_in_hole = [&](llvm::Value* input_index,
- int64 base_dilation) {
- llvm::Value* remainder = ir_builder_.CreateSRem(
- input_index, ir_builder_.getInt64(base_dilation));
- return ir_builder_.CreateICmpEQ(remainder, ir_builder_.getInt64(0));
- };
-
- llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
- for (int i = 0; i < num_spatial_dims; ++i) {
- llvm::ConstantInt* input_bound =
- ir_builder_.getInt64(window_util::DilatedBound(
- lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
- window.dimensions(i).base_dilation()));
- llvm::Value* dim_in_bound =
- ir_builder_.CreateICmpULT(input_spatial[i], input_bound);
- llvm::Value* dim_not_in_hole = not_in_hole(
- input_spatial[i], window.dimensions(i).base_dilation());
- llvm::Value* dim_ok =
- ir_builder_.CreateAnd(dim_in_bound, dim_not_in_hole);
- in_bounds_condition =
- ir_builder_.CreateAnd(in_bounds_condition, dim_ok);
- }
-
- // Now we need to map the dilated base coordinates back to the actual
- // data indices on the lhs.
- const auto undilate = [&](llvm::Value* input_index,
- int64 base_dilation) {
- return ir_builder_.CreateSDiv(input_index,
- ir_builder_.getInt64(base_dilation));
- };
- for (int i = 0; i < num_spatial_dims; ++i) {
- input_spatial[i] =
- undilate(input_spatial[i], window.dimensions(i).base_dilation());
- }
-
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- in_bounds_condition, "in-bounds", &ir_builder_);
- SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
-
- // We are not in the padding, so carry out the computation.
- int num_dims = num_spatial_dims + 2;
- llvm_ir::IrArray::Index input_index(ir_builder_.getInt64Ty(), num_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
- }
- input_index[dnums.input_feature_dimension()] = input_feature;
- input_index[dnums.input_batch_dimension()] = batch;
-
- llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs));
- llvm_ir::IrArray::Index kernel_index(ir_builder_.getInt64Ty(),
- num_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- kernel_index[dnums.kernel_spatial_dimensions(i)] =
- window.dimensions(i).window_reversal()
- ? ir_builder_.CreateNSWSub(
- ir_builder_.getInt64(window.dimensions(i).size() - 1),
- kernel_spatial[i])
- : kernel_spatial[i];
- }
-
- kernel_index[dnums.kernel_input_feature_dimension()] = input_feature;
- kernel_index[dnums.kernel_output_feature_dimension()] = output_feature;
-
- llvm_ir::IrArray input_array(GetIrArrayFor(lhs));
- llvm::Value* product = ir_builder_.CreateFMul(
- input_array.EmitReadArrayElement(input_index, &ir_builder_),
- kernel_array.EmitReadArrayElement(kernel_index, &ir_builder_));
- llvm::Value* sum = ir_builder_.CreateFAdd(
- ir_builder_.CreateLoad(sum_address), product);
- ir_builder_.CreateStore(sum, sum_address);
-
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
- return ir_builder_.CreateLoad(sum_address);
+ convolution, [&](const llvm_ir::IrArray::Index& index) {
+ return EmitTargetElementLoopBodyForConvolution(
+ Cast<HloConvolutionInstruction>(convolution), index);
});
}
@@ -1152,11 +1138,11 @@ Status IrEmitter::HandleFft(HloInstruction* fft) {
}
// Args have been computed, make the call.
- llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo();
- llvm::Type* int32_type = ir_builder_.getInt32Ty();
- llvm::Type* int64_type = ir_builder_.getInt64Ty();
+ llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo();
+ llvm::Type* int32_type = b_.getInt32Ty();
+ llvm::Type* int64_type = b_.getInt64Ty();
llvm::FunctionType* fft_type = llvm::FunctionType::get(
- ir_builder_.getVoidTy(),
+ b_.getVoidTy(),
{int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type,
int64_type, int64_type, int64_type, int64_type},
/*isVarArg=*/false);
@@ -1173,16 +1159,15 @@ Status IrEmitter::HandleFft(HloInstruction* fft) {
fft_func->setDoesNotThrow();
fft_func->setOnlyAccessesInaccessibleMemOrArgMem();
const int fft_rank = fft_length.size();
- ir_builder_.CreateCall(
+ b_.CreateCall(
fft_func,
{GetExecutableRunOptionsArgument(),
- ir_builder_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type),
- ir_builder_.CreateBitCast(operand_address, int8_ptr_type),
- ir_builder_.getInt32(fft->fft_type()), ir_builder_.getInt32(fft_rank),
- ir_builder_.getInt64(input_batch),
- ir_builder_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
- ir_builder_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
- ir_builder_.getInt64(fft_rank > 2 ? fft_length[2] : 0)});
+ b_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type),
+ b_.CreateBitCast(operand_address, int8_ptr_type),
+ b_.getInt32(fft->fft_type()), b_.getInt32(fft_rank),
+ b_.getInt64(input_batch), b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
+ b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
+ b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)});
return Status::OK();
}
@@ -1221,11 +1206,10 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape));
// TODO(b/63762267): Be more aggressive about specifying alignment.
- ir_builder_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
- /*SrcAlign=*/1,
- ShapeUtil::ByteSizeOf(operand_shape));
+ b_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
+ /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape));
}
- llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &ir_builder_, module_);
+ llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_, module_);
return Status::OK();
}
@@ -1258,47 +1242,7 @@ static llvm_ir::IrArray::Index FillReducedDimensionIndex(
Status IrEmitter::HandleParameter(HloInstruction* parameter) {
VLOG(2) << "HandleParameter: " << parameter->ToString();
- auto param_number = parameter->parameter_number();
- auto param_shape = parameter->shape();
-
- // We have to access the parameter at offset param_number in the params
- // array. The code generated here is equivalent to this C code:
- //
- // i8* param_address_untyped = params[param_number];
- // Param* param_address_typed = (Param*)param_address_untyped;
- //
- // Where Param is the actual element type of the underlying buffer (for
- // example, float for an XLA F32 element type).
- llvm::Value* params = compute_function_->parameters_arg();
- llvm::Value* param_address_offset =
- llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_);
- llvm::LoadInst* param_address_untyped =
- ir_builder_.CreateLoad(param_address_offset);
- param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped")));
- if (is_top_level_computation_ &&
- hlo_module_config_.debug_options()
- .xla_llvm_enable_invariant_load_metadata()) {
- // In the entry computation the parameter slots in the %params argument are
- // invariant through program execution. In computations that are called
- // from the entry computation (via kWhile, kCall and kConditional) the
- // parameter slots are *not* invariant since they're written to by their
- // callers.
- param_address_untyped->setMetadata(
- llvm::LLVMContext::MD_invariant_load,
- llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{}));
- }
-
- llvm::Value* param_address_typed = ir_builder_.CreateBitCast(
- param_address_untyped, IrShapeType(param_shape)->getPointerTo());
- emitted_value_[parameter] = param_address_typed;
-
- if (!ShapeUtil::IsOpaque(param_shape)) {
- AttachAlignmentMetadataForLoad(param_address_untyped, param_shape);
- AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape);
- }
-
- VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*param_address_typed);
- return Status::OK();
+ return EmitTargetAddressForOp(parameter);
}
// Returns true if the relative order of the unreduced dimensions stays the same
@@ -1396,62 +1340,61 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator(
return nullptr;
case HloOpcode::kAdd:
- return [root_is_integral](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs,
+ return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs,
llvm::Value* rhs) {
- return root_is_integral ? ir_builder->CreateAdd(lhs, rhs)
- : ir_builder->CreateFAdd(lhs, rhs);
+ return root_is_integral ? b->CreateAdd(lhs, rhs)
+ : b->CreateFAdd(lhs, rhs);
};
case HloOpcode::kMultiply:
- return [root_is_integral](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs,
+ return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs,
llvm::Value* rhs) {
- return root_is_integral ? ir_builder->CreateMul(lhs, rhs)
- : ir_builder->CreateFMul(lhs, rhs);
+ return root_is_integral ? b->CreateMul(lhs, rhs)
+ : b->CreateFMul(lhs, rhs);
};
case HloOpcode::kAnd:
- return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs,
- llvm::Value* rhs) { return ir_builder->CreateAnd(lhs, rhs); };
+ return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
+ return b->CreateAnd(lhs, rhs);
+ };
case HloOpcode::kOr:
- return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs,
- llvm::Value* rhs) { return ir_builder->CreateOr(lhs, rhs); };
+ return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
+ return b->CreateOr(lhs, rhs);
+ };
case HloOpcode::kXor:
- return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs,
- llvm::Value* rhs) { return ir_builder->CreateXor(lhs, rhs); };
+ return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
+ return b->CreateXor(lhs, rhs);
+ };
case HloOpcode::kMaximum:
return [root_is_floating_point, root_is_signed](
- llvm::IRBuilder<>* ir_builder, llvm::Value* lhs,
- llvm::Value* rhs) {
+ llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
if (root_is_floating_point) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum,
- {lhs, rhs}, {lhs->getType()},
- ir_builder);
+ {lhs, rhs}, {lhs->getType()}, b);
}
- return ir_builder->CreateSelect(
- ir_builder->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SGE
- : llvm::ICmpInst::ICMP_UGE,
- lhs, rhs),
+ return b->CreateSelect(
+ b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SGE
+ : llvm::ICmpInst::ICMP_UGE,
+ lhs, rhs),
lhs, rhs);
};
case HloOpcode::kMinimum:
return [root_is_floating_point, root_is_signed](
- llvm::IRBuilder<>* ir_builder, llvm::Value* lhs,
- llvm::Value* rhs) {
+ llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
if (root_is_floating_point) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum,
- {lhs, rhs}, {lhs->getType()},
- ir_builder);
+ {lhs, rhs}, {lhs->getType()}, b);
}
- return ir_builder->CreateSelect(
- ir_builder->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SLE
- : llvm::ICmpInst::ICMP_ULE,
- lhs, rhs),
+ return b->CreateSelect(
+ b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SLE
+ : llvm::ICmpInst::ICMP_ULE,
+ lhs, rhs),
lhs, rhs);
};
}
@@ -1520,34 +1463,31 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
accumulator.reserve(accumulator_type.size());
for (auto accumulator_shard_type : accumulator_type) {
accumulator.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
- accumulator_shard_type, "accumulator", &ir_builder_, 0));
+ accumulator_shard_type, "accumulator", &b_, 0));
}
- llvm::Value* init_value_ssa =
- ir_builder_.CreateLoad(GetEmittedValueFor(init_value));
+ llvm::Value* init_value_ssa = b_.CreateLoad(GetEmittedValueFor(init_value));
for (llvm::Value* accumulator_shard : accumulator) {
llvm::Value* initial_value;
auto shard_type = accumulator_shard->getType()->getPointerElementType();
if (auto vector_type = llvm::dyn_cast<llvm::VectorType>(shard_type)) {
- initial_value = ir_builder_.CreateVectorSplat(
- vector_type->getNumElements(), init_value_ssa);
+ initial_value =
+ b_.CreateVectorSplat(vector_type->getNumElements(), init_value_ssa);
} else {
initial_value = init_value_ssa;
}
- ir_builder_.CreateAlignedStore(initial_value, accumulator_shard,
- element_alignment);
+ b_.CreateAlignedStore(initial_value, accumulator_shard, element_alignment);
}
llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"),
- &ir_builder_);
+ &b_);
llvm_ir::IrArray::Index reduced_dims_index =
reduction_loop_nest.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
"reduction_dim");
- SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(),
- &ir_builder_);
+ SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), &b_);
llvm_ir::IrArray arg_array(GetIrArrayFor(arg));
llvm_ir::IrArray::Index input_index = reduced_dims_index;
@@ -1560,38 +1500,34 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
}
CHECK(output_index.end() == it);
- llvm::Value* input_address = ir_builder_.CreateBitCast(
- arg_array.EmitArrayElementAddress(input_index, &ir_builder_),
- ir_builder_.getInt8PtrTy());
+ llvm::Value* input_address = b_.CreateBitCast(
+ arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy());
for (int i = 0; i < accumulator.size(); i++) {
auto input_address_typed =
- ir_builder_.CreateBitCast(input_address, accumulator[i]->getType());
+ b_.CreateBitCast(input_address, accumulator[i]->getType());
auto current_accumulator_value =
- ir_builder_.CreateAlignedLoad(accumulator[i], element_alignment);
- auto addend =
- ir_builder_.CreateAlignedLoad(input_address_typed, element_alignment);
+ b_.CreateAlignedLoad(accumulator[i], element_alignment);
+ auto addend = b_.CreateAlignedLoad(input_address_typed, element_alignment);
arg_array.AnnotateLoadStoreInstructionWithMetadata(addend);
auto reduced_result =
- reduction_generator(&ir_builder_, current_accumulator_value, addend);
- ir_builder_.CreateAlignedStore(reduced_result, accumulator[i],
- element_alignment);
+ reduction_generator(&b_, current_accumulator_value, addend);
+ b_.CreateAlignedStore(reduced_result, accumulator[i], element_alignment);
if (i != (accumulator.size() - 1)) {
- input_address = ir_builder_.CreateConstInBoundsGEP1_32(
- reduced_result->getType(), input_address_typed, 1);
+ input_address = b_.CreateConstInBoundsGEP1_32(reduced_result->getType(),
+ input_address_typed, 1);
}
}
- SetToFirstInsertPoint(reduction_loop_nest.GetOuterLoopExitBasicBlock(),
- &ir_builder_);
+ SetToFirstInsertPoint(reduction_loop_nest.GetOuterLoopExitBasicBlock(), &b_);
ShardedVector result_ssa;
result_ssa.reserve(accumulator.size());
for (auto accumulator_shard : accumulator) {
result_ssa.push_back(
- ir_builder_.CreateAlignedLoad(accumulator_shard, element_alignment));
+ b_.CreateAlignedLoad(accumulator_shard, element_alignment));
}
return result_ssa;
}
@@ -1600,17 +1536,17 @@ void IrEmitter::EmitShardedVectorStore(
llvm::Value* store_address, const std::vector<llvm::Value*>& value_to_store,
const int alignment, const llvm_ir::IrArray& containing_array) {
for (int i = 0; i < value_to_store.size(); i++) {
- auto store_address_typed = ir_builder_.CreateBitCast(
+ auto store_address_typed = b_.CreateBitCast(
store_address,
llvm::PointerType::getUnqual(value_to_store[i]->getType()));
- auto store_instruction = ir_builder_.CreateAlignedStore(
+ auto store_instruction = b_.CreateAlignedStore(
value_to_store[i], store_address_typed, alignment);
containing_array.AnnotateLoadStoreInstructionWithMetadata(
store_instruction);
if (i != (value_to_store.size() - 1)) {
- store_address = ir_builder_.CreateConstInBoundsGEP1_32(
+ store_address = b_.CreateConstInBoundsGEP1_32(
value_to_store[i]->getType(), store_address_typed, 1);
}
}
@@ -1676,8 +1612,8 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
// }
// }
- llvm_ir::ForLoopNest loop_nest(IrName(reduce), &ir_builder_);
- llvm_ir::IrArray::Index array_index(ir_builder_.getInt64Ty(),
+ llvm_ir::ForLoopNest loop_nest(IrName(reduce), &b_);
+ llvm_ir::IrArray::Index array_index(b_.getInt64Ty(),
reduce->shape().dimensions_size());
for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0;
--i) {
@@ -1696,7 +1632,7 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
if (llvm::BasicBlock* innermost_body_bb =
loop_nest.GetInnerLoopBodyBasicBlock()) {
- SetToFirstInsertPoint(innermost_body_bb, &ir_builder_);
+ SetToFirstInsertPoint(innermost_body_bb, &b_);
}
auto outermost_loop_exit_block = loop_nest.GetOuterLoopExitBasicBlock();
@@ -1710,7 +1646,7 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
tensorflow::strings::Printf("dim.%lld", innermost_dimension));
array_index[innermost_dimension] = loop->GetIndVarValue();
- SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &ir_builder_);
+ SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_);
ShardedVectorType vector_type = CreateShardedVectorType(
reduce->shape().element_type(), vectorization_factor);
@@ -1721,16 +1657,16 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
llvm::Value* output_address =
- target_array.EmitArrayElementAddress(array_index, &ir_builder_);
+ target_array.EmitArrayElementAddress(array_index, &b_);
EmitShardedVectorStore(output_address, accumulator, element_alignment,
target_array);
if (auto exit_terminator = loop->GetExitBasicBlock()->getTerminator()) {
CHECK_GT(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
- ir_builder_.SetInsertPoint(exit_terminator);
+ b_.SetInsertPoint(exit_terminator);
} else {
CHECK_EQ(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
- ir_builder_.SetInsertPoint(loop->GetExitBasicBlock());
+ b_.SetInsertPoint(loop->GetExitBasicBlock());
}
}
@@ -1740,8 +1676,8 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
if (innermost_dimension_size % vectorization_factor) {
// TODO(b/63775531): Consider using a scalar loop here to save on code size.
array_index[innermost_dimension] =
- ir_builder_.getInt64(innermost_dimension_size -
- (innermost_dimension_size % vectorization_factor));
+ b_.getInt64(innermost_dimension_size -
+ (innermost_dimension_size % vectorization_factor));
ShardedVectorType vector_type = CreateShardedVectorType(
reduce->shape().element_type(),
@@ -1753,19 +1689,77 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
llvm::Value* output_address =
- target_array.EmitArrayElementAddress(array_index, &ir_builder_);
+ target_array.EmitArrayElementAddress(array_index, &b_);
EmitShardedVectorStore(output_address, accumulator, element_alignment,
target_array);
}
if (outermost_loop_exit_block) {
- ir_builder_.SetInsertPoint(outermost_loop_exit_block);
+ b_.SetInsertPoint(outermost_loop_exit_block);
}
return true;
}
+StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
+ HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index) {
+ const HloInstruction* arg = reduce->mutable_operand(0);
+ const HloInstruction* init_value = reduce->mutable_operand(1);
+ gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+
+ // Initialize an accumulator with init_value.
+ PrimitiveType accumulator_type = reduce->shape().element_type();
+ llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator",
+ &b_, MinimumAlignmentForPrimitiveType(accumulator_type));
+ llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
+ llvm::Value* load_init_value = b_.CreateLoad(init_value_addr);
+ b_.CreateStore(load_init_value, accumulator_addr);
+
+ // The enclosing loops go over all the target elements. Now we have to compute
+ // the actual target element. For this, we build a new loop nest to iterate
+ // over all the reduction dimensions in the argument.
+ // AddLoopsForShapeOnDimensions will return an Index where induction Value*s
+ // are placed for each dimension in dimensions, and all the rest are nullptrs.
+ llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_);
+ const llvm_ir::IrArray::Index reduced_dims_index =
+ loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
+ "reduction_dim");
+
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
+
+ // Build a full index for the input argument, using reduced_dims_index as the
+ // base. In reduced_dims_index only the reduction dimensions are filled in. We
+ // fill in the rest of the dimensions with induction Value*s taken from
+ // 'index' which iterates over the target array. See the high-level
+ // description in the XLA documentation for details.
+ llvm_ir::IrArray arg_array(GetIrArrayFor(arg));
+ llvm_ir::IrArray::Index input_index = reduced_dims_index;
+ llvm_ir::IrArray::Index::const_iterator it = index.begin();
+
+ for (size_t i = 0; i < input_index.size(); ++i) {
+ if (input_index[i] == nullptr) {
+ input_index[i] = *it++;
+ }
+ }
+ CHECK(index.end() == it);
+
+ // Apply the reduction function to the loaded value.
+ llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_);
+ llvm::Value* result = EmitThreadLocalCall(
+ *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element},
+ "reduce_function");
+ b_.CreateStore(result, accumulator_addr);
+
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
+ return b_.CreateLoad(accumulator_addr);
+}
+
Status IrEmitter::HandleReduce(HloInstruction* reduce) {
+ // TODO(b/112040122): Support variadic reduce.
+ if (!ShapeUtil::IsArray(reduce->shape())) {
+ return Unimplemented("Variadic reduce is not supported on CPU");
+ }
auto arg = reduce->mutable_operand(0);
auto init_value = reduce->mutable_operand(1);
gtl::ArraySlice<int64> dimensions(reduce->dimensions());
@@ -1786,61 +1780,11 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
}
}
- // The called computation should have been emitted previously.
- llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
- return EmitTargetElementLoop(
- reduce, [this, reduce, arg, init_value, dimensions,
- reducer_function](const llvm_ir::IrArray::Index& index) {
- // Initialize an accumulator with init_value.
- PrimitiveType accumulator_type = reduce->shape().element_type();
- llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
- llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_),
- "accumulator", &ir_builder_,
- MinimumAlignmentForPrimitiveType(accumulator_type));
- llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
- llvm::Value* load_init_value = ir_builder_.CreateLoad(init_value_addr);
- ir_builder_.CreateStore(load_init_value, accumulator_addr);
-
- // The enclosing loops go over all the target elements. Now we have to
- // compute the actual target element. For this, we build a new loop nest
- // to iterate over all the reduction dimensions in the argument.
- // AddLoopsForShapeOnDimensions will return an Index where induction
- // Value*s are placed for each dimension in dimensions, and all the rest
- // are nullptrs.
- llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &ir_builder_);
- const llvm_ir::IrArray::Index reduced_dims_index =
- loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
- "reduction_dim");
-
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
-
- // Build a full index for the input argument, using reduced_dims_index
- // as the base. In reduced_dims_index only the reduction dimensions are
- // filled in. We fill in the rest of the dimensions with induction
- // Value*s taken from 'index' which iterates over the target array.
- // See the high-level description in the XLA documentation for details.
- llvm_ir::IrArray arg_array(GetIrArrayFor(arg));
- llvm_ir::IrArray::Index input_index = reduced_dims_index;
- llvm_ir::IrArray::Index::const_iterator it = index.begin();
-
- for (size_t i = 0; i < input_index.size(); ++i) {
- if (input_index[i] == nullptr) {
- input_index[i] = *it++;
- }
- }
- CHECK(index.end() == it);
-
- // Apply the reduction function to the loaded value.
- llvm::Value* input_address =
- arg_array.EmitArrayElementAddress(input_index, &ir_builder_);
- llvm::Value* result = EmitElementFunctionCall(
- reducer_function, reduce->shape(),
- {accumulator_addr, input_address}, "reduce_function");
- ir_builder_.CreateStore(result, accumulator_addr);
-
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
- return ir_builder_.CreateLoad(accumulator_addr);
- });
+ return EmitTargetElementLoop(reduce,
+ [&](const llvm_ir::IrArray::Index& index) {
+ return EmitTargetElementLoopBodyForReduce(
+ Cast<HloReduceInstruction>(reduce), index);
+ });
}
Status IrEmitter::HandleSend(HloInstruction* send) {
@@ -1853,6 +1797,10 @@ Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
return Unimplemented("Send-done is not implemented on CPU.");
}
+Status IrEmitter::HandleScatter(HloInstruction*) {
+ return Unimplemented("Scatter is not implemented on CPUs.");
+}
+
Status IrEmitter::HandleSlice(HloInstruction* slice) {
VLOG(2) << "HandleSlice: " << slice->ToString();
auto operand = slice->operand(0);
@@ -1942,7 +1890,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) {
llvm_ir::IrArray target_array = GetIrArrayFor(slice);
const int64 num_outer_loops = outer_dims.size();
- llvm_ir::ForLoopNest loops(IrName(slice), &ir_builder_);
+ llvm_ir::ForLoopNest loops(IrName(slice), &b_);
llvm_ir::IrArray::Index target_index =
loops.AddLoopsForShapeOnDimensions(slice->shape(), outer_dims, "slice");
@@ -1951,21 +1899,21 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) {
// for the rest of the dimensions the copy writes to the full dimension.
std::replace(target_index.begin(), target_index.end(),
static_cast<llvm::Value*>(nullptr),
- static_cast<llvm::Value*>(ir_builder_.getInt64(0)));
+ static_cast<llvm::Value*>(b_.getInt64(0)));
if (num_outer_loops > 0) {
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
}
llvm_ir::IrArray source_array = GetIrArrayFor(operand);
const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice(
/*shape=*/slice->shape(), /*starts=*/slice->slice_starts(),
- /*strides=*/slice->slice_strides(), /*builder=*/&ir_builder_);
+ /*strides=*/slice->slice_strides(), /*builder=*/&b_);
- llvm::Value* memcpy_dest = target_array.EmitArrayElementAddress(
- target_index, &ir_builder_, "slice.dest");
- llvm::Value* memcpy_source = source_array.EmitArrayElementAddress(
- source_index, &ir_builder_, "slice.source");
+ llvm::Value* memcpy_dest =
+ target_array.EmitArrayElementAddress(target_index, &b_, "slice.dest");
+ llvm::Value* memcpy_source =
+ source_array.EmitArrayElementAddress(source_index, &b_, "slice.source");
const int64 memcpy_elements =
primitive_elements_per_logical_element * memcpy_logical_elements;
@@ -1982,7 +1930,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) {
}
if (num_outer_loops > 0) {
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
}
return Status::OK();
@@ -2008,7 +1956,7 @@ Status IrEmitter::HandleDynamicUpdateSlice(
auto operands = GetIrArraysForOperandsOf(dynamic_update_slice);
return llvm_ir::EmitDynamicUpdateSliceInPlace(
operands, GetIrArrayFor(dynamic_update_slice),
- IrName(dynamic_update_slice, "in_place"), &ir_builder_);
+ IrName(dynamic_update_slice, "in_place"), &b_);
}
return DefaultAction(dynamic_update_slice);
}
@@ -2042,43 +1990,41 @@ Status IrEmitter::HandlePad(HloInstruction* pad) {
[this, pad](const llvm_ir::IrArray::Index& target_index) {
const HloInstruction* padding_value = pad->operand(1);
llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value);
- return ir_builder_.CreateLoad(padding_value_addr);
+ return b_.CreateLoad(padding_value_addr);
}));
// Create a loop to iterate over the operand elements and update the output
// locations where the operand elements should be stored.
- llvm_ir::ForLoopNest loops(IrName(pad, "assign"), &ir_builder_);
+ llvm_ir::ForLoopNest loops(IrName(pad, "assign"), &b_);
const HloInstruction* operand = pad->operand(0);
const llvm_ir::IrArray::Index operand_index =
loops.AddLoopsForShape(operand->shape(), "operand");
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
// Load an element from the operand.
llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
llvm::Value* operand_data =
- operand_array.EmitReadArrayElement(operand_index, &ir_builder_);
+ operand_array.EmitReadArrayElement(operand_index, &b_);
// Compute the output index the operand element should be assigned to.
// output_index := edge_padding_low + operand_index * (interior_padding + 1)
const PaddingConfig& padding_config = pad->padding_config();
llvm_ir::IrArray::Index output_index(operand_index.GetType());
for (size_t i = 0; i < operand_index.size(); ++i) {
- llvm::Value* offset = ir_builder_.CreateMul(
+ llvm::Value* offset = b_.CreateMul(
operand_index[i],
- ir_builder_.getInt64(padding_config.dimensions(i).interior_padding() +
- 1));
- llvm::Value* index = ir_builder_.CreateAdd(
- offset,
- ir_builder_.getInt64(padding_config.dimensions(i).edge_padding_low()));
+ b_.getInt64(padding_config.dimensions(i).interior_padding() + 1));
+ llvm::Value* index = b_.CreateAdd(
+ offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low()));
output_index.push_back(index);
}
// Store the operand element to the computed output location.
llvm_ir::IrArray output_array(GetIrArrayFor(pad));
- output_array.EmitWriteArrayElement(output_index, operand_data, &ir_builder_);
+ output_array.EmitWriteArrayElement(output_index, operand_data, &b_);
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
return Status::OK();
}
@@ -2100,8 +2046,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
// Delegate to common implementation of fused in-place dynamic-update-slice.
auto operands = GetIrArraysForOperandsOf(fusion);
return llvm_ir::EmitFusedDynamicUpdateSliceInPlace(
- fusion, operands, GetIrArrayFor(fusion), &elemental_emitter,
- &ir_builder_);
+ fusion, operands, GetIrArrayFor(fusion), &elemental_emitter, &b_);
} else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) {
VLOG(3) << "HandleFusion kLoop";
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
@@ -2136,7 +2081,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation(
*dot, target_array, lhs_array, rhs_array, &addend_array,
- GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_,
+ GetExecutableRunOptionsArgument(), &b_, hlo_module_config_,
target_machine_features_));
return Status::OK();
} else {
@@ -2148,18 +2093,13 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
HloComputation* computation = call->to_apply();
llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation);
- std::vector<llvm::Value*> parameter_addresses;
- for (const HloInstruction* operand : call->operands()) {
- parameter_addresses.push_back(GetEmittedValueFor(operand));
- }
-
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call));
if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
// ParallelTaskAssignment assigned partitions, emit call to
// ParallelForkJoin.
std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments(
- parameter_addresses, &ir_builder_, computation->name(),
+ {}, &b_, computation->name(),
/*return_value_buffer=*/emitted_value_[call],
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
/*temp_buffers_arg=*/GetTempBuffersArgument(),
@@ -2167,11 +2107,10 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
HloInstruction* root = computation->root_instruction();
TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin(
- call_args, root->shape(), root->outer_dimension_partitions(),
- &ir_builder_, call_ir_function, computation->name()));
+ call_args, root->shape(), root->outer_dimension_partitions(), &b_,
+ call_ir_function, computation->name()));
} else {
- EmitArrayFunctionCallInto(call_ir_function, parameter_addresses,
- emitted_value_[call], computation->name());
+ EmitGlobalCall(*computation, computation->name());
}
return Status::OK();
@@ -2180,33 +2119,31 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
gtl::ArraySlice<HloInstruction*> operands(custom_call->operands());
tensorflow::StringPiece custom_call_target(custom_call->custom_call_target());
- llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy();
+ llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
llvm::AllocaInst* operands_alloca =
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
- i8_ptr_type, ir_builder_.getInt32(operands.size()),
- "cc_operands_alloca", &ir_builder_);
+ i8_ptr_type, b_.getInt32(operands.size()), "cc_operands_alloca", &b_);
for (size_t i = 0; i < operands.size(); ++i) {
const HloInstruction* operand = operands[i];
llvm::Value* operand_as_i8ptr =
- ir_builder_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type);
- llvm::Value* slot_in_operands_alloca = ir_builder_.CreateInBoundsGEP(
- operands_alloca, {ir_builder_.getInt64(i)});
- ir_builder_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca);
+ b_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type);
+ llvm::Value* slot_in_operands_alloca =
+ b_.CreateInBoundsGEP(operands_alloca, {b_.getInt64(i)});
+ b_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca);
}
auto* custom_call_ir_function =
llvm::cast<llvm::Function>(module_->getOrInsertFunction(
AsStringRef(custom_call_target),
llvm::FunctionType::get(
- /*Result=*/ir_builder_.getVoidTy(),
+ /*Result=*/b_.getVoidTy(),
/*Params=*/{i8_ptr_type, operands_alloca->getType()},
/*isVarArg=*/false)));
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
- auto* output_address_arg = ir_builder_.CreatePointerCast(
- GetEmittedValueFor(custom_call), i8_ptr_type);
+ auto* output_address_arg =
+ b_.CreatePointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
- ir_builder_.CreateCall(custom_call_ir_function,
- {output_address_arg, operands_alloca});
+ b_.CreateCall(custom_call_ir_function, {output_address_arg, operands_alloca});
return Status::OK();
}
@@ -2254,12 +2191,6 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
const HloInstruction* init = xla_while->operand(0);
emitted_value_[xla_while] = GetEmittedValueFor(init);
- // The called computation should have been emitted previously.
- llvm::Function* condition_ir_function =
- FindOrDie(emitted_functions_, condition);
- llvm::Function* body_ir_function =
- FindOrDie(emitted_functions_, xla_while->while_body());
-
// Generating:
// while (Condition(while_result)) {
// // CopyInsertion pass inserts copies which enable 'while_result' to
@@ -2271,17 +2202,15 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
llvm::BasicBlock* header_bb = llvm::BasicBlock::Create(
module_->getContext(), AsStringRef(IrName(xla_while, "header")),
compute_function_->function());
- ir_builder_.CreateBr(header_bb);
- ir_builder_.SetInsertPoint(header_bb);
+ b_.CreateBr(header_bb);
+ b_.SetInsertPoint(header_bb);
// Calls the condition function to determine whether to proceed with the
// body. It must return a bool, so use the scalar call form.
- llvm::Value* while_result = GetEmittedValueFor(xla_while);
- llvm::Value* while_condition = EmitElementFunctionCall(
- condition_ir_function, condition->root_instruction()->shape(),
- {while_result}, IrName(xla_while, "cond"));
- llvm::Value* while_predicate = ir_builder_.CreateICmpNE(
- while_condition,
+ EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond"));
+ llvm::Value* while_predicate = b_.CreateICmpNE(
+ b_.CreateLoad(
+ GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
// Branches to the body or to the while exit depending on the condition.
@@ -2290,20 +2219,20 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
compute_function_->function());
llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create(
module_->getContext(), AsStringRef(IrName(xla_while, "exit")));
- ir_builder_.CreateCondBr(while_predicate, body_bb, exit_bb);
+ b_.CreateCondBr(while_predicate, body_bb, exit_bb);
// Calls the body function from the body block.
- ir_builder_.SetInsertPoint(body_bb);
+ b_.SetInsertPoint(body_bb);
// Calls the body function.
- EmitArrayFunctionCallInto(body_ir_function, {while_result}, while_result,
- IrName(xla_while, "body"));
+ EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body"));
+
// Finishes with a branch back to the header.
- ir_builder_.CreateBr(header_bb);
+ b_.CreateBr(header_bb);
// Adds the exit block to the function and sets the insert point there.
compute_function_->function()->getBasicBlockList().push_back(exit_bb);
- ir_builder_.SetInsertPoint(exit_bb);
+ b_.SetInsertPoint(exit_bb);
return Status::OK();
}
@@ -2345,21 +2274,21 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
std::vector<int64> outer_dims(std::next(concat_dim_layout_itr),
output_min2maj.end());
- llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy();
- llvm::Type* i8_type = ir_builder_.getInt8Ty();
+ llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
+ llvm::Type* i8_type = b_.getInt8Ty();
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate));
llvm_ir::IrArray target_array = GetIrArrayFor(concatenate);
- llvm_ir::ForLoopNest loops(IrName(concatenate), &ir_builder_);
+ llvm_ir::ForLoopNest loops(IrName(concatenate), &b_);
llvm_ir::IrArray::Index outer_dims_index =
loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat");
std::replace(outer_dims_index.begin(), outer_dims_index.end(),
static_cast<llvm::Value*>(nullptr),
- static_cast<llvm::Value*>(ir_builder_.getInt64(0)));
+ static_cast<llvm::Value*>(b_.getInt64(0)));
if (!outer_dims.empty()) {
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
}
PrimitiveType primitive_type = output_shape.element_type();
@@ -2368,10 +2297,10 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
// Contiguous subregions from each operand to the concatenate contribute to a
// contiguous subregion in the target buffer starting at target_region_begin.
- llvm::Value* target_region_begin = ir_builder_.CreateBitCast(
- target_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_,
- "target_region"),
- i8_ptr_type);
+ llvm::Value* target_region_begin =
+ b_.CreateBitCast(target_array.EmitArrayElementAddress(
+ outer_dims_index, &b_, "target_region"),
+ i8_ptr_type);
int64 byte_offset_into_target_region = 0;
int64 inner_dims_product =
@@ -2385,14 +2314,13 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
for (HloInstruction* operand : operands) {
const Shape& input_shape = operand->shape();
llvm_ir::IrArray source_array = GetIrArrayFor(operand);
- llvm::Value* copy_source_address = ir_builder_.CreateBitCast(
- source_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_,
- "src_addr"),
+ llvm::Value* copy_source_address = b_.CreateBitCast(
+ source_array.EmitArrayElementAddress(outer_dims_index, &b_, "src_addr"),
i8_ptr_type);
- llvm::Value* copy_target_address = ir_builder_.CreateGEP(
- i8_type, target_region_begin,
- ir_builder_.getInt64(byte_offset_into_target_region));
+ llvm::Value* copy_target_address =
+ b_.CreateGEP(i8_type, target_region_begin,
+ b_.getInt64(byte_offset_into_target_region));
EmitTransferElements(
copy_target_address, copy_source_address,
@@ -2405,7 +2333,7 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
}
if (!outer_dims.empty()) {
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
}
return true;
@@ -2424,16 +2352,15 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_));
if (element_count == 1) {
- auto* load_instruction = ir_builder_.CreateAlignedLoad(
- ir_builder_.CreateBitCast(source, primitive_ptr_type),
- element_alignment);
+ auto* load_instruction = b_.CreateAlignedLoad(
+ b_.CreateBitCast(source, primitive_ptr_type), element_alignment);
source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction);
- auto* store_instruction = ir_builder_.CreateAlignedStore(
- load_instruction, ir_builder_.CreateBitCast(target, primitive_ptr_type),
+ auto* store_instruction = b_.CreateAlignedStore(
+ load_instruction, b_.CreateBitCast(target, primitive_ptr_type),
element_alignment);
target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
} else {
- auto* memcpy_instruction = ir_builder_.CreateMemCpy(
+ auto* memcpy_instruction = b_.CreateMemCpy(
target, /*DstAlign=*/element_alignment, source,
/*SrcAlign=*/element_alignment, element_count * primitive_type_size);
@@ -2467,8 +2394,6 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
Status IrEmitter::HandleConditional(HloInstruction* conditional) {
auto pred = conditional->operand(0);
- auto true_arg = conditional->operand(1);
- auto false_arg = conditional->operand(2);
TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) &&
pred->shape().element_type() == PRED)
<< "Predicate on a Conditional must be bool; got: "
@@ -2490,37 +2415,31 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
<< " and "
<< ShapeUtil::HumanString(false_computation->root_instruction()->shape());
- llvm::Function* true_function =
- FindOrDie(emitted_functions_, true_computation);
- llvm::Function* false_function =
- FindOrDie(emitted_functions_, false_computation);
-
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional));
- llvm::Value* conditional_result = GetEmittedValueFor(conditional);
// Generating:
// if (pred)
// cond_result = true_computation(true_operand)
// else
// cond_result = false_computation(false_operand)
- llvm::LoadInst* pred_value = ir_builder_.CreateLoad(
+ llvm::LoadInst* pred_value = b_.CreateLoad(
GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value");
- llvm::Value* pred_cond = ir_builder_.CreateICmpNE(
+ llvm::Value* pred_cond = b_.CreateICmpNE(
pred_value,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
"boolean_predicate");
llvm_ir::LlvmIfData if_data =
- llvm_ir::EmitIfThenElse(pred_cond, "conditional", &ir_builder_);
+ llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_);
- SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
- EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)},
- conditional_result, IrName(conditional, "_true"));
+ SetToFirstInsertPoint(if_data.true_block, &b_);
+ EmitGlobalCall(*conditional->true_computation(),
+ IrName(conditional, "_true"));
- SetToFirstInsertPoint(if_data.false_block, &ir_builder_);
- EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)},
- conditional_result, IrName(conditional, "_false"));
+ SetToFirstInsertPoint(if_data.false_block, &b_);
+ EmitGlobalCall(*conditional->false_computation(),
+ IrName(conditional, "_false"));
- SetToFirstInsertPoint(if_data.after_block, &ir_builder_);
+ SetToFirstInsertPoint(if_data.after_block, &b_);
return Status::OK();
}
@@ -2531,6 +2450,28 @@ Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) {
return Status::OK();
}
+Status IrEmitter::HandleIota(HloInstruction* iota) {
+ // TODO(b/64798317): implement iota on CPU.
+ return Unimplemented("Iota is not implemented on CPU.");
+}
+
+Status IrEmitter::HandleRng(HloInstruction* rng) {
+ ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
+ for (const HloInstruction* operand : rng->operands()) {
+ operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
+ return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
+ };
+ }
+
+ CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
+ TF_RETURN_IF_ERROR(EmitTargetElementLoop(
+ rng, elemental_emitter.MakeElementGenerator(rng, operand_to_generator)));
+
+ llvm_ir::IncrementVariableForPhiloxRngState(1, module_, &b_);
+
+ return Status::OK();
+}
+
Status IrEmitter::FinishVisit(HloInstruction* root) {
// When this method is called, we should have already emitted an IR value for
// the root (return) op. The IR value holds the address of the buffer holding
@@ -2548,7 +2489,7 @@ Status IrEmitter::FinishVisit(HloInstruction* root) {
auto record_complete_computation = [&](llvm::Value* prof_counter) {
if (prof_counter) {
- profiling_state_.RecordCompleteComputation(&ir_builder_, prof_counter);
+ profiling_state_.RecordCompleteComputation(&b_, prof_counter);
}
};
@@ -2570,54 +2511,51 @@ llvm::Value* IrEmitter::GetProfileCounterCommon(
int64 prof_counter_idx = it->second;
string counter_name = IrName("prof_counter", hlo.name());
- return ir_builder_.CreateGEP(GetProfileCountersArgument(),
- ir_builder_.getInt64(prof_counter_idx),
- AsStringRef(counter_name));
+ return b_.CreateGEP(GetProfileCountersArgument(),
+ b_.getInt64(prof_counter_idx), AsStringRef(counter_name));
}
-void IrEmitter::ProfilingState::UpdateProfileCounter(
- llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter,
- llvm::Value* cycle_end, llvm::Value* cycle_start) {
- auto* cycle_diff = ir_builder->CreateSub(cycle_end, cycle_start);
+void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b,
+ llvm::Value* prof_counter,
+ llvm::Value* cycle_end,
+ llvm::Value* cycle_start) {
+ auto* cycle_diff = b->CreateSub(cycle_end, cycle_start);
llvm::LoadInst* old_cycle_count =
- ir_builder->CreateLoad(prof_counter, "old_cycle_count");
+ b->CreateLoad(prof_counter, "old_cycle_count");
auto* new_cycle_count =
- ir_builder->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count");
- ir_builder->CreateStore(new_cycle_count, prof_counter);
+ b->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count");
+ b->CreateStore(new_cycle_count, prof_counter);
}
-llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter(
- llvm::IRBuilder<>* ir_builder) {
- llvm::Module* module = ir_builder->GetInsertBlock()->getModule();
+llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter(llvm::IRBuilder<>* b) {
+ llvm::Module* module = b->GetInsertBlock()->getModule();
if (use_rdtscp_) {
llvm::Function* func_llvm_readcyclecounter =
llvm::Intrinsic::getDeclaration(module,
llvm::Intrinsic::readcyclecounter);
- return ir_builder->CreateCall(func_llvm_readcyclecounter);
+ return b->CreateCall(func_llvm_readcyclecounter);
}
llvm::Function* func_llvm_x86_rdtscp =
llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::x86_rdtscp);
if (!aux_i8ptr_) {
- llvm::AllocaInst* rdtscp_aux = llvm_ir::EmitAllocaAtFunctionEntry(
- ir_builder->getInt32Ty(), "rdtscp_aux", ir_builder);
- aux_i8ptr_ =
- ir_builder->CreateBitCast(rdtscp_aux, ir_builder->getInt8PtrTy());
+ llvm::AllocaInst* rdtscp_aux =
+ llvm_ir::EmitAllocaAtFunctionEntry(b->getInt32Ty(), "rdtscp_aux", b);
+ aux_i8ptr_ = b->CreateBitCast(rdtscp_aux, b->getInt8PtrTy());
}
- llvm::ConstantInt* alloca_size = ir_builder->getInt64(4);
+ llvm::ConstantInt* alloca_size = b->getInt64(4);
llvm::Function* func_llvm_lifetime_start =
llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_start);
- ir_builder->CreateCall(func_llvm_lifetime_start, {alloca_size, aux_i8ptr_});
- llvm::Value* rdtscp_call =
- ir_builder->CreateCall(func_llvm_x86_rdtscp, aux_i8ptr_);
+ b->CreateCall(func_llvm_lifetime_start, {alloca_size, aux_i8ptr_});
+ llvm::Value* rdtscp_call = b->CreateCall(func_llvm_x86_rdtscp, aux_i8ptr_);
llvm::Function* func_llvm_lifetime_end =
llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_end);
- ir_builder->CreateCall(func_llvm_lifetime_end, {alloca_size, aux_i8ptr_});
+ b->CreateCall(func_llvm_lifetime_end, {alloca_size, aux_i8ptr_});
return rdtscp_call;
}
-void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* ir_builder,
+void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* b,
HloInstruction* hlo) {
- auto* cycle_start = ReadCycleCounter(ir_builder);
+ auto* cycle_start = ReadCycleCounter(b);
cycle_start->setName(AsStringRef(IrName(hlo, "cycle_start")));
cycle_starts_[hlo] = cycle_start;
if (first_read_cycle_start_ == nullptr) {
@@ -2625,20 +2563,20 @@ void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* ir_builder,
}
}
-void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* ir_builder,
+void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* b,
HloInstruction* hlo,
llvm::Value* prof_counter) {
- auto* cycle_end = ReadCycleCounter(ir_builder);
+ auto* cycle_end = ReadCycleCounter(b);
cycle_end->setName(AsStringRef(IrName(hlo, "cycle_end")));
auto* cycle_start = cycle_starts_[hlo];
- UpdateProfileCounter(ir_builder, prof_counter, cycle_end, cycle_start);
+ UpdateProfileCounter(b, prof_counter, cycle_end, cycle_start);
last_read_cycle_end_ = cycle_end;
}
void IrEmitter::ProfilingState::RecordCompleteComputation(
- llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter) {
+ llvm::IRBuilder<>* b, llvm::Value* prof_counter) {
if (last_read_cycle_end_ && first_read_cycle_start_) {
- UpdateProfileCounter(ir_builder, prof_counter, last_read_cycle_end_,
+ UpdateProfileCounter(b, prof_counter, last_read_cycle_end_,
first_read_cycle_start_);
}
}
@@ -2646,14 +2584,14 @@ void IrEmitter::ProfilingState::RecordCompleteComputation(
Status IrEmitter::Preprocess(HloInstruction* hlo) {
VLOG(3) << "Visiting: " << hlo->ToString();
if (instruction_to_profile_idx_.count(hlo)) {
- profiling_state_.RecordCycleStart(&ir_builder_, hlo);
+ profiling_state_.RecordCycleStart(&b_, hlo);
}
return Status::OK();
}
Status IrEmitter::Postprocess(HloInstruction* hlo) {
if (auto* prof_counter = GetProfileCounterFor(*hlo)) {
- profiling_state_.RecordCycleDelta(&ir_builder_, hlo, prof_counter);
+ profiling_state_.RecordCycleDelta(&b_, hlo, prof_counter);
}
return Status::OK();
}
@@ -2700,42 +2638,76 @@ llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
return compute_function_->exec_run_options_arg();
}
-llvm::Value* IrEmitter::EmitTempBufferPointer(
+llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape) {
- llvm::Type* element_type = IrShapeType(target_shape);
- // The alignment and number of bytes within the temporary buffer is determined
- // by the maximal shape as determined by buffer assignment.
- const BufferAllocation& allocation = assignment_.GetAllocation(slice.index());
- if (allocation.is_thread_local()) {
+ const BufferAllocation& allocation = *slice.allocation();
+ llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
+ if (slice == computation_root_allocation_) {
+ llvm::Argument* retval = compute_function_->result_arg();
+ llvm::AttrBuilder attr_builder;
+ attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
+ attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
+ retval->addAttrs(attr_builder);
+ return retval;
+ }
+
+ auto param_it =
+ computation_parameter_allocations_.find(slice.allocation()->index());
+ if (param_it != computation_parameter_allocations_.end()) {
+ int64 param_number = param_it->second;
+ // We have to access the parameter at offset param_number in the params
+ // array. The code generated here is equivalent to this C code:
+ //
+ // i8* param_address_untyped = params[param_number];
+ // Param* param_address_typed = (Param*)param_address_untyped;
+ //
+ // Where Param is the actual element type of the underlying buffer (for
+ // example, float for an XLA F32 element type).
+ llvm::Value* params = compute_function_->parameters_arg();
+ llvm::Value* param_address_offset =
+ llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
+ llvm::LoadInst* param_address_untyped =
+ b_.CreateLoad(param_address_offset);
+
+ if (!ShapeUtil::IsOpaque(target_shape)) {
+ AttachAlignmentMetadataForLoad(param_address_untyped, target_shape);
+ AttachDereferenceableMetadataForLoad(param_address_untyped,
+ target_shape);
+ }
+ return param_address_untyped;
+ }
+
// Thread-local allocations should only be assigned a single buffer.
const auto& assigned_buffers = allocation.assigned_buffers();
CHECK_EQ(1, assigned_buffers.size());
const Shape& shape = assigned_buffers.begin()->first->shape();
- llvm::AllocaInst*& tempbuf_address = thread_local_buffers_[{
- ir_builder_.GetInsertBlock()->getParent(), slice}];
- if (tempbuf_address == nullptr) {
- tempbuf_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ std::pair<llvm::Function*, BufferAllocation::Slice> key = {
+ compute_function_->function(), slice};
+ auto buf_it = thread_local_buffers_.find(key);
+ if (buf_it == thread_local_buffers_.end()) {
+ llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry(
IrShapeType(shape),
- tensorflow::strings::StrCat("thread_local", slice.ToString()),
- &ir_builder_, MinimumAlignmentForShape(target_shape));
+ tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_,
+ MinimumAlignmentForShape(target_shape));
+ auto it_inserted_pair = thread_local_buffers_.insert({key, buffer});
+ CHECK(it_inserted_pair.second);
+ buf_it = it_inserted_pair.first;
}
- return ir_builder_.CreateBitCast(tempbuf_address,
- element_type->getPointerTo());
- }
+ return buf_it->second;
+ }();
+ return b_.CreateBitCast(tempbuf_address,
+ IrShapeType(target_shape)->getPointerTo());
+}
+llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
+ const BufferAllocation::Slice& slice, const Shape& target_shape) {
+ const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
- GetTempBuffersArgument(), slice.index(), &ir_builder_);
- llvm::LoadInst* tempbuf_address_base =
- ir_builder_.CreateLoad(tempbuf_address_ptr);
- if (is_top_level_computation_ &&
- hlo_module_config_.debug_options()
+ GetTempBuffersArgument(), slice.index(), &b_);
+ llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr);
+ if (hlo_module_config_.debug_options()
.xla_llvm_enable_invariant_load_metadata()) {
- // In the entry computation the parameter slots in the %params argument are
- // invariant through program execution. In computations that are called
- // from the entry computation (via kWhile, kCall and kConditional) the
- // parameter slots are *not* invariant since they're written to by their
- // callers.
tempbuf_address_base->setMetadata(
llvm::LLVMContext::MD_invariant_load,
llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
@@ -2746,90 +2718,29 @@ llvm::Value* IrEmitter::EmitTempBufferPointer(
llvm::Value* tempbuf_address_untyped = tempbuf_address_base;
if (slice.offset() > 0) {
// Adjust the address to account for the slice offset.
- tempbuf_address_untyped = ir_builder_.CreateInBoundsGEP(
- tempbuf_address_base, ir_builder_.getInt64(slice.offset()));
+ tempbuf_address_untyped =
+ b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
}
- return ir_builder_.CreateBitCast(tempbuf_address_untyped,
- element_type->getPointerTo());
-}
-
-// Emits a function call returning a single array element. Allocates space
-// for a single element_type value, and loads it after call.
-llvm::Value* IrEmitter::EmitElementFunctionCall(
- llvm::Function* function, const Shape& return_shape,
- gtl::ArraySlice<llvm::Value*> parameter_addresses,
- tensorflow::StringPiece name) {
- llvm::Value* return_value_buffer = EmitArrayFunctionCall(
- function, return_shape, 1, parameter_addresses, name);
- return ir_builder_.CreateLoad(
- return_value_buffer,
- AsStringRef(tensorflow::strings::StrCat(name, "_return_value")));
-}
-
-// Emits a core function call based on the following pseudo-code.
-//
-// char** parameter_addresses_buffer =
-// allocate buffer with a pointer for each parameter to the function
-// for each parameter index, i.e. for i = 0, ..., #parameters:
-// parameter_addresses_buffer[i] = parameter_addresses[i]
-// call function(return_value_buffer,
-// parameter_addresses_buffer,
-// temps)
-// return return_value_buffer -- address of the return value.
-void IrEmitter::EmitArrayFunctionCallInto(
- llvm::Function* function, gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::Value* return_value_buffer, tensorflow::StringPiece name) {
- ir_builder_.CreateCall(
- function, GetArrayFunctionCallArguments(
- parameter_addresses, &ir_builder_, name,
- /*return_value_buffer=*/return_value_buffer,
- /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/GetTempBuffersArgument(),
- /*profile_counters_arg=*/GetProfileCountersArgument()));
+ return b_.CreateBitCast(tempbuf_address_untyped,
+ IrShapeType(target_shape)->getPointerTo());
}
-llvm::Value* IrEmitter::EmitArrayFunctionCall(
- llvm::Function* function, const Shape& return_shape, int64 element_count,
- gtl::ArraySlice<llvm::Value*> parameter_addresses,
- tensorflow::StringPiece name) {
- llvm::Value* elements =
- llvm::ConstantInt::get(ir_builder_.getInt64Ty(), element_count);
- PrimitiveType return_type = return_shape.element_type();
- llvm::Value* return_value_buffer =
- llvm_ir::EmitAllocaAtFunctionEntryWithCount(
- llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements,
- tensorflow::strings::StrCat(name, "_return_value_address"),
- &ir_builder_, MinimumAlignmentForPrimitiveType(return_type));
- EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer,
- name);
- return return_value_buffer;
+llvm::Value* IrEmitter::EmitTempBufferPointer(
+ const BufferAllocation::Slice& slice, const Shape& target_shape) {
+ if (slice.allocation()->is_thread_local()) {
+ return EmitThreadLocalTempBufferPointer(slice, target_shape);
+ } else if (slice.allocation()->is_constant()) {
+ return FindOrDie(constant_buffer_to_global_, slice.allocation()->index());
+ } else {
+ return EmitGlobalTempBufferPointer(slice, target_shape);
+ }
}
Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
- llvm::Value* addr;
const Shape& target_shape = op->shape();
- if (op == op->parent()->root_instruction()) {
- // For the root node, we write directly to the output buffer of the
- // function.
- llvm::Argument* retval = compute_function_->result_arg();
- if ((ShapeUtil::IsArray(target_shape) &&
- !ShapeUtil::IsZeroElementArray(target_shape)) ||
- (ShapeUtil::IsTuple(target_shape) &&
- !ShapeUtil::IsEmptyTuple(target_shape))) {
- llvm::AttrBuilder attr_builder;
- attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
- attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
- retval->addAttrs(attr_builder);
- }
- addr = ir_builder_.CreateBitCast(retval,
- IrShapeType(target_shape)->getPointerTo());
- } else {
- // For other nodes, we need the temporary buffer allocated for this node to
- // write the result into.
- TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
- assignment_.GetUniqueTopLevelSlice(op));
- addr = EmitTempBufferPointer(slice, target_shape);
- }
+ TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
+ assignment_.GetUniqueTopLevelSlice(op));
+ llvm::Value* addr = EmitTempBufferPointer(slice, target_shape);
addr->setName(AsStringRef(IrName(op)));
emitted_value_[op] = addr;
return Status::OK();
@@ -2864,14 +2775,14 @@ Status IrEmitter::EmitTargetElementLoop(
llvm_ir::IrArray(op_target_address, element_shape));
}
TF_RETURN_IF_ERROR(
- llvm_ir::LoopEmitter(element_generator, output_arrays, &ir_builder_)
+ llvm_ir::LoopEmitter(element_generator, output_arrays, &b_)
.EmitLoop(IrName(target_op)));
std::vector<llvm::Value*> tuple_operand_ptrs;
for (int64 i = 0; i < output_arrays.size(); ++i) {
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
}
- llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &ir_builder_, module_);
+ llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_, module_);
} else {
if (ShouldEmitParallelLoopFor(*target_op)) {
@@ -2880,11 +2791,11 @@ Status IrEmitter::EmitTargetElementLoop(
compute_function_->GetDynamicLoopBounds();
// Emit parallel loop with dynamic loop bounds for most-major dimensions.
TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array,
- &dynamic_loop_bounds, &ir_builder_)
+ &dynamic_loop_bounds, &b_)
.EmitLoop(IrName(target_op)));
} else {
TF_RETURN_IF_ERROR(
- llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_)
+ llvm_ir::LoopEmitter(element_generator, target_array, &b_)
.EmitLoop(IrName(target_op)));
}
}
@@ -2897,8 +2808,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source,
llvm::Value* destination_value = GetEmittedValueFor(&destination);
int64 source_size = ByteSizeOf(source.shape());
// TODO(b/63762267): Be more aggressive about specifying alignment.
- ir_builder_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value,
- /*SrcAlign=*/1, source_size);
+ b_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value,
+ /*SrcAlign=*/1, source_size);
return Status::OK();
}
@@ -2926,7 +2837,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
for (const HloInstruction* operand : hlo->operands()) {
operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
- return GetIrArrayFor(operand).EmitReadArrayElement(index, &ir_builder_);
+ return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
};
}
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
@@ -2934,20 +2845,69 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
}
-StatusOr<llvm::Value*> IrEmitter::EmitScalarCall(
- PrimitiveType return_type, HloComputation* computation,
- const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name) {
- llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation);
- std::vector<llvm::Value*> argument_addrs;
- for (auto argument : arguments) {
- llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry(
- argument->getType(), "arg_addr", &ir_builder_);
- ir_builder_.CreateStore(argument, argument_addr);
- argument_addrs.push_back(argument_addr);
+llvm::Value* IrEmitter::EmitThreadLocalCall(
+ const HloComputation& callee,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
+ tensorflow::StringPiece name) {
+ const Shape& return_shape = callee.root_instruction()->shape();
+
+ // Lifting this restriction to allow "small" arrays should be easy. Allowing
+ // larger arrays is difficult because we allocate the buffer for this return
+ // value on the stack.
+ CHECK(ShapeUtil::IsScalar(return_shape));
+
+ PrimitiveType return_type = return_shape.element_type();
+
+ std::vector<llvm::Value*> parameter_addrs;
+ for (llvm::Value* parameter : parameters) {
+ CHECK(!parameter->getType()->isPointerTy());
+ llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+ parameter->getType(), "arg_addr", &b_);
+ b_.CreateStore(parameter, parameter_addr);
+ parameter_addrs.push_back(parameter_addr);
+ }
+
+ llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(return_type, module_),
+ tensorflow::strings::StrCat(name, "_retval_addr"), &b_,
+ MinimumAlignmentForPrimitiveType(return_type));
+
+ b_.CreateCall(
+ FindOrDie(emitted_functions_, &callee),
+ GetArrayFunctionCallArguments(
+ parameter_addrs, &b_, name,
+ /*return_value_buffer=*/return_value_buffer,
+ /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
+ /*temp_buffers_arg=*/
+ llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
+ /*profile_counters_arg=*/GetProfileCountersArgument()));
+
+ return b_.CreateLoad(return_value_buffer);
+}
+
+void IrEmitter::EmitGlobalCall(const HloComputation& callee,
+ tensorflow::StringPiece name) {
+ b_.CreateCall(FindOrDie(emitted_functions_, &callee),
+ GetArrayFunctionCallArguments(
+ /*parameter_addresses=*/{}, &b_, name,
+ /*return_value_buffer=*/
+ llvm::Constant::getNullValue(b_.getInt8PtrTy()),
+ /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
+ /*temp_buffers_arg=*/GetTempBuffersArgument(),
+ /*profile_counters_arg=*/GetProfileCountersArgument()));
+}
+
+llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
+ const HloComputation& callee) {
+ const HloInstruction* root_inst = callee.root_instruction();
+ if (root_inst->opcode() == HloOpcode::kOutfeed) {
+ return llvm::Constant::getNullValue(b_.getInt8PtrTy());
}
- return EmitElementFunctionCall(llvm_function,
- ShapeUtil::MakeShape(return_type, {}),
- argument_addrs, name);
+
+ const BufferAllocation::Slice root_buffer =
+ assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie();
+ return EmitTempBufferPointer(root_buffer, root_inst->shape());
}
+
} // namespace cpu
} // namespace xla