diff options
Diffstat (limited to 'tensorflow/compiler')
17 files changed, 461 insertions, 390 deletions
diff --git a/tensorflow/compiler/aot/runtime.cc b/tensorflow/compiler/aot/runtime.cc index 5e74079fc1..7606420ded 100644 --- a/tensorflow/compiler/aot/runtime.cc +++ b/tensorflow/compiler/aot/runtime.cc @@ -64,7 +64,7 @@ size_t align_to(size_t n, size_t align) { size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n) { size_t total = 0; for (size_t i = 0; i < n; ++i) { - if (sizes[i] != -1) { + if (sizes[i] > 0) { total += align_to(sizes[i], kAlign); } } @@ -85,7 +85,9 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, } uintptr_t pos = reinterpret_cast<uintptr_t>(contiguous); for (size_t i = 0; i < n; ++i) { - if (sizes[i] == -1) { + if (sizes[i] < 0) { + // bufs[i] is either a constant, an entry parameter or a thread local + // allocation. bufs[i] = nullptr; } else { bufs[i] = reinterpret_cast<void*>(pos); diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index 672e19bd93..ed5aa08c6f 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -26,6 +26,8 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, result_index_(static_data.result_index), args_(new void*[static_data.num_args]), temps_(new void*[static_data.num_temps]), + arg_index_to_temp_index_(new int32[static_data.num_args]), + num_args_(static_data.num_args), arg_names_(static_data.arg_names), result_names_(static_data.result_names), program_shape_(static_data.program_shape), @@ -40,6 +42,13 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, static_data.temp_sizes, static_data.num_temps, temps_, /*annotate_initialized=*/true); + for (int i = 0; i < static_data.num_temps; i++) { + if (static_data.temp_sizes[i] < -1) { + int32 param_number = -(static_data.temp_sizes[i] + 2); + arg_index_to_temp_index_[param_number] = i; + } + } + // If Hlo profiling is enabled the generated code expects an appropriately // sized buffer to be passed in as the last argument. If Hlo profiling is // disabled the last function argument is still present in the function @@ -50,11 +59,24 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, } } +bool XlaCompiledCpuFunction::Run() { + // Propagate pointers to the argument buffers into the temps array. Code + // generated by XLA discovers the incoming argument pointers from the temps + // array. + for (int32 i = 0; i < num_args_; i++) { + temps_[arg_index_to_temp_index_[i]] = args_[i]; + } + raw_function_(temps_[result_index_], &run_options_, nullptr, temps_, + profile_counters_); + return true; +} + XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_); tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); delete[] args_; delete[] temps_; + delete[] arg_index_to_temp_index_; delete[] profile_counters_; } diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 48a8c083ca..27cfb354bf 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -60,9 +60,19 @@ class XlaCompiledCpuFunction { // The raw function to call. RawFunction raw_function; - // Cardinality and sizes of arg and temp buffers. + // Cardinality and size of arg buffers. const intptr_t* arg_sizes = nullptr; size_t num_args = 0; + + // Cardinality and size of temp buffers. + // + // If temp_sizes[i] >= 0 then the i'th temp is a regular temporary buffer. + // + // If temp_sizes[i] == -1 then the i'th temp is a constant buffer. The + // corresponding entry in the temp buffer array needs to be set to null. + // + // If temp_sizes[i] < -1 then the i'th temp is the entry parameter + // -(temp_sizes[i] + 2). const intptr_t* temp_sizes = nullptr; size_t num_temps = 0; @@ -113,11 +123,7 @@ class XlaCompiledCpuFunction { // Runs the computation, with inputs read from arg buffers, and outputs // written to result buffers. Returns true on success and false on failure. - bool Run() { - raw_function_(temps_[result_index_], &run_options_, - const_cast<const void**>(args_), temps_, profile_counters_); - return true; - } + bool Run(); // Returns the error message from the previous failed Run call. // @@ -224,6 +230,17 @@ class XlaCompiledCpuFunction { void** args_ = nullptr; void** temps_ = nullptr; + // Argument i needs to be placed in temps_[arg_index_to_temp_index_[i]] for + // XLA generated code to be able to find it. + // + // For now we need to keep around the args_ array because there is code that + // depends on args() returning a void**. However, in the future we may remove + // args_ in favor of using temps_ as the sole storage for the arguments. + int32* arg_index_to_temp_index_; + + // The number of incoming arguments. + int32 num_args_; + // Backing memory for individual arg and temp buffers. void* alloc_args_ = nullptr; void* alloc_temps_ = nullptr; diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 00ccfb1c78..114a9241bd 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -58,11 +58,15 @@ xla::StatusOr<std::vector<intptr_t>> ComputeTempSizes( std::vector<intptr_t> temp_sizes; temp_sizes.reserve(allocations.size()); for (const xla::BufferAllocation& allocation : allocations) { - // Callers don't allocate temporary buffers for parameters. Nor for - // thread-local buffers, which are lowered to alloca. - if (allocation.is_entry_computation_parameter() || - allocation.is_thread_local()) { + if (allocation.is_constant() || allocation.is_thread_local()) { + // Constants are lowered to globals. Thread locals are lowered to + // allocas. temp_sizes.push_back(-1); + } else if (allocation.is_entry_computation_parameter()) { + // Entry computation parameters need some preprocessing in + // XlaCompiledCpuFunction::Run. See the comment on + // XlaCompiledCpuFunction::StaticData::temp_sizes. + temp_sizes.push_back(-allocation.parameter_number() - 2); } else { temp_sizes.push_back(allocation.size()); } diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 6a7eb85e3b..128eea4828 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -156,9 +156,26 @@ std::unique_ptr<llvm::MemoryBuffer> CompilerFunctor::operator()( target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream); codegen_passes.run(module); - // Construct ObjectFile from machine code buffer. - return std::unique_ptr<llvm::MemoryBuffer>( + std::unique_ptr<llvm::MemoryBuffer> memory_buffer( new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer))); + + if (VLOG_IS_ON(2)) { + llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> obj_file = + llvm::object::ObjectFile::createObjectFile(*memory_buffer); + if (obj_file) { + StatusOr<DisassemblerResult> disasm_result = + disassembler_->DisassembleObjectFile(*obj_file.get()); + if (disasm_result.ok()) { + XLA_VLOG_LINES(2, disasm_result.ValueOrDie().text); + } else { + LOG(WARNING) << "Could not disassemble object file!"; + } + } else { + LOG(WARNING) << "Could convert memory buffer to object file!"; + } + } + + return memory_buffer; } static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index b49ea89896..8cbe9a1b0d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -840,18 +840,29 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, BufferSizes buffer_sizes; for (const BufferAllocation& allocation : assignment->Allocations()) { - // Callers don't need to allocate temporary buffers for parameters. - if (allocation.is_entry_computation_parameter() || - allocation.is_constant()) { - buffer_sizes.push_back(-1); - continue; - } // Callers don't need to allocate anything for thread-local temporary // buffers. They are lowered to allocas. if (allocation.is_thread_local()) { buffer_sizes.push_back(-1); continue; } + + // Callers don't need to allocate anything for constant buffers. They are + // lowered to globals. + if (allocation.is_constant()) { + buffer_sizes.push_back(-1); + continue; + } + + // Callers don't need to allocate anything for entry computation buffers, + // but they do need to stash the pointer to the entry computation buffer + // in the temp buffer table. See the comment on + // XlaCompiledCpuFunction::StaticData::temp_sizes. + if (allocation.is_entry_computation_parameter()) { + buffer_sizes.push_back(-allocation.parameter_number() - 2); + continue; + } + buffer_sizes.push_back(allocation.size()); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 81e17a5cd4..946f5124b8 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -69,12 +69,19 @@ CpuExecutable::CpuExecutable( // guarded by the mutex. compute_function_ = reinterpret_cast<ComputeFunctionType>(cantFail(sym.getAddress())); + VLOG(1) << "compute_function_ at address " + << reinterpret_cast<void*>(compute_function_); } -Status CpuExecutable::AllocateBuffers( +StatusOr<std::pair<std::vector<se::DeviceMemoryBase>, + std::vector<OwningDeviceMemory>>> +CpuExecutable::CreateTempArray( DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector<OwningDeviceMemory>* buffers) { - CHECK_EQ(buffers->size(), assignment_->Allocations().size()); + tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) { + std::vector<se::DeviceMemoryBase> unowning_buffers( + assignment_->Allocations().size()); + std::vector<OwningDeviceMemory> owning_buffers( + assignment_->Allocations().size()); VLOG(3) << "Allocating " << assignment_->Allocations().size() << " allocations for module " << module().name(); for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); @@ -84,6 +91,8 @@ Status CpuExecutable::AllocateBuffers( VLOG(3) << allocation.ToString(); if (allocation.is_entry_computation_parameter()) { + unowning_buffers[i] = arguments[allocation.parameter_number()]->buffer( + allocation.param_shape_index()); VLOG(3) << "allocation #" << i << " is a parameter"; continue; } @@ -99,34 +108,34 @@ Status CpuExecutable::AllocateBuffers( } int64 buffer_size = allocation.size(); - if (!(*buffers)[i].is_null()) { + if (!owning_buffers[i].is_null()) { VLOG(3) << "buffer #" << i << " is in the preallocated result ShapedBuffer"; } else { - TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate( - device_ordinal, buffer_size)); + TF_ASSIGN_OR_RETURN(owning_buffers[i], memory_allocator->Allocate( + device_ordinal, buffer_size)); + unowning_buffers[i] = owning_buffers[i].AsDeviceMemoryBase(); VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes [" - << (*buffers)[i].opaque() << "]"; + << owning_buffers[i].opaque() << "]"; } // Since the output buffer and all the temporary buffers were written into // by the JITed code, msan has no way of knowing their memory was // initialized. Mark them initialized so that msan doesn't flag loads from // these buffers. - TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size); + TF_ANNOTATE_MEMORY_IS_INITIALIZED(owning_buffers[i].opaque(), buffer_size); } TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, assignment_->GetUniqueTopLevelOutputSlice()); VLOG(3) << "result index: " << result_slice.index(); - return Status::OK(); + return {{std::move(unowning_buffers), std::move(owning_buffers)}}; } Status CpuExecutable::ExecuteComputeFunction( const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers, HloExecutionProfile* hlo_execution_profile) { // The calling convention for JITed functions is: @@ -136,17 +145,11 @@ Status CpuExecutable::ExecuteComputeFunction( // // result: Points at the result. // run_options: the ExecutableRunOptions object. - // args_array: An array of pointers, each of which points to a parameter. - // The size of this array is determined by the function's arity - // (ProgramShape). - // temps_array: An array of pointers, each of which points to a temporary - // buffer the computation needs. The size of this array is - // determined by buffer analysis. + // args_array: null + // temps_array: An array of pointers, containing pointers to temporary buffers + // required by the executable adn pointers to entry computation + // parameters. // - std::vector<const void*> args_array; - for (const ShapedBuffer* argument : arguments) { - args_array.push_back(argument->root_buffer().opaque()); - } uint64 start_micros = tensorflow::Env::Default()->NowMicros(); @@ -169,16 +172,14 @@ Status CpuExecutable::ExecuteComputeFunction( if (VLOG_IS_ON(3)) { VLOG(3) << "Executing compute function:"; VLOG(3) << tensorflow::strings::Printf( - " func(void* result, void* params[%zu], void* temps[%zu], " + " func(void* result, void* params[null], void* temps[%zu], " "uint64 profile_counters[%zu])", - args_array.size(), buffer_pointers.size(), profile_counters_size); + buffer_pointers.size(), profile_counters_size); VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); auto ptr_printer = [](string* out, const void* p) { tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); }; - VLOG(3) << tensorflow::strings::Printf( - " params = [%s]", - tensorflow::str_util::Join(args_array, ", ", ptr_printer).c_str()); + VLOG(3) << " params = nullptr"; VLOG(3) << tensorflow::strings::Printf( " temps = [%s]", tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); @@ -186,8 +187,8 @@ Status CpuExecutable::ExecuteComputeFunction( profile_counters); } - compute_function_(result_buffer, run_options, args_array.data(), - buffer_pointers.data(), profile_counters); + compute_function_(result_buffer, run_options, nullptr, buffer_pointers.data(), + profile_counters); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -254,21 +255,18 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream( se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector<OwningDeviceMemory> buffers(assignment_->Allocations().size()); - - TF_RETURN_IF_ERROR(AllocateBuffers( - memory_allocator, stream->parent()->device_ordinal(), &buffers)); + std::vector<OwningDeviceMemory> owning_buffers; std::vector<se::DeviceMemoryBase> unowning_buffers; - unowning_buffers.reserve(buffers.size()); - for (auto& buffer : buffers) { - unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); - } - TF_RETURN_IF_ERROR(ExecuteComputeFunction(&run_options->run_options(), - arguments, unowning_buffers, - hlo_execution_profile)); + TF_ASSIGN_OR_RETURN( + std::tie(unowning_buffers, owning_buffers), + CreateTempArray(memory_allocator, stream->parent()->device_ordinal(), + arguments)); + + TF_RETURN_IF_ERROR(ExecuteComputeFunction( + &run_options->run_options(), unowning_buffers, hlo_execution_profile)); - return CreateResultShapedBuffer(run_options, &buffers); + return CreateResultShapedBuffer(run_options, &owning_buffers); } StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream( @@ -284,17 +282,15 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream( run_options->stream()->implementation()); se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector<OwningDeviceMemory> buffers(assignment_->Allocations().size()); - TF_RETURN_IF_ERROR(AllocateBuffers( - memory_allocator, stream->parent()->device_ordinal(), &buffers)); - + std::vector<OwningDeviceMemory> owning_buffers; std::vector<se::DeviceMemoryBase> unowning_buffers; - unowning_buffers.reserve(buffers.size()); - for (auto& buffer : buffers) { - unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); - } + TF_ASSIGN_OR_RETURN( + std::tie(unowning_buffers, owning_buffers), + CreateTempArray(memory_allocator, stream->parent()->device_ordinal(), + arguments)); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, - CreateResultShapedBuffer(run_options, &buffers)); + CreateResultShapedBuffer(run_options, &owning_buffers)); // At this point, `unowning_buffers` contains unowning pointers to all of our // buffers, and `buffers` contains owning pointers to the non-live-out @@ -312,7 +308,6 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream( struct AsyncRunTask { CpuExecutable* executable; ServiceExecutableRunOptions run_options; - std::vector<const ShapedBuffer*> arguments; std::vector<se::DeviceMemoryBase> unowning_buffers; std::shared_ptr<std::vector<OwningDeviceMemory>> buffers; @@ -320,15 +315,14 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream( // Failing a CHECK here is not great, but I don't see an obvious way to // return a failed Status asynchronously. TF_CHECK_OK(executable->ExecuteComputeFunction( - &run_options.run_options(), arguments, unowning_buffers, + &run_options.run_options(), unowning_buffers, /*hlo_execution_profile=*/nullptr)); } }; - host_stream->EnqueueTask(AsyncRunTask{ - this, *run_options, - std::vector<const ShapedBuffer*>(arguments.begin(), arguments.end()), - unowning_buffers, - std::make_shared<std::vector<OwningDeviceMemory>>(std::move(buffers))}); + host_stream->EnqueueTask( + AsyncRunTask{this, *run_options, std::move(unowning_buffers), + std::make_shared<std::vector<OwningDeviceMemory>>( + std::move(owning_buffers))}); return std::move(result); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 8dd47bfb86..8af8a5dfec 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -85,20 +85,29 @@ class CpuExecutable : public Executable { const BufferAssignment& buffer_assignment() const { return *assignment_; } private: - // Allocate buffers required for execution and assign them to the elements of - // "buffers". "buffers" should be sized to the number of buffers in buffer - // assignment. Each vector element corresponds to a particular Index. If - // a vector element already contains a non-null DeviceMemoryBase, then no - // buffer is assigned for this element. - Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator, - int device_ordinal, - std::vector<OwningDeviceMemory>* buffers); + // Creates an array suitable for passing as the "temps" argument to the JIT + // compiled function pointer. + // + // Returns (unowning_buffers, owning_buffers) where: + // + // - unowning_buffers.data() can be passed as the temps argument as-is and + // includes pointers to the scratch storage required by the computation, + // the live-out buffer into which the result will be written and entry + // computation parameters. + // + // - owning_buffers contains owning pointers to the buffers that were + // allocated by this routine. This routine allocates buffers for temporary + // storage and the live-out buffer into which the computation writes it + // result. + StatusOr<std::pair<std::vector<se::DeviceMemoryBase>, + std::vector<OwningDeviceMemory>>> + CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal, + tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments); // Calls the generated function performing the computation with the given // arguments using the supplied buffers. Status ExecuteComputeFunction( const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers, HloExecutionProfile* hlo_execution_profile); diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index cf955a8add..c13d36776f 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -19,6 +19,8 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/types.h" @@ -117,9 +119,8 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( ElementwiseSourceIndex(index, *hlo, i))); operands.push_back(operand_value); } - return ir_emitter_->EmitScalarCall(hlo->shape().element_type(), - hlo->to_apply(), operands, - llvm_ir::IrName(hlo)); + return ir_emitter_->EmitElementalMap(*Cast<HloMapInstruction>(hlo), + operands, llvm_ir::IrName(hlo)); }; } return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index a6d8551841..60f9cd1121 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -116,6 +116,19 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation( computation->root_instruction()->outer_dimension_partitions().size(); } + if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) { + TF_ASSIGN_OR_RETURN( + computation_root_allocation_, + assignment_.GetUniqueTopLevelSlice(computation->root_instruction())); + } + + for (const HloInstruction* param : computation->parameter_instructions()) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice, + assignment_.GetUniqueTopLevelSlice(param)); + computation_parameter_allocations_[param_slice.allocation()->index()] = + param->parameter_number(); + } + InitializeIrFunction(function_name); // The rdtscp instruction is x86 specific. We will fallback to LLVM's generic // readcyclecounter if it is unavailable. @@ -132,6 +145,8 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation( // Delete 'compute_function', finalizing 'ir_function' and restoring caller // IR insert point. compute_function_.reset(); + computation_root_allocation_ = BufferAllocation::Slice(); + computation_parameter_allocations_.clear(); return ir_function; } @@ -484,23 +499,11 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForMap( - HloMapInstruction* map, const llvm_ir::IrArray::Index& index) { - llvm::Function* mapped_ir_function = - FindOrDie(emitted_functions_, map->to_apply()); - std::vector<llvm::Value*> parameter_addresses; - for (const HloInstruction* operand : map->operands()) { - const llvm_ir::IrArray& array = GetIrArrayFor(operand); - parameter_addresses.push_back(array.EmitArrayElementAddress(index, &b_)); - } - return EmitElementFunctionCall(mapped_ir_function, map->shape(), - parameter_addresses, "map_function"); -} - -Status IrEmitter::HandleMap(HloInstruction* map) { - return EmitTargetElementLoop(map, [&](const llvm_ir::IrArray::Index& index) { - return EmitTargetElementLoopBodyForMap(Cast<HloMapInstruction>(map), index); - }); +llvm::Value* IrEmitter::EmitElementalMap( + const HloMapInstruction& map_instr, + tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands, + tensorflow::StringPiece name) { + return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name); } StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow( @@ -508,9 +511,6 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow( const llvm_ir::IrArray::Index& index) { const HloInstruction* operand = reduce_window->operand(0); const Window& window = reduce_window->window(); - HloComputation* function = reduce_window->to_apply(); - // The called computation should have been emitted previously. - llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); // We fold inputs into the accumulator and initialize it to // the initial value on the reduce_window. @@ -563,11 +563,10 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow( // We are not in the padding, so carry out the computation. llvm_ir::IrArray input_array(GetIrArrayFor(operand)); - llvm::Value* input_value_address = - input_array.EmitArrayElementAddress(input_index, &b_); - llvm::Value* result = EmitElementFunctionCall( - reducer_function, reduce_window->shape(), - {accumulator_address, input_value_address}, "reducer_function"); + llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_); + llvm::Value* result = EmitThreadLocalCall( + *reduce_window->to_apply(), + {b_.CreateLoad(accumulator_address), input_value}, "reducer_function"); b_.CreateStore(result, accumulator_address); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); @@ -623,12 +622,6 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { "Dilation for SelectAndScatter is not implemented on CPU. "); } - // The select and scatter computations should have been emitted previously. - llvm::Function* select_function = - FindOrDie(emitted_functions_, select_and_scatter->select()); - llvm::Function* scatter_function = - FindOrDie(emitted_functions_, select_and_scatter->scatter()); - // Pseudo code for select-and-scatter: // // initialized_flag is initially off for every window, and is turned on after @@ -733,11 +726,12 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // If the initialized_flag is true, call the `select` function to potentially // update the selected value and index with the currently visiting operand. SetToFirstInsertPoint(if_initialized.true_block, &b_); - const Shape output_shape = ShapeUtil::MakeShape(PRED, {}); llvm::Value* operand_address = operand_array.EmitArrayElementAddress(operand_index, &b_); - llvm::Value* result = EmitElementFunctionCall( - select_function, output_shape, {selected_value_address, operand_address}, + llvm::Value* operand_element = b_.CreateLoad(operand_address); + llvm::Value* result = EmitThreadLocalCall( + *select_and_scatter->select(), + {b_.CreateLoad(selected_value_address), operand_element}, "select_function"); // If the 'select' function returns false, update the selected value and the @@ -764,14 +758,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); } llvm_ir::IrArray source_array(GetIrArrayFor(source)); - llvm::Value* source_value_address = - source_array.EmitArrayElementAddress(source_index, &b_); + llvm::Value* source_value = + source_array.EmitReadArrayElement(source_index, &b_); llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter)); - llvm::Value* output_value_address = - output_array.EmitArrayElementAddress(selected_index, &b_); - llvm::Value* scatter_value = EmitElementFunctionCall( - scatter_function, source->shape(), - {output_value_address, source_value_address}, "scatter_function"); + llvm::Value* output_value = + output_array.EmitReadArrayElement(selected_index, &b_); + llvm::Value* scatter_value = + EmitThreadLocalCall(*select_and_scatter->scatter(), + {output_value, source_value}, "scatter_function"); output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_); SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_); @@ -1248,46 +1242,7 @@ static llvm_ir::IrArray::Index FillReducedDimensionIndex( Status IrEmitter::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); - auto param_number = parameter->parameter_number(); - auto param_shape = parameter->shape(); - - // We have to access the parameter at offset param_number in the params - // array. The code generated here is equivalent to this C code: - // - // i8* param_address_untyped = params[param_number]; - // Param* param_address_typed = (Param*)param_address_untyped; - // - // Where Param is the actual element type of the underlying buffer (for - // example, float for an XLA F32 element type). - llvm::Value* params = compute_function_->parameters_arg(); - llvm::Value* param_address_offset = - llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); - llvm::LoadInst* param_address_untyped = b_.CreateLoad(param_address_offset); - param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped"))); - if (is_top_level_computation_ && - hlo_module_config_.debug_options() - .xla_llvm_enable_invariant_load_metadata()) { - // In the entry computation the parameter slots in the %params argument are - // invariant through program execution. In computations that are called - // from the entry computation (via kWhile, kCall and kConditional) the - // parameter slots are *not* invariant since they're written to by their - // callers. - param_address_untyped->setMetadata( - llvm::LLVMContext::MD_invariant_load, - llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{})); - } - - llvm::Value* param_address_typed = b_.CreateBitCast( - param_address_untyped, IrShapeType(param_shape)->getPointerTo()); - emitted_value_[parameter] = param_address_typed; - - if (!ShapeUtil::IsOpaque(param_shape)) { - AttachAlignmentMetadataForLoad(param_address_untyped, param_shape); - AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape); - } - - VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*param_address_typed); - return Status::OK(); + return EmitTargetAddressForOp(parameter); } // Returns true if the relative order of the unreduced dimensions stays the same @@ -1751,9 +1706,6 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce( const HloInstruction* arg = reduce->mutable_operand(0); const HloInstruction* init_value = reduce->mutable_operand(1); gtl::ArraySlice<int64> dimensions(reduce->dimensions()); - HloComputation* function = reduce->to_apply(); - // The called computation should have been emitted previously. - llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); // Initialize an accumulator with init_value. PrimitiveType accumulator_type = reduce->shape().element_type(); @@ -1793,10 +1745,9 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce( CHECK(index.end() == it); // Apply the reduction function to the loaded value. - llvm::Value* input_address = - arg_array.EmitArrayElementAddress(input_index, &b_); - llvm::Value* result = EmitElementFunctionCall( - reducer_function, reduce->shape(), {accumulator_addr, input_address}, + llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_); + llvm::Value* result = EmitThreadLocalCall( + *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element}, "reduce_function"); b_.CreateStore(result, accumulator_addr); @@ -2134,18 +2085,13 @@ Status IrEmitter::HandleCall(HloInstruction* call) { HloComputation* computation = call->to_apply(); llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation); - std::vector<llvm::Value*> parameter_addresses; - for (const HloInstruction* operand : call->operands()) { - parameter_addresses.push_back(GetEmittedValueFor(operand)); - } - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call)); if (!computation->root_instruction()->outer_dimension_partitions().empty()) { // ParallelTaskAssignment assigned partitions, emit call to // ParallelForkJoin. std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments( - parameter_addresses, &b_, computation->name(), + {}, &b_, computation->name(), /*return_value_buffer=*/emitted_value_[call], /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), /*temp_buffers_arg=*/GetTempBuffersArgument(), @@ -2156,8 +2102,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) { call_args, root->shape(), root->outer_dimension_partitions(), &b_, call_ir_function, computation->name())); } else { - EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, - emitted_value_[call], computation->name()); + EmitGlobalCall(*computation, computation->name()); } return Status::OK(); @@ -2238,12 +2183,6 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { const HloInstruction* init = xla_while->operand(0); emitted_value_[xla_while] = GetEmittedValueFor(init); - // The called computation should have been emitted previously. - llvm::Function* condition_ir_function = - FindOrDie(emitted_functions_, condition); - llvm::Function* body_ir_function = - FindOrDie(emitted_functions_, xla_while->while_body()); - // Generating: // while (Condition(while_result)) { // // CopyInsertion pass inserts copies which enable 'while_result' to @@ -2260,12 +2199,10 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Calls the condition function to determine whether to proceed with the // body. It must return a bool, so use the scalar call form. - llvm::Value* while_result = GetEmittedValueFor(xla_while); - llvm::Value* while_condition = EmitElementFunctionCall( - condition_ir_function, condition->root_instruction()->shape(), - {while_result}, IrName(xla_while, "cond")); + EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond")); llvm::Value* while_predicate = b_.CreateICmpNE( - while_condition, + b_.CreateLoad( + GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); // Branches to the body or to the while exit depending on the condition. @@ -2280,8 +2217,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { b_.SetInsertPoint(body_bb); // Calls the body function. - EmitArrayFunctionCallInto(body_ir_function, {while_result}, while_result, - IrName(xla_while, "body")); + EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body")); + // Finishes with a branch back to the header. b_.CreateBr(header_bb); @@ -2449,8 +2386,6 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { Status IrEmitter::HandleConditional(HloInstruction* conditional) { auto pred = conditional->operand(0); - auto true_arg = conditional->operand(1); - auto false_arg = conditional->operand(2); TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) && pred->shape().element_type() == PRED) << "Predicate on a Conditional must be bool; got: " @@ -2472,13 +2407,7 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { << " and " << ShapeUtil::HumanString(false_computation->root_instruction()->shape()); - llvm::Function* true_function = - FindOrDie(emitted_functions_, true_computation); - llvm::Function* false_function = - FindOrDie(emitted_functions_, false_computation); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional)); - llvm::Value* conditional_result = GetEmittedValueFor(conditional); // Generating: // if (pred) @@ -2495,12 +2424,12 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_); SetToFirstInsertPoint(if_data.true_block, &b_); - EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)}, - conditional_result, IrName(conditional, "_true")); + EmitGlobalCall(*conditional->true_computation(), + IrName(conditional, "_true")); SetToFirstInsertPoint(if_data.false_block, &b_); - EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)}, - conditional_result, IrName(conditional, "_false")); + EmitGlobalCall(*conditional->false_computation(), + IrName(conditional, "_false")); SetToFirstInsertPoint(if_data.after_block, &b_); return Status::OK(); @@ -2701,44 +2630,76 @@ llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { return compute_function_->exec_run_options_arg(); } -llvm::Value* IrEmitter::EmitTempBufferPointer( +llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { - llvm::Type* element_type = IrShapeType(target_shape); - // The alignment and number of bytes within the temporary buffer is determined - // by the maximal shape as determined by buffer assignment. - const BufferAllocation& allocation = assignment_.GetAllocation(slice.index()); - if (allocation.is_thread_local()) { + const BufferAllocation& allocation = *slice.allocation(); + llvm::Value* tempbuf_address = [&]() -> llvm::Value* { + if (slice == computation_root_allocation_) { + llvm::Argument* retval = compute_function_->result_arg(); + llvm::AttrBuilder attr_builder; + attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); + attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); + retval->addAttrs(attr_builder); + return retval; + } + + auto param_it = + computation_parameter_allocations_.find(slice.allocation()->index()); + if (param_it != computation_parameter_allocations_.end()) { + int64 param_number = param_it->second; + // We have to access the parameter at offset param_number in the params + // array. The code generated here is equivalent to this C code: + // + // i8* param_address_untyped = params[param_number]; + // Param* param_address_typed = (Param*)param_address_untyped; + // + // Where Param is the actual element type of the underlying buffer (for + // example, float for an XLA F32 element type). + llvm::Value* params = compute_function_->parameters_arg(); + llvm::Value* param_address_offset = + llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); + llvm::LoadInst* param_address_untyped = + b_.CreateLoad(param_address_offset); + + if (!ShapeUtil::IsOpaque(target_shape)) { + AttachAlignmentMetadataForLoad(param_address_untyped, target_shape); + AttachDereferenceableMetadataForLoad(param_address_untyped, + target_shape); + } + return param_address_untyped; + } + // Thread-local allocations should only be assigned a single buffer. const auto& assigned_buffers = allocation.assigned_buffers(); CHECK_EQ(1, assigned_buffers.size()); const Shape& shape = assigned_buffers.begin()->first->shape(); - llvm::AllocaInst*& tempbuf_address = - thread_local_buffers_[{b_.GetInsertBlock()->getParent(), slice}]; - if (tempbuf_address == nullptr) { - tempbuf_address = llvm_ir::EmitAllocaAtFunctionEntry( + std::pair<llvm::Function*, BufferAllocation::Slice> key = { + compute_function_->function(), slice}; + auto buf_it = thread_local_buffers_.find(key); + if (buf_it == thread_local_buffers_.end()) { + llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry( IrShapeType(shape), tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_, MinimumAlignmentForShape(target_shape)); + auto it_inserted_pair = thread_local_buffers_.insert({key, buffer}); + CHECK(it_inserted_pair.second); + buf_it = it_inserted_pair.first; } - return b_.CreateBitCast(tempbuf_address, element_type->getPointerTo()); - } - - if (allocation.is_constant()) { - return FindOrDie(constant_buffer_to_global_, allocation.index()); - } + return buf_it->second; + }(); + return b_.CreateBitCast(tempbuf_address, + IrShapeType(target_shape)->getPointerTo()); +} +llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( + const BufferAllocation::Slice& slice, const Shape& target_shape) { + const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP( GetTempBuffersArgument(), slice.index(), &b_); llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr); - if (is_top_level_computation_ && - hlo_module_config_.debug_options() + if (hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { - // In the entry computation the parameter slots in the %params argument are - // invariant through program execution. In computations that are called - // from the entry computation (via kWhile, kCall and kConditional) the - // parameter slots are *not* invariant since they're written to by their - // callers. tempbuf_address_base->setMetadata( llvm::LLVMContext::MD_invariant_load, llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); @@ -2753,85 +2714,25 @@ llvm::Value* IrEmitter::EmitTempBufferPointer( b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset())); } return b_.CreateBitCast(tempbuf_address_untyped, - element_type->getPointerTo()); -} - -// Emits a function call returning a single array element. Allocates space -// for a single element_type value, and loads it after call. -llvm::Value* IrEmitter::EmitElementFunctionCall( - llvm::Function* function, const Shape& return_shape, - gtl::ArraySlice<llvm::Value*> parameter_addresses, - tensorflow::StringPiece name) { - llvm::Value* return_value_buffer = EmitArrayFunctionCall( - function, return_shape, 1, parameter_addresses, name); - return b_.CreateLoad( - return_value_buffer, - AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); + IrShapeType(target_shape)->getPointerTo()); } -// Emits a core function call based on the following pseudo-code. -// -// char** parameter_addresses_buffer = -// allocate buffer with a pointer for each parameter to the function -// for each parameter index, i.e. for i = 0, ..., #parameters: -// parameter_addresses_buffer[i] = parameter_addresses[i] -// call function(return_value_buffer, -// parameter_addresses_buffer, -// temps) -// return return_value_buffer -- address of the return value. -void IrEmitter::EmitArrayFunctionCallInto( - llvm::Function* function, gtl::ArraySlice<llvm::Value*> parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name) { - b_.CreateCall(function, - GetArrayFunctionCallArguments( - parameter_addresses, &b_, name, - /*return_value_buffer=*/return_value_buffer, - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), - /*profile_counters_arg=*/GetProfileCountersArgument())); -} - -llvm::Value* IrEmitter::EmitArrayFunctionCall( - llvm::Function* function, const Shape& return_shape, int64 element_count, - gtl::ArraySlice<llvm::Value*> parameter_addresses, - tensorflow::StringPiece name) { - llvm::Value* elements = - llvm::ConstantInt::get(b_.getInt64Ty(), element_count); - PrimitiveType return_type = return_shape.element_type(); - llvm::Value* return_value_buffer = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements, - tensorflow::strings::StrCat(name, "_return_value_address"), &b_, - MinimumAlignmentForPrimitiveType(return_type)); - EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer, - name); - return return_value_buffer; +llvm::Value* IrEmitter::EmitTempBufferPointer( + const BufferAllocation::Slice& slice, const Shape& target_shape) { + if (slice.allocation()->is_thread_local()) { + return EmitThreadLocalTempBufferPointer(slice, target_shape); + } else if (slice.allocation()->is_constant()) { + return FindOrDie(constant_buffer_to_global_, slice.allocation()->index()); + } else { + return EmitGlobalTempBufferPointer(slice, target_shape); + } } Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { - llvm::Value* addr; const Shape& target_shape = op->shape(); - if (op == op->parent()->root_instruction()) { - // For the root node, we write directly to the output buffer of the - // function. - llvm::Argument* retval = compute_function_->result_arg(); - if ((ShapeUtil::IsArray(target_shape) && - !ShapeUtil::IsZeroElementArray(target_shape)) || - (ShapeUtil::IsTuple(target_shape) && - !ShapeUtil::IsEmptyTuple(target_shape))) { - llvm::AttrBuilder attr_builder; - attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); - attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); - retval->addAttrs(attr_builder); - } - addr = b_.CreateBitCast(retval, IrShapeType(target_shape)->getPointerTo()); - } else { - // For other nodes, we need the temporary buffer allocated for this node to - // write the result into. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - assignment_.GetUniqueTopLevelSlice(op)); - addr = EmitTempBufferPointer(slice, target_shape); - } + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + assignment_.GetUniqueTopLevelSlice(op)); + llvm::Value* addr = EmitTempBufferPointer(slice, target_shape); addr->setName(AsStringRef(IrName(op))); emitted_value_[op] = addr; return Status::OK(); @@ -2936,20 +2837,69 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator)); } -StatusOr<llvm::Value*> IrEmitter::EmitScalarCall( - PrimitiveType return_type, HloComputation* computation, - const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name) { - llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation); - std::vector<llvm::Value*> argument_addrs; - for (auto argument : arguments) { - llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry( - argument->getType(), "arg_addr", &b_); - b_.CreateStore(argument, argument_addr); - argument_addrs.push_back(argument_addr); +llvm::Value* IrEmitter::EmitThreadLocalCall( + const HloComputation& callee, + tensorflow::gtl::ArraySlice<llvm::Value*> parameters, + tensorflow::StringPiece name) { + const Shape& return_shape = callee.root_instruction()->shape(); + + // Lifting this restriction to allow "small" arrays should be easy. Allowing + // larger arrays is difficult because we allocate the buffer for this return + // value on the stack. + CHECK(ShapeUtil::IsScalar(return_shape)); + + PrimitiveType return_type = return_shape.element_type(); + + std::vector<llvm::Value*> parameter_addrs; + for (llvm::Value* parameter : parameters) { + CHECK(!parameter->getType()->isPointerTy()); + llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry( + parameter->getType(), "arg_addr", &b_); + b_.CreateStore(parameter, parameter_addr); + parameter_addrs.push_back(parameter_addr); } - return EmitElementFunctionCall(llvm_function, - ShapeUtil::MakeShape(return_type, {}), - argument_addrs, name); + + llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(return_type, module_), + tensorflow::strings::StrCat(name, "_retval_addr"), &b_, + MinimumAlignmentForPrimitiveType(return_type)); + + b_.CreateCall( + FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + parameter_addrs, &b_, name, + /*return_value_buffer=*/return_value_buffer, + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), + /*profile_counters_arg=*/GetProfileCountersArgument())); + + return b_.CreateLoad(return_value_buffer); } + +void IrEmitter::EmitGlobalCall(const HloComputation& callee, + tensorflow::StringPiece name) { + b_.CreateCall(FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + /*parameter_addresses=*/{}, &b_, name, + /*return_value_buffer=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()), + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument())); +} + +llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( + const HloComputation& callee) { + const HloInstruction* root_inst = callee.root_instruction(); + if (root_inst->opcode() == HloOpcode::kOutfeed) { + return llvm::Constant::getNullValue(b_.getInt8PtrTy()); + } + + const BufferAllocation::Slice root_buffer = + assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie(); + return EmitTempBufferPointer(root_buffer, root_inst->shape()); +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 03bbb2afb5..372017441f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -100,14 +100,15 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::IRBuilder<>* b() { return &b_; } - // Emits a call to `computation` with scalar arguments `arguments`. - StatusOr<llvm::Value*> EmitScalarCall( - PrimitiveType return_type, HloComputation* computation, - const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name); - // Emit an LLVM global variable for every constant buffer allocation. Status EmitConstantGlobals(); + // Emit code to map one element according to `map_instr`. + llvm::Value* EmitElementalMap( + const HloMapInstruction& map_instr, + tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands, + tensorflow::StringPiece name); + protected: // // The following methods implement the DfsHloVisitor interface. @@ -143,7 +144,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleRecvDone(HloInstruction* recv_done) override; Status HandlePad(HloInstruction* pad) override; Status HandleTuple(HloInstruction* tuple) override; - Status HandleMap(HloInstruction* map) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call) override; @@ -218,9 +218,18 @@ class IrEmitter : public DfsHloVisitorWithDefault { // computation function being emitted by this emitter. llvm::Value* GetTempBuffersArgument(); - // Emits code that computes the address of the given temporary buffer to the - // function. target_shape is the shape of this temporary buffer. - // The returned Value's type is a pointer to element_type. + // Helper for EmitTempBufferPointer. + llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice, + const Shape& target_shape); + + // Helper for EmitTempBufferPointer. + llvm::Value* EmitThreadLocalTempBufferPointer( + const BufferAllocation::Slice& slice, const Shape& target_shape); + + // Emits code that computes the address of the given buffer allocation slice. + // + // TODO(sanjoy): This should be renamed to reflect that it no longer provides + // access to just temporaries. llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice, const Shape& target_shape); @@ -232,44 +241,27 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::StringPiece function_name_suffix); // Used for LLVM IR register names. - // Methods that emit a function call. - // Parameters: - // function - The LLVM function to call. - // return_shape - The return shape of the HLO computation that was used to - // make the function. Not the same as the return type of the function - // in LLVM, since we use output parameters for the return type. - // element_count - number of elements to return (array form only). - // parameter_addresses - pointers to be passed to the function as - // parameters. - // name - used for LLVM IR register names. - - // Emits a function call, returning a scalar, often an element of a larger - // array. Returns a Value for the scalar element returned by the function. - llvm::Value* EmitElementFunctionCall( - llvm::Function* function, const Shape& return_shape, - tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, + // Emits a call to a thread local function (e.g. to the computation nested + // within a reduce or a map). Thread local callees (by definition) only write + // to and read from thread local allocations. + // + // `parameters` holds the *scalar values* that need to be passed to the + // callee. The return value is the scalar returned by the callee. + llvm::Value* EmitThreadLocalCall( + const HloComputation& callee, + tensorflow::gtl::ArraySlice<llvm::Value*> parameters, tensorflow::StringPiece name); - // Array function call emitter. Stores the function's result into a supplied - // buffer. - // Parameters: - // function - The LLVM function to call. - // parameter_addresses - pointers to be passed to the function as - // parameters. - // return_value - pointer to a buffer where the call result is stored. - - void EmitArrayFunctionCallInto( - llvm::Function* function, - tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name); - - // Array function call emitter. Returns a Value for the function's return - // value buffer address. The return value buffer is alloca'ed by this - // function. - llvm::Value* EmitArrayFunctionCall( - llvm::Function* function, const Shape& return_shape, int64 element_count, - tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, - tensorflow::StringPiece name); + // Emits a call to a "global" function (e.g. to the computation nested within + // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to + // the parameters and return values for these computations so there is no need + // to explicitly pass parameters or return results. + void EmitGlobalCall(const HloComputation& callee, + tensorflow::StringPiece name); + + // Returns the buffer to which a global call to `callee` would have written + // its result. + llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee); // Verifies that the element types of all of the given operand instructions // match and are of one of the given supported types. @@ -408,11 +400,10 @@ class IrEmitter : public DfsHloVisitorWithDefault { NameUniquer name_uniquer_; // Map containing all previously emitted computations. - std::map<HloComputation*, llvm::Function*> emitted_functions_; + std::map<const HloComputation*, llvm::Function*> emitted_functions_; // Map containing all previously emitted thread-local temporary buffers. - std::map<std::pair<llvm::Function*, BufferAllocation::Slice>, - llvm::AllocaInst*> + std::map<std::pair<llvm::Function*, BufferAllocation::Slice>, llvm::Value*> thread_local_buffers_; // The following fields track the IR emission state. According to LLVM memory @@ -422,6 +413,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { std::unique_ptr<IrFunction> compute_function_; llvm::IRBuilder<> b_; + // The buffer allocation slice for the root of the computation being compiled. + // Only relevant for thread local computations. + BufferAllocation::Slice computation_root_allocation_; + + // Maps the buffer allocation slices for the parameters to the computation + // being compiled to their parameter numbers. Only relevant for thread local + // computations. + tensorflow::gtl::FlatMap<BufferAllocation::Index, int64> + computation_parameter_allocations_; + // Maps HLO instructions to their index into the profile counter array. const std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx_; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index 6aff838462..2db4d000f5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -80,9 +80,16 @@ void IrFunction::Initialize(const string& function_name, // void function(i8* retval, i8* run_options, i8** params, i8** temps, // i64* dynamic_loop_bounds, i64* prof_counters) // - // retval: points to the returned value. - // params: address of an array with pointers to parameters. - // temps: address of an array with pointers to temporary buffers. + // For thread local functions: + // retval: points to the returned value. + // params: address of an array with pointers to parameters. + // temps: is null + // + // For global functions: + // retval: is null + // params: is null + // temps: address of an array with pointers to temporary buffers and entry + // computation parameters. // // Therefore, the generated function's signature (FunctionType) is statically // determined - parameter unpacking is done in code generated into the @@ -196,18 +203,25 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments( llvm::IRBuilder<>* b, tensorflow::StringPiece name, llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { - llvm::Value* parameter_addresses_buffer = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()), - tensorflow::strings::StrCat(name, "_parameter_addresses"), b); - for (size_t i = 0; i < parameter_addresses.size(); ++i) { - llvm::Value* parameter_as_i8ptr = - b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(), - AsStringRef(tensorflow::strings::StrCat( - name, "_parameter_", i, "_address_as_i8ptr"))); - llvm::Value* slot_in_param_addresses = - b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); - b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); + llvm::Value* parameter_addresses_buffer; + + if (parameter_addresses.empty()) { + parameter_addresses_buffer = + llvm::Constant::getNullValue(b->getInt8PtrTy()->getPointerTo()); + } else { + parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()), + tensorflow::strings::StrCat(name, "_parameter_addresses"), b); + + for (size_t i = 0; i < parameter_addresses.size(); ++i) { + llvm::Value* parameter_as_i8ptr = + b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(), + AsStringRef(tensorflow::strings::StrCat( + name, "_parameter_", i, "_address_as_i8ptr"))); + llvm::Value* slot_in_param_addresses = + b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); + b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); + } } const auto to_int8_ptr = [=](llvm::Value* ptr) { diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc index 7b1f62541d..a5f34908d7 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc @@ -66,6 +66,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( VLOG(2) << "ParallelForkJoin ENTRY" << " num_partitions: " << num_partitions << " num_partitioned_dims: " << num_partitioned_dims; + CHECK_EQ(params, nullptr); CHECK_GT(num_partitions, 1); CHECK_GT(num_partitioned_dims, 0); const xla::ExecutableRunOptions* run_options = @@ -80,9 +81,9 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( for (int32 i = 1; i < num_partitions; ++i) { const int64 offset = i * stride; run_options->intra_op_thread_pool()->enqueueNoNotification( - [i, function, result_ptr, run_options_ptr, params, temps, prof_counters, + [i, function, result_ptr, run_options_ptr, temps, prof_counters, partitions, offset, &bc]() { - function(result_ptr, run_options_ptr, params, temps, + function(result_ptr, run_options_ptr, nullptr, temps, &partitions[offset], prof_counters); bc.DecrementCount(); VLOG(3) << "ParallelForkJoin partition " << i << " done."; diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc index 941d940684..fe5ec1cc66 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc @@ -56,12 +56,12 @@ ENTRY while3 { )"; CompileAndVerifyIr(hlo_string, R"( -; CHECK-LABEL: @body(i8* align 4 dereferenceable(4) %retval +; CHECK-LABEL: @body(i8* %retval ; CHECK: %[[add_result:.*]] = fadd fast float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]] ; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:[0-9]+]] ; -; CHECK-LABEL: @condition(i8* align 1 dereferenceable(1) %fusion, i8* noalias %run_options, i8** noalias %params -; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %params, i64 0 +; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params +; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %temps, i64 0 ; CHECK: %[[cond_state_buf_untyped:.*]] = load i8*, i8** %[[cond_state_buf_ptr]] ; CHECK: %[[cond_state_buf_typed:.*]] = bitcast i8* %[[cond_state_buf_untyped]] to float* ; CHECK: load float, float* %[[cond_state_buf_typed]], !alias.scope ![[alias_scope_md_for_store]], !noalias ![[noalias_md_for_load:.*]] diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc index 47cab79604..115448c908 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc @@ -42,13 +42,12 @@ extern "C" void SumStructElements(float* out, void** parameters) { TEST_F(LocalClientAotTest, Constant) { xla::ExecutableRunOptions run_options; OpaqueData opaque_data{100, 20, 3}; - void* parameters[] = {&opaque_data}; float out = 0; - void* temporary_buffers[] = {nullptr, &out}; - SumAndDouble(&out, &run_options, parameters, temporary_buffers); + void* temporary_buffers[] = {&opaque_data, &out}; + SumAndDouble(&out, &run_options, nullptr, temporary_buffers); EXPECT_EQ(out, 246.0f); opaque_data = {1, 2, 3}; - SumAndDouble(&out, &run_options, parameters, temporary_buffers); + SumAndDouble(&out, &run_options, nullptr, temporary_buffers); EXPECT_EQ(out, 12.0f); } diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 74494e60e8..e310966d8b 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -93,7 +93,7 @@ int main(int argc, char** argv) { // local_client_aot_test.cc to be able to easily invoke the function. CHECK_EQ(result->result_buffer_index(), 1); CHECK_EQ(result->buffer_sizes().size(), 3); - CHECK_EQ(result->buffer_sizes()[0], -1); // param buffer + CHECK_EQ(result->buffer_sizes()[0], -2); // param buffer CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // result buffer CHECK_EQ(result->buffer_sizes()[2], -1); // const buffer if (triple.isOSBinFormatELF()) { diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index c81c27891c..1bdf1867b9 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -1236,6 +1236,35 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { {param_value.get()}, ErrorSpec(4e-5)); } +TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) { + auto while_shape = ShapeUtil::MakeShape(S32, {}); + + XlaComputation condition; + { + XlaBuilder builder("condition"); + Parameter(&builder, 0, while_shape, "state"); + Infeed(&builder, ShapeUtil::MakeShape(PRED, {})); + TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); + } + + XlaComputation body; + { + XlaBuilder builder("body"); + auto indvar = Parameter(&builder, 0, while_shape, "state"); + Add(indvar, ConstantR0<int32>(&builder, 1)); + TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); + } + + XlaBuilder builder(TestName()); + While(condition, body, ConstantR0<int32>(&builder, 0)); + + TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true))); + TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true))); + TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(false))); + + ComputeAndCompareR0<int32>(&builder, 2, {}); +} + void BM_WhileLoop(int num_iters) { // Benchmark a simple kernel to measure while loop overheads. tensorflow::testing::StopTiming(); |