diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/ir_emitter.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/ir_emitter.cc | 410 |
1 files changed, 230 insertions, 180 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 60f9cd1121..a6d8551841 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -116,19 +116,6 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation( computation->root_instruction()->outer_dimension_partitions().size(); } - if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) { - TF_ASSIGN_OR_RETURN( - computation_root_allocation_, - assignment_.GetUniqueTopLevelSlice(computation->root_instruction())); - } - - for (const HloInstruction* param : computation->parameter_instructions()) { - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice, - assignment_.GetUniqueTopLevelSlice(param)); - computation_parameter_allocations_[param_slice.allocation()->index()] = - param->parameter_number(); - } - InitializeIrFunction(function_name); // The rdtscp instruction is x86 specific. We will fallback to LLVM's generic // readcyclecounter if it is unavailable. @@ -145,8 +132,6 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation( // Delete 'compute_function', finalizing 'ir_function' and restoring caller // IR insert point. compute_function_.reset(); - computation_root_allocation_ = BufferAllocation::Slice(); - computation_parameter_allocations_.clear(); return ir_function; } @@ -499,11 +484,23 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -llvm::Value* IrEmitter::EmitElementalMap( - const HloMapInstruction& map_instr, - tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands, - tensorflow::StringPiece name) { - return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name); +StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForMap( + HloMapInstruction* map, const llvm_ir::IrArray::Index& index) { + llvm::Function* mapped_ir_function = + FindOrDie(emitted_functions_, map->to_apply()); + std::vector<llvm::Value*> parameter_addresses; + for (const HloInstruction* operand : map->operands()) { + const llvm_ir::IrArray& array = GetIrArrayFor(operand); + parameter_addresses.push_back(array.EmitArrayElementAddress(index, &b_)); + } + return EmitElementFunctionCall(mapped_ir_function, map->shape(), + parameter_addresses, "map_function"); +} + +Status IrEmitter::HandleMap(HloInstruction* map) { + return EmitTargetElementLoop(map, [&](const llvm_ir::IrArray::Index& index) { + return EmitTargetElementLoopBodyForMap(Cast<HloMapInstruction>(map), index); + }); } StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow( @@ -511,6 +508,9 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow( const llvm_ir::IrArray::Index& index) { const HloInstruction* operand = reduce_window->operand(0); const Window& window = reduce_window->window(); + HloComputation* function = reduce_window->to_apply(); + // The called computation should have been emitted previously. + llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); // We fold inputs into the accumulator and initialize it to // the initial value on the reduce_window. @@ -563,10 +563,11 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow( // We are not in the padding, so carry out the computation. llvm_ir::IrArray input_array(GetIrArrayFor(operand)); - llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_); - llvm::Value* result = EmitThreadLocalCall( - *reduce_window->to_apply(), - {b_.CreateLoad(accumulator_address), input_value}, "reducer_function"); + llvm::Value* input_value_address = + input_array.EmitArrayElementAddress(input_index, &b_); + llvm::Value* result = EmitElementFunctionCall( + reducer_function, reduce_window->shape(), + {accumulator_address, input_value_address}, "reducer_function"); b_.CreateStore(result, accumulator_address); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); @@ -622,6 +623,12 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { "Dilation for SelectAndScatter is not implemented on CPU. "); } + // The select and scatter computations should have been emitted previously. + llvm::Function* select_function = + FindOrDie(emitted_functions_, select_and_scatter->select()); + llvm::Function* scatter_function = + FindOrDie(emitted_functions_, select_and_scatter->scatter()); + // Pseudo code for select-and-scatter: // // initialized_flag is initially off for every window, and is turned on after @@ -726,12 +733,11 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // If the initialized_flag is true, call the `select` function to potentially // update the selected value and index with the currently visiting operand. SetToFirstInsertPoint(if_initialized.true_block, &b_); + const Shape output_shape = ShapeUtil::MakeShape(PRED, {}); llvm::Value* operand_address = operand_array.EmitArrayElementAddress(operand_index, &b_); - llvm::Value* operand_element = b_.CreateLoad(operand_address); - llvm::Value* result = EmitThreadLocalCall( - *select_and_scatter->select(), - {b_.CreateLoad(selected_value_address), operand_element}, + llvm::Value* result = EmitElementFunctionCall( + select_function, output_shape, {selected_value_address, operand_address}, "select_function"); // If the 'select' function returns false, update the selected value and the @@ -758,14 +764,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); } llvm_ir::IrArray source_array(GetIrArrayFor(source)); - llvm::Value* source_value = - source_array.EmitReadArrayElement(source_index, &b_); + llvm::Value* source_value_address = + source_array.EmitArrayElementAddress(source_index, &b_); llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter)); - llvm::Value* output_value = - output_array.EmitReadArrayElement(selected_index, &b_); - llvm::Value* scatter_value = - EmitThreadLocalCall(*select_and_scatter->scatter(), - {output_value, source_value}, "scatter_function"); + llvm::Value* output_value_address = + output_array.EmitArrayElementAddress(selected_index, &b_); + llvm::Value* scatter_value = EmitElementFunctionCall( + scatter_function, source->shape(), + {output_value_address, source_value_address}, "scatter_function"); output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_); SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_); @@ -1242,7 +1248,46 @@ static llvm_ir::IrArray::Index FillReducedDimensionIndex( Status IrEmitter::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); - return EmitTargetAddressForOp(parameter); + auto param_number = parameter->parameter_number(); + auto param_shape = parameter->shape(); + + // We have to access the parameter at offset param_number in the params + // array. The code generated here is equivalent to this C code: + // + // i8* param_address_untyped = params[param_number]; + // Param* param_address_typed = (Param*)param_address_untyped; + // + // Where Param is the actual element type of the underlying buffer (for + // example, float for an XLA F32 element type). + llvm::Value* params = compute_function_->parameters_arg(); + llvm::Value* param_address_offset = + llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); + llvm::LoadInst* param_address_untyped = b_.CreateLoad(param_address_offset); + param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped"))); + if (is_top_level_computation_ && + hlo_module_config_.debug_options() + .xla_llvm_enable_invariant_load_metadata()) { + // In the entry computation the parameter slots in the %params argument are + // invariant through program execution. In computations that are called + // from the entry computation (via kWhile, kCall and kConditional) the + // parameter slots are *not* invariant since they're written to by their + // callers. + param_address_untyped->setMetadata( + llvm::LLVMContext::MD_invariant_load, + llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{})); + } + + llvm::Value* param_address_typed = b_.CreateBitCast( + param_address_untyped, IrShapeType(param_shape)->getPointerTo()); + emitted_value_[parameter] = param_address_typed; + + if (!ShapeUtil::IsOpaque(param_shape)) { + AttachAlignmentMetadataForLoad(param_address_untyped, param_shape); + AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape); + } + + VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*param_address_typed); + return Status::OK(); } // Returns true if the relative order of the unreduced dimensions stays the same @@ -1706,6 +1751,9 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce( const HloInstruction* arg = reduce->mutable_operand(0); const HloInstruction* init_value = reduce->mutable_operand(1); gtl::ArraySlice<int64> dimensions(reduce->dimensions()); + HloComputation* function = reduce->to_apply(); + // The called computation should have been emitted previously. + llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); // Initialize an accumulator with init_value. PrimitiveType accumulator_type = reduce->shape().element_type(); @@ -1745,9 +1793,10 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce( CHECK(index.end() == it); // Apply the reduction function to the loaded value. - llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_); - llvm::Value* result = EmitThreadLocalCall( - *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element}, + llvm::Value* input_address = + arg_array.EmitArrayElementAddress(input_index, &b_); + llvm::Value* result = EmitElementFunctionCall( + reducer_function, reduce->shape(), {accumulator_addr, input_address}, "reduce_function"); b_.CreateStore(result, accumulator_addr); @@ -2085,13 +2134,18 @@ Status IrEmitter::HandleCall(HloInstruction* call) { HloComputation* computation = call->to_apply(); llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation); + std::vector<llvm::Value*> parameter_addresses; + for (const HloInstruction* operand : call->operands()) { + parameter_addresses.push_back(GetEmittedValueFor(operand)); + } + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call)); if (!computation->root_instruction()->outer_dimension_partitions().empty()) { // ParallelTaskAssignment assigned partitions, emit call to // ParallelForkJoin. std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments( - {}, &b_, computation->name(), + parameter_addresses, &b_, computation->name(), /*return_value_buffer=*/emitted_value_[call], /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), /*temp_buffers_arg=*/GetTempBuffersArgument(), @@ -2102,7 +2156,8 @@ Status IrEmitter::HandleCall(HloInstruction* call) { call_args, root->shape(), root->outer_dimension_partitions(), &b_, call_ir_function, computation->name())); } else { - EmitGlobalCall(*computation, computation->name()); + EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, + emitted_value_[call], computation->name()); } return Status::OK(); @@ -2183,6 +2238,12 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { const HloInstruction* init = xla_while->operand(0); emitted_value_[xla_while] = GetEmittedValueFor(init); + // The called computation should have been emitted previously. + llvm::Function* condition_ir_function = + FindOrDie(emitted_functions_, condition); + llvm::Function* body_ir_function = + FindOrDie(emitted_functions_, xla_while->while_body()); + // Generating: // while (Condition(while_result)) { // // CopyInsertion pass inserts copies which enable 'while_result' to @@ -2199,10 +2260,12 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Calls the condition function to determine whether to proceed with the // body. It must return a bool, so use the scalar call form. - EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond")); + llvm::Value* while_result = GetEmittedValueFor(xla_while); + llvm::Value* while_condition = EmitElementFunctionCall( + condition_ir_function, condition->root_instruction()->shape(), + {while_result}, IrName(xla_while, "cond")); llvm::Value* while_predicate = b_.CreateICmpNE( - b_.CreateLoad( - GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), + while_condition, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); // Branches to the body or to the while exit depending on the condition. @@ -2217,8 +2280,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { b_.SetInsertPoint(body_bb); // Calls the body function. - EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body")); - + EmitArrayFunctionCallInto(body_ir_function, {while_result}, while_result, + IrName(xla_while, "body")); // Finishes with a branch back to the header. b_.CreateBr(header_bb); @@ -2386,6 +2449,8 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { Status IrEmitter::HandleConditional(HloInstruction* conditional) { auto pred = conditional->operand(0); + auto true_arg = conditional->operand(1); + auto false_arg = conditional->operand(2); TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) && pred->shape().element_type() == PRED) << "Predicate on a Conditional must be bool; got: " @@ -2407,7 +2472,13 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { << " and " << ShapeUtil::HumanString(false_computation->root_instruction()->shape()); + llvm::Function* true_function = + FindOrDie(emitted_functions_, true_computation); + llvm::Function* false_function = + FindOrDie(emitted_functions_, false_computation); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional)); + llvm::Value* conditional_result = GetEmittedValueFor(conditional); // Generating: // if (pred) @@ -2424,12 +2495,12 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_); SetToFirstInsertPoint(if_data.true_block, &b_); - EmitGlobalCall(*conditional->true_computation(), - IrName(conditional, "_true")); + EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)}, + conditional_result, IrName(conditional, "_true")); SetToFirstInsertPoint(if_data.false_block, &b_); - EmitGlobalCall(*conditional->false_computation(), - IrName(conditional, "_false")); + EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)}, + conditional_result, IrName(conditional, "_false")); SetToFirstInsertPoint(if_data.after_block, &b_); return Status::OK(); @@ -2630,76 +2701,44 @@ llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { return compute_function_->exec_run_options_arg(); } -llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( +llvm::Value* IrEmitter::EmitTempBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { - const BufferAllocation& allocation = *slice.allocation(); - llvm::Value* tempbuf_address = [&]() -> llvm::Value* { - if (slice == computation_root_allocation_) { - llvm::Argument* retval = compute_function_->result_arg(); - llvm::AttrBuilder attr_builder; - attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); - attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); - retval->addAttrs(attr_builder); - return retval; - } - - auto param_it = - computation_parameter_allocations_.find(slice.allocation()->index()); - if (param_it != computation_parameter_allocations_.end()) { - int64 param_number = param_it->second; - // We have to access the parameter at offset param_number in the params - // array. The code generated here is equivalent to this C code: - // - // i8* param_address_untyped = params[param_number]; - // Param* param_address_typed = (Param*)param_address_untyped; - // - // Where Param is the actual element type of the underlying buffer (for - // example, float for an XLA F32 element type). - llvm::Value* params = compute_function_->parameters_arg(); - llvm::Value* param_address_offset = - llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); - llvm::LoadInst* param_address_untyped = - b_.CreateLoad(param_address_offset); - - if (!ShapeUtil::IsOpaque(target_shape)) { - AttachAlignmentMetadataForLoad(param_address_untyped, target_shape); - AttachDereferenceableMetadataForLoad(param_address_untyped, - target_shape); - } - return param_address_untyped; - } - + llvm::Type* element_type = IrShapeType(target_shape); + // The alignment and number of bytes within the temporary buffer is determined + // by the maximal shape as determined by buffer assignment. + const BufferAllocation& allocation = assignment_.GetAllocation(slice.index()); + if (allocation.is_thread_local()) { // Thread-local allocations should only be assigned a single buffer. const auto& assigned_buffers = allocation.assigned_buffers(); CHECK_EQ(1, assigned_buffers.size()); const Shape& shape = assigned_buffers.begin()->first->shape(); - std::pair<llvm::Function*, BufferAllocation::Slice> key = { - compute_function_->function(), slice}; - auto buf_it = thread_local_buffers_.find(key); - if (buf_it == thread_local_buffers_.end()) { - llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry( + llvm::AllocaInst*& tempbuf_address = + thread_local_buffers_[{b_.GetInsertBlock()->getParent(), slice}]; + if (tempbuf_address == nullptr) { + tempbuf_address = llvm_ir::EmitAllocaAtFunctionEntry( IrShapeType(shape), tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_, MinimumAlignmentForShape(target_shape)); - auto it_inserted_pair = thread_local_buffers_.insert({key, buffer}); - CHECK(it_inserted_pair.second); - buf_it = it_inserted_pair.first; } - return buf_it->second; - }(); - return b_.CreateBitCast(tempbuf_address, - IrShapeType(target_shape)->getPointerTo()); -} + return b_.CreateBitCast(tempbuf_address, element_type->getPointerTo()); + } + + if (allocation.is_constant()) { + return FindOrDie(constant_buffer_to_global_, allocation.index()); + } -llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( - const BufferAllocation::Slice& slice, const Shape& target_shape) { - const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP( GetTempBuffersArgument(), slice.index(), &b_); llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr); - if (hlo_module_config_.debug_options() + if (is_top_level_computation_ && + hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { + // In the entry computation the parameter slots in the %params argument are + // invariant through program execution. In computations that are called + // from the entry computation (via kWhile, kCall and kConditional) the + // parameter slots are *not* invariant since they're written to by their + // callers. tempbuf_address_base->setMetadata( llvm::LLVMContext::MD_invariant_load, llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); @@ -2714,25 +2753,85 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset())); } return b_.CreateBitCast(tempbuf_address_untyped, - IrShapeType(target_shape)->getPointerTo()); + element_type->getPointerTo()); } -llvm::Value* IrEmitter::EmitTempBufferPointer( - const BufferAllocation::Slice& slice, const Shape& target_shape) { - if (slice.allocation()->is_thread_local()) { - return EmitThreadLocalTempBufferPointer(slice, target_shape); - } else if (slice.allocation()->is_constant()) { - return FindOrDie(constant_buffer_to_global_, slice.allocation()->index()); - } else { - return EmitGlobalTempBufferPointer(slice, target_shape); - } +// Emits a function call returning a single array element. Allocates space +// for a single element_type value, and loads it after call. +llvm::Value* IrEmitter::EmitElementFunctionCall( + llvm::Function* function, const Shape& return_shape, + gtl::ArraySlice<llvm::Value*> parameter_addresses, + tensorflow::StringPiece name) { + llvm::Value* return_value_buffer = EmitArrayFunctionCall( + function, return_shape, 1, parameter_addresses, name); + return b_.CreateLoad( + return_value_buffer, + AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); +} + +// Emits a core function call based on the following pseudo-code. +// +// char** parameter_addresses_buffer = +// allocate buffer with a pointer for each parameter to the function +// for each parameter index, i.e. for i = 0, ..., #parameters: +// parameter_addresses_buffer[i] = parameter_addresses[i] +// call function(return_value_buffer, +// parameter_addresses_buffer, +// temps) +// return return_value_buffer -- address of the return value. +void IrEmitter::EmitArrayFunctionCallInto( + llvm::Function* function, gtl::ArraySlice<llvm::Value*> parameter_addresses, + llvm::Value* return_value_buffer, tensorflow::StringPiece name) { + b_.CreateCall(function, + GetArrayFunctionCallArguments( + parameter_addresses, &b_, name, + /*return_value_buffer=*/return_value_buffer, + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument())); +} + +llvm::Value* IrEmitter::EmitArrayFunctionCall( + llvm::Function* function, const Shape& return_shape, int64 element_count, + gtl::ArraySlice<llvm::Value*> parameter_addresses, + tensorflow::StringPiece name) { + llvm::Value* elements = + llvm::ConstantInt::get(b_.getInt64Ty(), element_count); + PrimitiveType return_type = return_shape.element_type(); + llvm::Value* return_value_buffer = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements, + tensorflow::strings::StrCat(name, "_return_value_address"), &b_, + MinimumAlignmentForPrimitiveType(return_type)); + EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer, + name); + return return_value_buffer; } Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { + llvm::Value* addr; const Shape& target_shape = op->shape(); - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - assignment_.GetUniqueTopLevelSlice(op)); - llvm::Value* addr = EmitTempBufferPointer(slice, target_shape); + if (op == op->parent()->root_instruction()) { + // For the root node, we write directly to the output buffer of the + // function. + llvm::Argument* retval = compute_function_->result_arg(); + if ((ShapeUtil::IsArray(target_shape) && + !ShapeUtil::IsZeroElementArray(target_shape)) || + (ShapeUtil::IsTuple(target_shape) && + !ShapeUtil::IsEmptyTuple(target_shape))) { + llvm::AttrBuilder attr_builder; + attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); + attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); + retval->addAttrs(attr_builder); + } + addr = b_.CreateBitCast(retval, IrShapeType(target_shape)->getPointerTo()); + } else { + // For other nodes, we need the temporary buffer allocated for this node to + // write the result into. + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + assignment_.GetUniqueTopLevelSlice(op)); + addr = EmitTempBufferPointer(slice, target_shape); + } addr->setName(AsStringRef(IrName(op))); emitted_value_[op] = addr; return Status::OK(); @@ -2837,69 +2936,20 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator)); } -llvm::Value* IrEmitter::EmitThreadLocalCall( - const HloComputation& callee, - tensorflow::gtl::ArraySlice<llvm::Value*> parameters, - tensorflow::StringPiece name) { - const Shape& return_shape = callee.root_instruction()->shape(); - - // Lifting this restriction to allow "small" arrays should be easy. Allowing - // larger arrays is difficult because we allocate the buffer for this return - // value on the stack. - CHECK(ShapeUtil::IsScalar(return_shape)); - - PrimitiveType return_type = return_shape.element_type(); - - std::vector<llvm::Value*> parameter_addrs; - for (llvm::Value* parameter : parameters) { - CHECK(!parameter->getType()->isPointerTy()); - llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry( - parameter->getType(), "arg_addr", &b_); - b_.CreateStore(parameter, parameter_addr); - parameter_addrs.push_back(parameter_addr); +StatusOr<llvm::Value*> IrEmitter::EmitScalarCall( + PrimitiveType return_type, HloComputation* computation, + const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name) { + llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation); + std::vector<llvm::Value*> argument_addrs; + for (auto argument : arguments) { + llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry( + argument->getType(), "arg_addr", &b_); + b_.CreateStore(argument, argument_addr); + argument_addrs.push_back(argument_addr); } - - llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(return_type, module_), - tensorflow::strings::StrCat(name, "_retval_addr"), &b_, - MinimumAlignmentForPrimitiveType(return_type)); - - b_.CreateCall( - FindOrDie(emitted_functions_, &callee), - GetArrayFunctionCallArguments( - parameter_addrs, &b_, name, - /*return_value_buffer=*/return_value_buffer, - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/ - llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), - /*profile_counters_arg=*/GetProfileCountersArgument())); - - return b_.CreateLoad(return_value_buffer); + return EmitElementFunctionCall(llvm_function, + ShapeUtil::MakeShape(return_type, {}), + argument_addrs, name); } - -void IrEmitter::EmitGlobalCall(const HloComputation& callee, - tensorflow::StringPiece name) { - b_.CreateCall(FindOrDie(emitted_functions_, &callee), - GetArrayFunctionCallArguments( - /*parameter_addresses=*/{}, &b_, name, - /*return_value_buffer=*/ - llvm::Constant::getNullValue(b_.getInt8PtrTy()), - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), - /*profile_counters_arg=*/GetProfileCountersArgument())); -} - -llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( - const HloComputation& callee) { - const HloInstruction* root_inst = callee.root_instruction(); - if (root_inst->opcode() == HloOpcode::kOutfeed) { - return llvm::Constant::getNullValue(b_.getInt8PtrTy()); - } - - const BufferAllocation::Slice root_buffer = - assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie(); - return EmitTempBufferPointer(root_buffer, root_inst->shape()); -} - } // namespace cpu } // namespace xla |