aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-07-31 17:21:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-31 17:26:19 -0700
commit182a00ee781017932443bacb475af7acc4a56d5a (patch)
treeec24c5cc7e424642bf6772b4d46f7a08ac31dc88 /tensorflow/compiler/xla/service/cpu/ir_emitter.cc
parent64f191cdc0121bbcb322c3b11b160d638c2f4af9 (diff)
Automated rollback of commit fba2d773f45f10882aa475ac75cbf9884995d626
PiperOrigin-RevId: 206855848
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/ir_emitter.cc')
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc410
1 files changed, 230 insertions, 180 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 60f9cd1121..a6d8551841 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -116,19 +116,6 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
computation->root_instruction()->outer_dimension_partitions().size();
}
- if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) {
- TF_ASSIGN_OR_RETURN(
- computation_root_allocation_,
- assignment_.GetUniqueTopLevelSlice(computation->root_instruction()));
- }
-
- for (const HloInstruction* param : computation->parameter_instructions()) {
- TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice,
- assignment_.GetUniqueTopLevelSlice(param));
- computation_parameter_allocations_[param_slice.allocation()->index()] =
- param->parameter_number();
- }
-
InitializeIrFunction(function_name);
// The rdtscp instruction is x86 specific. We will fallback to LLVM's generic
// readcyclecounter if it is unavailable.
@@ -145,8 +132,6 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
// Delete 'compute_function', finalizing 'ir_function' and restoring caller
// IR insert point.
compute_function_.reset();
- computation_root_allocation_ = BufferAllocation::Slice();
- computation_parameter_allocations_.clear();
return ir_function;
}
@@ -499,11 +484,23 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
return Status::OK();
}
-llvm::Value* IrEmitter::EmitElementalMap(
- const HloMapInstruction& map_instr,
- tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
- tensorflow::StringPiece name) {
- return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name);
+StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForMap(
+ HloMapInstruction* map, const llvm_ir::IrArray::Index& index) {
+ llvm::Function* mapped_ir_function =
+ FindOrDie(emitted_functions_, map->to_apply());
+ std::vector<llvm::Value*> parameter_addresses;
+ for (const HloInstruction* operand : map->operands()) {
+ const llvm_ir::IrArray& array = GetIrArrayFor(operand);
+ parameter_addresses.push_back(array.EmitArrayElementAddress(index, &b_));
+ }
+ return EmitElementFunctionCall(mapped_ir_function, map->shape(),
+ parameter_addresses, "map_function");
+}
+
+Status IrEmitter::HandleMap(HloInstruction* map) {
+ return EmitTargetElementLoop(map, [&](const llvm_ir::IrArray::Index& index) {
+ return EmitTargetElementLoopBodyForMap(Cast<HloMapInstruction>(map), index);
+ });
}
StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
@@ -511,6 +508,9 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
const llvm_ir::IrArray::Index& index) {
const HloInstruction* operand = reduce_window->operand(0);
const Window& window = reduce_window->window();
+ HloComputation* function = reduce_window->to_apply();
+ // The called computation should have been emitted previously.
+ llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
// We fold inputs into the accumulator and initialize it to
// the initial value on the reduce_window.
@@ -563,10 +563,11 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
// We are not in the padding, so carry out the computation.
llvm_ir::IrArray input_array(GetIrArrayFor(operand));
- llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_);
- llvm::Value* result = EmitThreadLocalCall(
- *reduce_window->to_apply(),
- {b_.CreateLoad(accumulator_address), input_value}, "reducer_function");
+ llvm::Value* input_value_address =
+ input_array.EmitArrayElementAddress(input_index, &b_);
+ llvm::Value* result = EmitElementFunctionCall(
+ reducer_function, reduce_window->shape(),
+ {accumulator_address, input_value_address}, "reducer_function");
b_.CreateStore(result, accumulator_address);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
@@ -622,6 +623,12 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
"Dilation for SelectAndScatter is not implemented on CPU. ");
}
+ // The select and scatter computations should have been emitted previously.
+ llvm::Function* select_function =
+ FindOrDie(emitted_functions_, select_and_scatter->select());
+ llvm::Function* scatter_function =
+ FindOrDie(emitted_functions_, select_and_scatter->scatter());
+
// Pseudo code for select-and-scatter:
//
// initialized_flag is initially off for every window, and is turned on after
@@ -726,12 +733,11 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
// If the initialized_flag is true, call the `select` function to potentially
// update the selected value and index with the currently visiting operand.
SetToFirstInsertPoint(if_initialized.true_block, &b_);
+ const Shape output_shape = ShapeUtil::MakeShape(PRED, {});
llvm::Value* operand_address =
operand_array.EmitArrayElementAddress(operand_index, &b_);
- llvm::Value* operand_element = b_.CreateLoad(operand_address);
- llvm::Value* result = EmitThreadLocalCall(
- *select_and_scatter->select(),
- {b_.CreateLoad(selected_value_address), operand_element},
+ llvm::Value* result = EmitElementFunctionCall(
+ select_function, output_shape, {selected_value_address, operand_address},
"select_function");
// If the 'select' function returns false, update the selected value and the
@@ -758,14 +764,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
selected_index.push_back(b_.CreateLoad(selected_index_address_slot));
}
llvm_ir::IrArray source_array(GetIrArrayFor(source));
- llvm::Value* source_value =
- source_array.EmitReadArrayElement(source_index, &b_);
+ llvm::Value* source_value_address =
+ source_array.EmitArrayElementAddress(source_index, &b_);
llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter));
- llvm::Value* output_value =
- output_array.EmitReadArrayElement(selected_index, &b_);
- llvm::Value* scatter_value =
- EmitThreadLocalCall(*select_and_scatter->scatter(),
- {output_value, source_value}, "scatter_function");
+ llvm::Value* output_value_address =
+ output_array.EmitArrayElementAddress(selected_index, &b_);
+ llvm::Value* scatter_value = EmitElementFunctionCall(
+ scatter_function, source->shape(),
+ {output_value_address, source_value_address}, "scatter_function");
output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_);
SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_);
@@ -1242,7 +1248,46 @@ static llvm_ir::IrArray::Index FillReducedDimensionIndex(
Status IrEmitter::HandleParameter(HloInstruction* parameter) {
VLOG(2) << "HandleParameter: " << parameter->ToString();
- return EmitTargetAddressForOp(parameter);
+ auto param_number = parameter->parameter_number();
+ auto param_shape = parameter->shape();
+
+ // We have to access the parameter at offset param_number in the params
+ // array. The code generated here is equivalent to this C code:
+ //
+ // i8* param_address_untyped = params[param_number];
+ // Param* param_address_typed = (Param*)param_address_untyped;
+ //
+ // Where Param is the actual element type of the underlying buffer (for
+ // example, float for an XLA F32 element type).
+ llvm::Value* params = compute_function_->parameters_arg();
+ llvm::Value* param_address_offset =
+ llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
+ llvm::LoadInst* param_address_untyped = b_.CreateLoad(param_address_offset);
+ param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped")));
+ if (is_top_level_computation_ &&
+ hlo_module_config_.debug_options()
+ .xla_llvm_enable_invariant_load_metadata()) {
+ // In the entry computation the parameter slots in the %params argument are
+ // invariant through program execution. In computations that are called
+ // from the entry computation (via kWhile, kCall and kConditional) the
+ // parameter slots are *not* invariant since they're written to by their
+ // callers.
+ param_address_untyped->setMetadata(
+ llvm::LLVMContext::MD_invariant_load,
+ llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{}));
+ }
+
+ llvm::Value* param_address_typed = b_.CreateBitCast(
+ param_address_untyped, IrShapeType(param_shape)->getPointerTo());
+ emitted_value_[parameter] = param_address_typed;
+
+ if (!ShapeUtil::IsOpaque(param_shape)) {
+ AttachAlignmentMetadataForLoad(param_address_untyped, param_shape);
+ AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape);
+ }
+
+ VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*param_address_typed);
+ return Status::OK();
}
// Returns true if the relative order of the unreduced dimensions stays the same
@@ -1706,6 +1751,9 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
const HloInstruction* arg = reduce->mutable_operand(0);
const HloInstruction* init_value = reduce->mutable_operand(1);
gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ HloComputation* function = reduce->to_apply();
+ // The called computation should have been emitted previously.
+ llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
// Initialize an accumulator with init_value.
PrimitiveType accumulator_type = reduce->shape().element_type();
@@ -1745,9 +1793,10 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
CHECK(index.end() == it);
// Apply the reduction function to the loaded value.
- llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_);
- llvm::Value* result = EmitThreadLocalCall(
- *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element},
+ llvm::Value* input_address =
+ arg_array.EmitArrayElementAddress(input_index, &b_);
+ llvm::Value* result = EmitElementFunctionCall(
+ reducer_function, reduce->shape(), {accumulator_addr, input_address},
"reduce_function");
b_.CreateStore(result, accumulator_addr);
@@ -2085,13 +2134,18 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
HloComputation* computation = call->to_apply();
llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation);
+ std::vector<llvm::Value*> parameter_addresses;
+ for (const HloInstruction* operand : call->operands()) {
+ parameter_addresses.push_back(GetEmittedValueFor(operand));
+ }
+
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call));
if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
// ParallelTaskAssignment assigned partitions, emit call to
// ParallelForkJoin.
std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments(
- {}, &b_, computation->name(),
+ parameter_addresses, &b_, computation->name(),
/*return_value_buffer=*/emitted_value_[call],
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
/*temp_buffers_arg=*/GetTempBuffersArgument(),
@@ -2102,7 +2156,8 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
call_args, root->shape(), root->outer_dimension_partitions(), &b_,
call_ir_function, computation->name()));
} else {
- EmitGlobalCall(*computation, computation->name());
+ EmitArrayFunctionCallInto(call_ir_function, parameter_addresses,
+ emitted_value_[call], computation->name());
}
return Status::OK();
@@ -2183,6 +2238,12 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
const HloInstruction* init = xla_while->operand(0);
emitted_value_[xla_while] = GetEmittedValueFor(init);
+ // The called computation should have been emitted previously.
+ llvm::Function* condition_ir_function =
+ FindOrDie(emitted_functions_, condition);
+ llvm::Function* body_ir_function =
+ FindOrDie(emitted_functions_, xla_while->while_body());
+
// Generating:
// while (Condition(while_result)) {
// // CopyInsertion pass inserts copies which enable 'while_result' to
@@ -2199,10 +2260,12 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
// Calls the condition function to determine whether to proceed with the
// body. It must return a bool, so use the scalar call form.
- EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond"));
+ llvm::Value* while_result = GetEmittedValueFor(xla_while);
+ llvm::Value* while_condition = EmitElementFunctionCall(
+ condition_ir_function, condition->root_instruction()->shape(),
+ {while_result}, IrName(xla_while, "cond"));
llvm::Value* while_predicate = b_.CreateICmpNE(
- b_.CreateLoad(
- GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
+ while_condition,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
// Branches to the body or to the while exit depending on the condition.
@@ -2217,8 +2280,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
b_.SetInsertPoint(body_bb);
// Calls the body function.
- EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body"));
-
+ EmitArrayFunctionCallInto(body_ir_function, {while_result}, while_result,
+ IrName(xla_while, "body"));
// Finishes with a branch back to the header.
b_.CreateBr(header_bb);
@@ -2386,6 +2449,8 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
Status IrEmitter::HandleConditional(HloInstruction* conditional) {
auto pred = conditional->operand(0);
+ auto true_arg = conditional->operand(1);
+ auto false_arg = conditional->operand(2);
TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) &&
pred->shape().element_type() == PRED)
<< "Predicate on a Conditional must be bool; got: "
@@ -2407,7 +2472,13 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
<< " and "
<< ShapeUtil::HumanString(false_computation->root_instruction()->shape());
+ llvm::Function* true_function =
+ FindOrDie(emitted_functions_, true_computation);
+ llvm::Function* false_function =
+ FindOrDie(emitted_functions_, false_computation);
+
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional));
+ llvm::Value* conditional_result = GetEmittedValueFor(conditional);
// Generating:
// if (pred)
@@ -2424,12 +2495,12 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_);
SetToFirstInsertPoint(if_data.true_block, &b_);
- EmitGlobalCall(*conditional->true_computation(),
- IrName(conditional, "_true"));
+ EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)},
+ conditional_result, IrName(conditional, "_true"));
SetToFirstInsertPoint(if_data.false_block, &b_);
- EmitGlobalCall(*conditional->false_computation(),
- IrName(conditional, "_false"));
+ EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)},
+ conditional_result, IrName(conditional, "_false"));
SetToFirstInsertPoint(if_data.after_block, &b_);
return Status::OK();
@@ -2630,76 +2701,44 @@ llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
return compute_function_->exec_run_options_arg();
}
-llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
+llvm::Value* IrEmitter::EmitTempBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape) {
- const BufferAllocation& allocation = *slice.allocation();
- llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
- if (slice == computation_root_allocation_) {
- llvm::Argument* retval = compute_function_->result_arg();
- llvm::AttrBuilder attr_builder;
- attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
- attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
- retval->addAttrs(attr_builder);
- return retval;
- }
-
- auto param_it =
- computation_parameter_allocations_.find(slice.allocation()->index());
- if (param_it != computation_parameter_allocations_.end()) {
- int64 param_number = param_it->second;
- // We have to access the parameter at offset param_number in the params
- // array. The code generated here is equivalent to this C code:
- //
- // i8* param_address_untyped = params[param_number];
- // Param* param_address_typed = (Param*)param_address_untyped;
- //
- // Where Param is the actual element type of the underlying buffer (for
- // example, float for an XLA F32 element type).
- llvm::Value* params = compute_function_->parameters_arg();
- llvm::Value* param_address_offset =
- llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
- llvm::LoadInst* param_address_untyped =
- b_.CreateLoad(param_address_offset);
-
- if (!ShapeUtil::IsOpaque(target_shape)) {
- AttachAlignmentMetadataForLoad(param_address_untyped, target_shape);
- AttachDereferenceableMetadataForLoad(param_address_untyped,
- target_shape);
- }
- return param_address_untyped;
- }
-
+ llvm::Type* element_type = IrShapeType(target_shape);
+ // The alignment and number of bytes within the temporary buffer is determined
+ // by the maximal shape as determined by buffer assignment.
+ const BufferAllocation& allocation = assignment_.GetAllocation(slice.index());
+ if (allocation.is_thread_local()) {
// Thread-local allocations should only be assigned a single buffer.
const auto& assigned_buffers = allocation.assigned_buffers();
CHECK_EQ(1, assigned_buffers.size());
const Shape& shape = assigned_buffers.begin()->first->shape();
- std::pair<llvm::Function*, BufferAllocation::Slice> key = {
- compute_function_->function(), slice};
- auto buf_it = thread_local_buffers_.find(key);
- if (buf_it == thread_local_buffers_.end()) {
- llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm::AllocaInst*& tempbuf_address =
+ thread_local_buffers_[{b_.GetInsertBlock()->getParent(), slice}];
+ if (tempbuf_address == nullptr) {
+ tempbuf_address = llvm_ir::EmitAllocaAtFunctionEntry(
IrShapeType(shape),
tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_,
MinimumAlignmentForShape(target_shape));
- auto it_inserted_pair = thread_local_buffers_.insert({key, buffer});
- CHECK(it_inserted_pair.second);
- buf_it = it_inserted_pair.first;
}
- return buf_it->second;
- }();
- return b_.CreateBitCast(tempbuf_address,
- IrShapeType(target_shape)->getPointerTo());
-}
+ return b_.CreateBitCast(tempbuf_address, element_type->getPointerTo());
+ }
+
+ if (allocation.is_constant()) {
+ return FindOrDie(constant_buffer_to_global_, allocation.index());
+ }
-llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
- const BufferAllocation::Slice& slice, const Shape& target_shape) {
- const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
GetTempBuffersArgument(), slice.index(), &b_);
llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr);
- if (hlo_module_config_.debug_options()
+ if (is_top_level_computation_ &&
+ hlo_module_config_.debug_options()
.xla_llvm_enable_invariant_load_metadata()) {
+ // In the entry computation the parameter slots in the %params argument are
+ // invariant through program execution. In computations that are called
+ // from the entry computation (via kWhile, kCall and kConditional) the
+ // parameter slots are *not* invariant since they're written to by their
+ // callers.
tempbuf_address_base->setMetadata(
llvm::LLVMContext::MD_invariant_load,
llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
@@ -2714,25 +2753,85 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
}
return b_.CreateBitCast(tempbuf_address_untyped,
- IrShapeType(target_shape)->getPointerTo());
+ element_type->getPointerTo());
}
-llvm::Value* IrEmitter::EmitTempBufferPointer(
- const BufferAllocation::Slice& slice, const Shape& target_shape) {
- if (slice.allocation()->is_thread_local()) {
- return EmitThreadLocalTempBufferPointer(slice, target_shape);
- } else if (slice.allocation()->is_constant()) {
- return FindOrDie(constant_buffer_to_global_, slice.allocation()->index());
- } else {
- return EmitGlobalTempBufferPointer(slice, target_shape);
- }
+// Emits a function call returning a single array element. Allocates space
+// for a single element_type value, and loads it after call.
+llvm::Value* IrEmitter::EmitElementFunctionCall(
+ llvm::Function* function, const Shape& return_shape,
+ gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ tensorflow::StringPiece name) {
+ llvm::Value* return_value_buffer = EmitArrayFunctionCall(
+ function, return_shape, 1, parameter_addresses, name);
+ return b_.CreateLoad(
+ return_value_buffer,
+ AsStringRef(tensorflow::strings::StrCat(name, "_return_value")));
+}
+
+// Emits a core function call based on the following pseudo-code.
+//
+// char** parameter_addresses_buffer =
+// allocate buffer with a pointer for each parameter to the function
+// for each parameter index, i.e. for i = 0, ..., #parameters:
+// parameter_addresses_buffer[i] = parameter_addresses[i]
+// call function(return_value_buffer,
+// parameter_addresses_buffer,
+// temps)
+// return return_value_buffer -- address of the return value.
+void IrEmitter::EmitArrayFunctionCallInto(
+ llvm::Function* function, gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ llvm::Value* return_value_buffer, tensorflow::StringPiece name) {
+ b_.CreateCall(function,
+ GetArrayFunctionCallArguments(
+ parameter_addresses, &b_, name,
+ /*return_value_buffer=*/return_value_buffer,
+ /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
+ /*temp_buffers_arg=*/GetTempBuffersArgument(),
+ /*profile_counters_arg=*/GetProfileCountersArgument()));
+}
+
+llvm::Value* IrEmitter::EmitArrayFunctionCall(
+ llvm::Function* function, const Shape& return_shape, int64 element_count,
+ gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ tensorflow::StringPiece name) {
+ llvm::Value* elements =
+ llvm::ConstantInt::get(b_.getInt64Ty(), element_count);
+ PrimitiveType return_type = return_shape.element_type();
+ llvm::Value* return_value_buffer =
+ llvm_ir::EmitAllocaAtFunctionEntryWithCount(
+ llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements,
+ tensorflow::strings::StrCat(name, "_return_value_address"), &b_,
+ MinimumAlignmentForPrimitiveType(return_type));
+ EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer,
+ name);
+ return return_value_buffer;
}
Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
+ llvm::Value* addr;
const Shape& target_shape = op->shape();
- TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
- assignment_.GetUniqueTopLevelSlice(op));
- llvm::Value* addr = EmitTempBufferPointer(slice, target_shape);
+ if (op == op->parent()->root_instruction()) {
+ // For the root node, we write directly to the output buffer of the
+ // function.
+ llvm::Argument* retval = compute_function_->result_arg();
+ if ((ShapeUtil::IsArray(target_shape) &&
+ !ShapeUtil::IsZeroElementArray(target_shape)) ||
+ (ShapeUtil::IsTuple(target_shape) &&
+ !ShapeUtil::IsEmptyTuple(target_shape))) {
+ llvm::AttrBuilder attr_builder;
+ attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
+ attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
+ retval->addAttrs(attr_builder);
+ }
+ addr = b_.CreateBitCast(retval, IrShapeType(target_shape)->getPointerTo());
+ } else {
+ // For other nodes, we need the temporary buffer allocated for this node to
+ // write the result into.
+ TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
+ assignment_.GetUniqueTopLevelSlice(op));
+ addr = EmitTempBufferPointer(slice, target_shape);
+ }
addr->setName(AsStringRef(IrName(op)));
emitted_value_[op] = addr;
return Status::OK();
@@ -2837,69 +2936,20 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
}
-llvm::Value* IrEmitter::EmitThreadLocalCall(
- const HloComputation& callee,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
- tensorflow::StringPiece name) {
- const Shape& return_shape = callee.root_instruction()->shape();
-
- // Lifting this restriction to allow "small" arrays should be easy. Allowing
- // larger arrays is difficult because we allocate the buffer for this return
- // value on the stack.
- CHECK(ShapeUtil::IsScalar(return_shape));
-
- PrimitiveType return_type = return_shape.element_type();
-
- std::vector<llvm::Value*> parameter_addrs;
- for (llvm::Value* parameter : parameters) {
- CHECK(!parameter->getType()->isPointerTy());
- llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry(
- parameter->getType(), "arg_addr", &b_);
- b_.CreateStore(parameter, parameter_addr);
- parameter_addrs.push_back(parameter_addr);
+StatusOr<llvm::Value*> IrEmitter::EmitScalarCall(
+ PrimitiveType return_type, HloComputation* computation,
+ const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name) {
+ llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation);
+ std::vector<llvm::Value*> argument_addrs;
+ for (auto argument : arguments) {
+ llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+ argument->getType(), "arg_addr", &b_);
+ b_.CreateStore(argument, argument_addr);
+ argument_addrs.push_back(argument_addr);
}
-
- llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
- llvm_ir::PrimitiveTypeToIrType(return_type, module_),
- tensorflow::strings::StrCat(name, "_retval_addr"), &b_,
- MinimumAlignmentForPrimitiveType(return_type));
-
- b_.CreateCall(
- FindOrDie(emitted_functions_, &callee),
- GetArrayFunctionCallArguments(
- parameter_addrs, &b_, name,
- /*return_value_buffer=*/return_value_buffer,
- /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/
- llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
- /*profile_counters_arg=*/GetProfileCountersArgument()));
-
- return b_.CreateLoad(return_value_buffer);
+ return EmitElementFunctionCall(llvm_function,
+ ShapeUtil::MakeShape(return_type, {}),
+ argument_addrs, name);
}
-
-void IrEmitter::EmitGlobalCall(const HloComputation& callee,
- tensorflow::StringPiece name) {
- b_.CreateCall(FindOrDie(emitted_functions_, &callee),
- GetArrayFunctionCallArguments(
- /*parameter_addresses=*/{}, &b_, name,
- /*return_value_buffer=*/
- llvm::Constant::getNullValue(b_.getInt8PtrTy()),
- /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/GetTempBuffersArgument(),
- /*profile_counters_arg=*/GetProfileCountersArgument()));
-}
-
-llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
- const HloComputation& callee) {
- const HloInstruction* root_inst = callee.root_instruction();
- if (root_inst->opcode() == HloOpcode::kOutfeed) {
- return llvm::Constant::getNullValue(b_.getInt8PtrTy());
- }
-
- const BufferAllocation::Slice root_buffer =
- assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie();
- return EmitTempBufferPointer(root_buffer, root_inst->shape());
-}
-
} // namespace cpu
} // namespace xla