diff options
author | 2018-07-31 17:21:26 -0700 | |
---|---|---|
committer | 2018-07-31 17:26:19 -0700 | |
commit | 182a00ee781017932443bacb475af7acc4a56d5a (patch) | |
tree | ec24c5cc7e424642bf6772b4d46f7a08ac31dc88 | |
parent | 64f191cdc0121bbcb322c3b11b160d638c2f4af9 (diff) |
Automated rollback of commit fba2d773f45f10882aa475ac75cbf9884995d626
PiperOrigin-RevId: 206855848
17 files changed, 389 insertions, 460 deletions
diff --git a/tensorflow/compiler/aot/runtime.cc b/tensorflow/compiler/aot/runtime.cc index 475eebaa35..5e74079fc1 100644 --- a/tensorflow/compiler/aot/runtime.cc +++ b/tensorflow/compiler/aot/runtime.cc @@ -85,9 +85,7 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, } uintptr_t pos = reinterpret_cast<uintptr_t>(contiguous); for (size_t i = 0; i < n; ++i) { - if (sizes[i] < 0) { - // bufs[i] is either a constant, an entry parameter or a thread local - // allocation. + if (sizes[i] == -1) { bufs[i] = nullptr; } else { bufs[i] = reinterpret_cast<void*>(pos); diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index ed5aa08c6f..672e19bd93 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -26,8 +26,6 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, result_index_(static_data.result_index), args_(new void*[static_data.num_args]), temps_(new void*[static_data.num_temps]), - arg_index_to_temp_index_(new int32[static_data.num_args]), - num_args_(static_data.num_args), arg_names_(static_data.arg_names), result_names_(static_data.result_names), program_shape_(static_data.program_shape), @@ -42,13 +40,6 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, static_data.temp_sizes, static_data.num_temps, temps_, /*annotate_initialized=*/true); - for (int i = 0; i < static_data.num_temps; i++) { - if (static_data.temp_sizes[i] < -1) { - int32 param_number = -(static_data.temp_sizes[i] + 2); - arg_index_to_temp_index_[param_number] = i; - } - } - // If Hlo profiling is enabled the generated code expects an appropriately // sized buffer to be passed in as the last argument. If Hlo profiling is // disabled the last function argument is still present in the function @@ -59,24 +50,11 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, } } -bool XlaCompiledCpuFunction::Run() { - // Propagate pointers to the argument buffers into the temps array. Code - // generated by XLA discovers the incoming argument pointers from the temps - // array. - for (int32 i = 0; i < num_args_; i++) { - temps_[arg_index_to_temp_index_[i]] = args_[i]; - } - raw_function_(temps_[result_index_], &run_options_, nullptr, temps_, - profile_counters_); - return true; -} - XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_); tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); delete[] args_; delete[] temps_; - delete[] arg_index_to_temp_index_; delete[] profile_counters_; } diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 27cfb354bf..48a8c083ca 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -60,19 +60,9 @@ class XlaCompiledCpuFunction { // The raw function to call. RawFunction raw_function; - // Cardinality and size of arg buffers. + // Cardinality and sizes of arg and temp buffers. const intptr_t* arg_sizes = nullptr; size_t num_args = 0; - - // Cardinality and size of temp buffers. - // - // If temp_sizes[i] >= 0 then the i'th temp is a regular temporary buffer. - // - // If temp_sizes[i] == -1 then the i'th temp is a constant buffer. The - // corresponding entry in the temp buffer array needs to be set to null. - // - // If temp_sizes[i] < -1 then the i'th temp is the entry parameter - // -(temp_sizes[i] + 2). const intptr_t* temp_sizes = nullptr; size_t num_temps = 0; @@ -123,7 +113,11 @@ class XlaCompiledCpuFunction { // Runs the computation, with inputs read from arg buffers, and outputs // written to result buffers. Returns true on success and false on failure. - bool Run(); + bool Run() { + raw_function_(temps_[result_index_], &run_options_, + const_cast<const void**>(args_), temps_, profile_counters_); + return true; + } // Returns the error message from the previous failed Run call. // @@ -230,17 +224,6 @@ class XlaCompiledCpuFunction { void** args_ = nullptr; void** temps_ = nullptr; - // Argument i needs to be placed in temps_[arg_index_to_temp_index_[i]] for - // XLA generated code to be able to find it. - // - // For now we need to keep around the args_ array because there is code that - // depends on args() returning a void**. However, in the future we may remove - // args_ in favor of using temps_ as the sole storage for the arguments. - int32* arg_index_to_temp_index_; - - // The number of incoming arguments. - int32 num_args_; - // Backing memory for individual arg and temp buffers. void* alloc_args_ = nullptr; void* alloc_temps_ = nullptr; diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 114a9241bd..00ccfb1c78 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -58,15 +58,11 @@ xla::StatusOr<std::vector<intptr_t>> ComputeTempSizes( std::vector<intptr_t> temp_sizes; temp_sizes.reserve(allocations.size()); for (const xla::BufferAllocation& allocation : allocations) { - if (allocation.is_constant() || allocation.is_thread_local()) { - // Constants are lowered to globals. Thread locals are lowered to - // allocas. + // Callers don't allocate temporary buffers for parameters. Nor for + // thread-local buffers, which are lowered to alloca. + if (allocation.is_entry_computation_parameter() || + allocation.is_thread_local()) { temp_sizes.push_back(-1); - } else if (allocation.is_entry_computation_parameter()) { - // Entry computation parameters need some preprocessing in - // XlaCompiledCpuFunction::Run. See the comment on - // XlaCompiledCpuFunction::StaticData::temp_sizes. - temp_sizes.push_back(-allocation.parameter_number() - 2); } else { temp_sizes.push_back(allocation.size()); } diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 128eea4828..6a7eb85e3b 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -156,26 +156,9 @@ std::unique_ptr<llvm::MemoryBuffer> CompilerFunctor::operator()( target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream); codegen_passes.run(module); - std::unique_ptr<llvm::MemoryBuffer> memory_buffer( + // Construct ObjectFile from machine code buffer. + return std::unique_ptr<llvm::MemoryBuffer>( new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer))); - - if (VLOG_IS_ON(2)) { - llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> obj_file = - llvm::object::ObjectFile::createObjectFile(*memory_buffer); - if (obj_file) { - StatusOr<DisassemblerResult> disasm_result = - disassembler_->DisassembleObjectFile(*obj_file.get()); - if (disasm_result.ok()) { - XLA_VLOG_LINES(2, disasm_result.ValueOrDie().text); - } else { - LOG(WARNING) << "Could not disassemble object file!"; - } - } else { - LOG(WARNING) << "Could convert memory buffer to object file!"; - } - } - - return memory_buffer; } static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 8cbe9a1b0d..b49ea89896 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -840,29 +840,18 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, BufferSizes buffer_sizes; for (const BufferAllocation& allocation : assignment->Allocations()) { - // Callers don't need to allocate anything for thread-local temporary - // buffers. They are lowered to allocas. - if (allocation.is_thread_local()) { + // Callers don't need to allocate temporary buffers for parameters. + if (allocation.is_entry_computation_parameter() || + allocation.is_constant()) { buffer_sizes.push_back(-1); continue; } - - // Callers don't need to allocate anything for constant buffers. They are - // lowered to globals. - if (allocation.is_constant()) { + // Callers don't need to allocate anything for thread-local temporary + // buffers. They are lowered to allocas. + if (allocation.is_thread_local()) { buffer_sizes.push_back(-1); continue; } - - // Callers don't need to allocate anything for entry computation buffers, - // but they do need to stash the pointer to the entry computation buffer - // in the temp buffer table. See the comment on - // XlaCompiledCpuFunction::StaticData::temp_sizes. - if (allocation.is_entry_computation_parameter()) { - buffer_sizes.push_back(-allocation.parameter_number() - 2); - continue; - } - buffer_sizes.push_back(allocation.size()); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 946f5124b8..81e17a5cd4 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -69,19 +69,12 @@ CpuExecutable::CpuExecutable( // guarded by the mutex. compute_function_ = reinterpret_cast<ComputeFunctionType>(cantFail(sym.getAddress())); - VLOG(1) << "compute_function_ at address " - << reinterpret_cast<void*>(compute_function_); } -StatusOr<std::pair<std::vector<se::DeviceMemoryBase>, - std::vector<OwningDeviceMemory>>> -CpuExecutable::CreateTempArray( +Status CpuExecutable::AllocateBuffers( DeviceMemoryAllocator* memory_allocator, int device_ordinal, - tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) { - std::vector<se::DeviceMemoryBase> unowning_buffers( - assignment_->Allocations().size()); - std::vector<OwningDeviceMemory> owning_buffers( - assignment_->Allocations().size()); + std::vector<OwningDeviceMemory>* buffers) { + CHECK_EQ(buffers->size(), assignment_->Allocations().size()); VLOG(3) << "Allocating " << assignment_->Allocations().size() << " allocations for module " << module().name(); for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); @@ -91,8 +84,6 @@ CpuExecutable::CreateTempArray( VLOG(3) << allocation.ToString(); if (allocation.is_entry_computation_parameter()) { - unowning_buffers[i] = arguments[allocation.parameter_number()]->buffer( - allocation.param_shape_index()); VLOG(3) << "allocation #" << i << " is a parameter"; continue; } @@ -108,34 +99,34 @@ CpuExecutable::CreateTempArray( } int64 buffer_size = allocation.size(); - if (!owning_buffers[i].is_null()) { + if (!(*buffers)[i].is_null()) { VLOG(3) << "buffer #" << i << " is in the preallocated result ShapedBuffer"; } else { - TF_ASSIGN_OR_RETURN(owning_buffers[i], memory_allocator->Allocate( - device_ordinal, buffer_size)); - unowning_buffers[i] = owning_buffers[i].AsDeviceMemoryBase(); + TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate( + device_ordinal, buffer_size)); VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes [" - << owning_buffers[i].opaque() << "]"; + << (*buffers)[i].opaque() << "]"; } // Since the output buffer and all the temporary buffers were written into // by the JITed code, msan has no way of knowing their memory was // initialized. Mark them initialized so that msan doesn't flag loads from // these buffers. - TF_ANNOTATE_MEMORY_IS_INITIALIZED(owning_buffers[i].opaque(), buffer_size); + TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size); } TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, assignment_->GetUniqueTopLevelOutputSlice()); VLOG(3) << "result index: " << result_slice.index(); - return {{std::move(unowning_buffers), std::move(owning_buffers)}}; + return Status::OK(); } Status CpuExecutable::ExecuteComputeFunction( const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers, HloExecutionProfile* hlo_execution_profile) { // The calling convention for JITed functions is: @@ -145,11 +136,17 @@ Status CpuExecutable::ExecuteComputeFunction( // // result: Points at the result. // run_options: the ExecutableRunOptions object. - // args_array: null - // temps_array: An array of pointers, containing pointers to temporary buffers - // required by the executable adn pointers to entry computation - // parameters. + // args_array: An array of pointers, each of which points to a parameter. + // The size of this array is determined by the function's arity + // (ProgramShape). + // temps_array: An array of pointers, each of which points to a temporary + // buffer the computation needs. The size of this array is + // determined by buffer analysis. // + std::vector<const void*> args_array; + for (const ShapedBuffer* argument : arguments) { + args_array.push_back(argument->root_buffer().opaque()); + } uint64 start_micros = tensorflow::Env::Default()->NowMicros(); @@ -172,14 +169,16 @@ Status CpuExecutable::ExecuteComputeFunction( if (VLOG_IS_ON(3)) { VLOG(3) << "Executing compute function:"; VLOG(3) << tensorflow::strings::Printf( - " func(void* result, void* params[null], void* temps[%zu], " + " func(void* result, void* params[%zu], void* temps[%zu], " "uint64 profile_counters[%zu])", - buffer_pointers.size(), profile_counters_size); + args_array.size(), buffer_pointers.size(), profile_counters_size); VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); auto ptr_printer = [](string* out, const void* p) { tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); }; - VLOG(3) << " params = nullptr"; + VLOG(3) << tensorflow::strings::Printf( + " params = [%s]", + tensorflow::str_util::Join(args_array, ", ", ptr_printer).c_str()); VLOG(3) << tensorflow::strings::Printf( " temps = [%s]", tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); @@ -187,8 +186,8 @@ Status CpuExecutable::ExecuteComputeFunction( profile_counters); } - compute_function_(result_buffer, run_options, nullptr, buffer_pointers.data(), - profile_counters); + compute_function_(result_buffer, run_options, args_array.data(), + buffer_pointers.data(), profile_counters); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -255,18 +254,21 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream( se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + std::vector<OwningDeviceMemory> buffers(assignment_->Allocations().size()); - std::vector<OwningDeviceMemory> owning_buffers; - std::vector<se::DeviceMemoryBase> unowning_buffers; - TF_ASSIGN_OR_RETURN( - std::tie(unowning_buffers, owning_buffers), - CreateTempArray(memory_allocator, stream->parent()->device_ordinal(), - arguments)); + TF_RETURN_IF_ERROR(AllocateBuffers( + memory_allocator, stream->parent()->device_ordinal(), &buffers)); - TF_RETURN_IF_ERROR(ExecuteComputeFunction( - &run_options->run_options(), unowning_buffers, hlo_execution_profile)); + std::vector<se::DeviceMemoryBase> unowning_buffers; + unowning_buffers.reserve(buffers.size()); + for (auto& buffer : buffers) { + unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); + } + TF_RETURN_IF_ERROR(ExecuteComputeFunction(&run_options->run_options(), + arguments, unowning_buffers, + hlo_execution_profile)); - return CreateResultShapedBuffer(run_options, &owning_buffers); + return CreateResultShapedBuffer(run_options, &buffers); } StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream( @@ -282,15 +284,17 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream( run_options->stream()->implementation()); se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector<OwningDeviceMemory> owning_buffers; - std::vector<se::DeviceMemoryBase> unowning_buffers; - TF_ASSIGN_OR_RETURN( - std::tie(unowning_buffers, owning_buffers), - CreateTempArray(memory_allocator, stream->parent()->device_ordinal(), - arguments)); + std::vector<OwningDeviceMemory> buffers(assignment_->Allocations().size()); + TF_RETURN_IF_ERROR(AllocateBuffers( + memory_allocator, stream->parent()->device_ordinal(), &buffers)); + std::vector<se::DeviceMemoryBase> unowning_buffers; + unowning_buffers.reserve(buffers.size()); + for (auto& buffer : buffers) { + unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); + } TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, - CreateResultShapedBuffer(run_options, &owning_buffers)); + CreateResultShapedBuffer(run_options, &buffers)); // At this point, `unowning_buffers` contains unowning pointers to all of our // buffers, and `buffers` contains owning pointers to the non-live-out @@ -308,6 +312,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream( struct AsyncRunTask { CpuExecutable* executable; ServiceExecutableRunOptions run_options; + std::vector<const ShapedBuffer*> arguments; std::vector<se::DeviceMemoryBase> unowning_buffers; std::shared_ptr<std::vector<OwningDeviceMemory>> buffers; @@ -315,14 +320,15 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream( // Failing a CHECK here is not great, but I don't see an obvious way to // return a failed Status asynchronously. TF_CHECK_OK(executable->ExecuteComputeFunction( - &run_options.run_options(), unowning_buffers, + &run_options.run_options(), arguments, unowning_buffers, /*hlo_execution_profile=*/nullptr)); } }; - host_stream->EnqueueTask( - AsyncRunTask{this, *run_options, std::move(unowning_buffers), - std::make_shared<std::vector<OwningDeviceMemory>>( - std::move(owning_buffers))}); + host_stream->EnqueueTask(AsyncRunTask{ + this, *run_options, + std::vector<const ShapedBuffer*>(arguments.begin(), arguments.end()), + unowning_buffers, + std::make_shared<std::vector<OwningDeviceMemory>>(std::move(buffers))}); return std::move(result); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 8af8a5dfec..8dd47bfb86 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -85,29 +85,20 @@ class CpuExecutable : public Executable { const BufferAssignment& buffer_assignment() const { return *assignment_; } private: - // Creates an array suitable for passing as the "temps" argument to the JIT - // compiled function pointer. - // - // Returns (unowning_buffers, owning_buffers) where: - // - // - unowning_buffers.data() can be passed as the temps argument as-is and - // includes pointers to the scratch storage required by the computation, - // the live-out buffer into which the result will be written and entry - // computation parameters. - // - // - owning_buffers contains owning pointers to the buffers that were - // allocated by this routine. This routine allocates buffers for temporary - // storage and the live-out buffer into which the computation writes it - // result. - StatusOr<std::pair<std::vector<se::DeviceMemoryBase>, - std::vector<OwningDeviceMemory>>> - CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal, - tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments); + // Allocate buffers required for execution and assign them to the elements of + // "buffers". "buffers" should be sized to the number of buffers in buffer + // assignment. Each vector element corresponds to a particular Index. If + // a vector element already contains a non-null DeviceMemoryBase, then no + // buffer is assigned for this element. + Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator, + int device_ordinal, + std::vector<OwningDeviceMemory>* buffers); // Calls the generated function performing the computation with the given // arguments using the supplied buffers. Status ExecuteComputeFunction( const ExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers, HloExecutionProfile* hlo_execution_profile); diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index c13d36776f..cf955a8add 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -19,8 +19,6 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" -#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" -#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/types.h" @@ -119,8 +117,9 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( ElementwiseSourceIndex(index, *hlo, i))); operands.push_back(operand_value); } - return ir_emitter_->EmitElementalMap(*Cast<HloMapInstruction>(hlo), - operands, llvm_ir::IrName(hlo)); + return ir_emitter_->EmitScalarCall(hlo->shape().element_type(), + hlo->to_apply(), operands, + llvm_ir::IrName(hlo)); }; } return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 60f9cd1121..a6d8551841 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -116,19 +116,6 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation( computation->root_instruction()->outer_dimension_partitions().size(); } - if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) { - TF_ASSIGN_OR_RETURN( - computation_root_allocation_, - assignment_.GetUniqueTopLevelSlice(computation->root_instruction())); - } - - for (const HloInstruction* param : computation->parameter_instructions()) { - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice, - assignment_.GetUniqueTopLevelSlice(param)); - computation_parameter_allocations_[param_slice.allocation()->index()] = - param->parameter_number(); - } - InitializeIrFunction(function_name); // The rdtscp instruction is x86 specific. We will fallback to LLVM's generic // readcyclecounter if it is unavailable. @@ -145,8 +132,6 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation( // Delete 'compute_function', finalizing 'ir_function' and restoring caller // IR insert point. compute_function_.reset(); - computation_root_allocation_ = BufferAllocation::Slice(); - computation_parameter_allocations_.clear(); return ir_function; } @@ -499,11 +484,23 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -llvm::Value* IrEmitter::EmitElementalMap( - const HloMapInstruction& map_instr, - tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands, - tensorflow::StringPiece name) { - return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name); +StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForMap( + HloMapInstruction* map, const llvm_ir::IrArray::Index& index) { + llvm::Function* mapped_ir_function = + FindOrDie(emitted_functions_, map->to_apply()); + std::vector<llvm::Value*> parameter_addresses; + for (const HloInstruction* operand : map->operands()) { + const llvm_ir::IrArray& array = GetIrArrayFor(operand); + parameter_addresses.push_back(array.EmitArrayElementAddress(index, &b_)); + } + return EmitElementFunctionCall(mapped_ir_function, map->shape(), + parameter_addresses, "map_function"); +} + +Status IrEmitter::HandleMap(HloInstruction* map) { + return EmitTargetElementLoop(map, [&](const llvm_ir::IrArray::Index& index) { + return EmitTargetElementLoopBodyForMap(Cast<HloMapInstruction>(map), index); + }); } StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow( @@ -511,6 +508,9 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow( const llvm_ir::IrArray::Index& index) { const HloInstruction* operand = reduce_window->operand(0); const Window& window = reduce_window->window(); + HloComputation* function = reduce_window->to_apply(); + // The called computation should have been emitted previously. + llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); // We fold inputs into the accumulator and initialize it to // the initial value on the reduce_window. @@ -563,10 +563,11 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow( // We are not in the padding, so carry out the computation. llvm_ir::IrArray input_array(GetIrArrayFor(operand)); - llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_); - llvm::Value* result = EmitThreadLocalCall( - *reduce_window->to_apply(), - {b_.CreateLoad(accumulator_address), input_value}, "reducer_function"); + llvm::Value* input_value_address = + input_array.EmitArrayElementAddress(input_index, &b_); + llvm::Value* result = EmitElementFunctionCall( + reducer_function, reduce_window->shape(), + {accumulator_address, input_value_address}, "reducer_function"); b_.CreateStore(result, accumulator_address); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); @@ -622,6 +623,12 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { "Dilation for SelectAndScatter is not implemented on CPU. "); } + // The select and scatter computations should have been emitted previously. + llvm::Function* select_function = + FindOrDie(emitted_functions_, select_and_scatter->select()); + llvm::Function* scatter_function = + FindOrDie(emitted_functions_, select_and_scatter->scatter()); + // Pseudo code for select-and-scatter: // // initialized_flag is initially off for every window, and is turned on after @@ -726,12 +733,11 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // If the initialized_flag is true, call the `select` function to potentially // update the selected value and index with the currently visiting operand. SetToFirstInsertPoint(if_initialized.true_block, &b_); + const Shape output_shape = ShapeUtil::MakeShape(PRED, {}); llvm::Value* operand_address = operand_array.EmitArrayElementAddress(operand_index, &b_); - llvm::Value* operand_element = b_.CreateLoad(operand_address); - llvm::Value* result = EmitThreadLocalCall( - *select_and_scatter->select(), - {b_.CreateLoad(selected_value_address), operand_element}, + llvm::Value* result = EmitElementFunctionCall( + select_function, output_shape, {selected_value_address, operand_address}, "select_function"); // If the 'select' function returns false, update the selected value and the @@ -758,14 +764,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); } llvm_ir::IrArray source_array(GetIrArrayFor(source)); - llvm::Value* source_value = - source_array.EmitReadArrayElement(source_index, &b_); + llvm::Value* source_value_address = + source_array.EmitArrayElementAddress(source_index, &b_); llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter)); - llvm::Value* output_value = - output_array.EmitReadArrayElement(selected_index, &b_); - llvm::Value* scatter_value = - EmitThreadLocalCall(*select_and_scatter->scatter(), - {output_value, source_value}, "scatter_function"); + llvm::Value* output_value_address = + output_array.EmitArrayElementAddress(selected_index, &b_); + llvm::Value* scatter_value = EmitElementFunctionCall( + scatter_function, source->shape(), + {output_value_address, source_value_address}, "scatter_function"); output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_); SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_); @@ -1242,7 +1248,46 @@ static llvm_ir::IrArray::Index FillReducedDimensionIndex( Status IrEmitter::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); - return EmitTargetAddressForOp(parameter); + auto param_number = parameter->parameter_number(); + auto param_shape = parameter->shape(); + + // We have to access the parameter at offset param_number in the params + // array. The code generated here is equivalent to this C code: + // + // i8* param_address_untyped = params[param_number]; + // Param* param_address_typed = (Param*)param_address_untyped; + // + // Where Param is the actual element type of the underlying buffer (for + // example, float for an XLA F32 element type). + llvm::Value* params = compute_function_->parameters_arg(); + llvm::Value* param_address_offset = + llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); + llvm::LoadInst* param_address_untyped = b_.CreateLoad(param_address_offset); + param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped"))); + if (is_top_level_computation_ && + hlo_module_config_.debug_options() + .xla_llvm_enable_invariant_load_metadata()) { + // In the entry computation the parameter slots in the %params argument are + // invariant through program execution. In computations that are called + // from the entry computation (via kWhile, kCall and kConditional) the + // parameter slots are *not* invariant since they're written to by their + // callers. + param_address_untyped->setMetadata( + llvm::LLVMContext::MD_invariant_load, + llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{})); + } + + llvm::Value* param_address_typed = b_.CreateBitCast( + param_address_untyped, IrShapeType(param_shape)->getPointerTo()); + emitted_value_[parameter] = param_address_typed; + + if (!ShapeUtil::IsOpaque(param_shape)) { + AttachAlignmentMetadataForLoad(param_address_untyped, param_shape); + AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape); + } + + VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*param_address_typed); + return Status::OK(); } // Returns true if the relative order of the unreduced dimensions stays the same @@ -1706,6 +1751,9 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce( const HloInstruction* arg = reduce->mutable_operand(0); const HloInstruction* init_value = reduce->mutable_operand(1); gtl::ArraySlice<int64> dimensions(reduce->dimensions()); + HloComputation* function = reduce->to_apply(); + // The called computation should have been emitted previously. + llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); // Initialize an accumulator with init_value. PrimitiveType accumulator_type = reduce->shape().element_type(); @@ -1745,9 +1793,10 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce( CHECK(index.end() == it); // Apply the reduction function to the loaded value. - llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_); - llvm::Value* result = EmitThreadLocalCall( - *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element}, + llvm::Value* input_address = + arg_array.EmitArrayElementAddress(input_index, &b_); + llvm::Value* result = EmitElementFunctionCall( + reducer_function, reduce->shape(), {accumulator_addr, input_address}, "reduce_function"); b_.CreateStore(result, accumulator_addr); @@ -2085,13 +2134,18 @@ Status IrEmitter::HandleCall(HloInstruction* call) { HloComputation* computation = call->to_apply(); llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation); + std::vector<llvm::Value*> parameter_addresses; + for (const HloInstruction* operand : call->operands()) { + parameter_addresses.push_back(GetEmittedValueFor(operand)); + } + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call)); if (!computation->root_instruction()->outer_dimension_partitions().empty()) { // ParallelTaskAssignment assigned partitions, emit call to // ParallelForkJoin. std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments( - {}, &b_, computation->name(), + parameter_addresses, &b_, computation->name(), /*return_value_buffer=*/emitted_value_[call], /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), /*temp_buffers_arg=*/GetTempBuffersArgument(), @@ -2102,7 +2156,8 @@ Status IrEmitter::HandleCall(HloInstruction* call) { call_args, root->shape(), root->outer_dimension_partitions(), &b_, call_ir_function, computation->name())); } else { - EmitGlobalCall(*computation, computation->name()); + EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, + emitted_value_[call], computation->name()); } return Status::OK(); @@ -2183,6 +2238,12 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { const HloInstruction* init = xla_while->operand(0); emitted_value_[xla_while] = GetEmittedValueFor(init); + // The called computation should have been emitted previously. + llvm::Function* condition_ir_function = + FindOrDie(emitted_functions_, condition); + llvm::Function* body_ir_function = + FindOrDie(emitted_functions_, xla_while->while_body()); + // Generating: // while (Condition(while_result)) { // // CopyInsertion pass inserts copies which enable 'while_result' to @@ -2199,10 +2260,12 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Calls the condition function to determine whether to proceed with the // body. It must return a bool, so use the scalar call form. - EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond")); + llvm::Value* while_result = GetEmittedValueFor(xla_while); + llvm::Value* while_condition = EmitElementFunctionCall( + condition_ir_function, condition->root_instruction()->shape(), + {while_result}, IrName(xla_while, "cond")); llvm::Value* while_predicate = b_.CreateICmpNE( - b_.CreateLoad( - GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), + while_condition, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); // Branches to the body or to the while exit depending on the condition. @@ -2217,8 +2280,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { b_.SetInsertPoint(body_bb); // Calls the body function. - EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body")); - + EmitArrayFunctionCallInto(body_ir_function, {while_result}, while_result, + IrName(xla_while, "body")); // Finishes with a branch back to the header. b_.CreateBr(header_bb); @@ -2386,6 +2449,8 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { Status IrEmitter::HandleConditional(HloInstruction* conditional) { auto pred = conditional->operand(0); + auto true_arg = conditional->operand(1); + auto false_arg = conditional->operand(2); TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) && pred->shape().element_type() == PRED) << "Predicate on a Conditional must be bool; got: " @@ -2407,7 +2472,13 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { << " and " << ShapeUtil::HumanString(false_computation->root_instruction()->shape()); + llvm::Function* true_function = + FindOrDie(emitted_functions_, true_computation); + llvm::Function* false_function = + FindOrDie(emitted_functions_, false_computation); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional)); + llvm::Value* conditional_result = GetEmittedValueFor(conditional); // Generating: // if (pred) @@ -2424,12 +2495,12 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_); SetToFirstInsertPoint(if_data.true_block, &b_); - EmitGlobalCall(*conditional->true_computation(), - IrName(conditional, "_true")); + EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)}, + conditional_result, IrName(conditional, "_true")); SetToFirstInsertPoint(if_data.false_block, &b_); - EmitGlobalCall(*conditional->false_computation(), - IrName(conditional, "_false")); + EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)}, + conditional_result, IrName(conditional, "_false")); SetToFirstInsertPoint(if_data.after_block, &b_); return Status::OK(); @@ -2630,76 +2701,44 @@ llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { return compute_function_->exec_run_options_arg(); } -llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( +llvm::Value* IrEmitter::EmitTempBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { - const BufferAllocation& allocation = *slice.allocation(); - llvm::Value* tempbuf_address = [&]() -> llvm::Value* { - if (slice == computation_root_allocation_) { - llvm::Argument* retval = compute_function_->result_arg(); - llvm::AttrBuilder attr_builder; - attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); - attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); - retval->addAttrs(attr_builder); - return retval; - } - - auto param_it = - computation_parameter_allocations_.find(slice.allocation()->index()); - if (param_it != computation_parameter_allocations_.end()) { - int64 param_number = param_it->second; - // We have to access the parameter at offset param_number in the params - // array. The code generated here is equivalent to this C code: - // - // i8* param_address_untyped = params[param_number]; - // Param* param_address_typed = (Param*)param_address_untyped; - // - // Where Param is the actual element type of the underlying buffer (for - // example, float for an XLA F32 element type). - llvm::Value* params = compute_function_->parameters_arg(); - llvm::Value* param_address_offset = - llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); - llvm::LoadInst* param_address_untyped = - b_.CreateLoad(param_address_offset); - - if (!ShapeUtil::IsOpaque(target_shape)) { - AttachAlignmentMetadataForLoad(param_address_untyped, target_shape); - AttachDereferenceableMetadataForLoad(param_address_untyped, - target_shape); - } - return param_address_untyped; - } - + llvm::Type* element_type = IrShapeType(target_shape); + // The alignment and number of bytes within the temporary buffer is determined + // by the maximal shape as determined by buffer assignment. + const BufferAllocation& allocation = assignment_.GetAllocation(slice.index()); + if (allocation.is_thread_local()) { // Thread-local allocations should only be assigned a single buffer. const auto& assigned_buffers = allocation.assigned_buffers(); CHECK_EQ(1, assigned_buffers.size()); const Shape& shape = assigned_buffers.begin()->first->shape(); - std::pair<llvm::Function*, BufferAllocation::Slice> key = { - compute_function_->function(), slice}; - auto buf_it = thread_local_buffers_.find(key); - if (buf_it == thread_local_buffers_.end()) { - llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry( + llvm::AllocaInst*& tempbuf_address = + thread_local_buffers_[{b_.GetInsertBlock()->getParent(), slice}]; + if (tempbuf_address == nullptr) { + tempbuf_address = llvm_ir::EmitAllocaAtFunctionEntry( IrShapeType(shape), tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_, MinimumAlignmentForShape(target_shape)); - auto it_inserted_pair = thread_local_buffers_.insert({key, buffer}); - CHECK(it_inserted_pair.second); - buf_it = it_inserted_pair.first; } - return buf_it->second; - }(); - return b_.CreateBitCast(tempbuf_address, - IrShapeType(target_shape)->getPointerTo()); -} + return b_.CreateBitCast(tempbuf_address, element_type->getPointerTo()); + } + + if (allocation.is_constant()) { + return FindOrDie(constant_buffer_to_global_, allocation.index()); + } -llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( - const BufferAllocation::Slice& slice, const Shape& target_shape) { - const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP( GetTempBuffersArgument(), slice.index(), &b_); llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr); - if (hlo_module_config_.debug_options() + if (is_top_level_computation_ && + hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { + // In the entry computation the parameter slots in the %params argument are + // invariant through program execution. In computations that are called + // from the entry computation (via kWhile, kCall and kConditional) the + // parameter slots are *not* invariant since they're written to by their + // callers. tempbuf_address_base->setMetadata( llvm::LLVMContext::MD_invariant_load, llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); @@ -2714,25 +2753,85 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset())); } return b_.CreateBitCast(tempbuf_address_untyped, - IrShapeType(target_shape)->getPointerTo()); + element_type->getPointerTo()); } -llvm::Value* IrEmitter::EmitTempBufferPointer( - const BufferAllocation::Slice& slice, const Shape& target_shape) { - if (slice.allocation()->is_thread_local()) { - return EmitThreadLocalTempBufferPointer(slice, target_shape); - } else if (slice.allocation()->is_constant()) { - return FindOrDie(constant_buffer_to_global_, slice.allocation()->index()); - } else { - return EmitGlobalTempBufferPointer(slice, target_shape); - } +// Emits a function call returning a single array element. Allocates space +// for a single element_type value, and loads it after call. +llvm::Value* IrEmitter::EmitElementFunctionCall( + llvm::Function* function, const Shape& return_shape, + gtl::ArraySlice<llvm::Value*> parameter_addresses, + tensorflow::StringPiece name) { + llvm::Value* return_value_buffer = EmitArrayFunctionCall( + function, return_shape, 1, parameter_addresses, name); + return b_.CreateLoad( + return_value_buffer, + AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); +} + +// Emits a core function call based on the following pseudo-code. +// +// char** parameter_addresses_buffer = +// allocate buffer with a pointer for each parameter to the function +// for each parameter index, i.e. for i = 0, ..., #parameters: +// parameter_addresses_buffer[i] = parameter_addresses[i] +// call function(return_value_buffer, +// parameter_addresses_buffer, +// temps) +// return return_value_buffer -- address of the return value. +void IrEmitter::EmitArrayFunctionCallInto( + llvm::Function* function, gtl::ArraySlice<llvm::Value*> parameter_addresses, + llvm::Value* return_value_buffer, tensorflow::StringPiece name) { + b_.CreateCall(function, + GetArrayFunctionCallArguments( + parameter_addresses, &b_, name, + /*return_value_buffer=*/return_value_buffer, + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument())); +} + +llvm::Value* IrEmitter::EmitArrayFunctionCall( + llvm::Function* function, const Shape& return_shape, int64 element_count, + gtl::ArraySlice<llvm::Value*> parameter_addresses, + tensorflow::StringPiece name) { + llvm::Value* elements = + llvm::ConstantInt::get(b_.getInt64Ty(), element_count); + PrimitiveType return_type = return_shape.element_type(); + llvm::Value* return_value_buffer = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements, + tensorflow::strings::StrCat(name, "_return_value_address"), &b_, + MinimumAlignmentForPrimitiveType(return_type)); + EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer, + name); + return return_value_buffer; } Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { + llvm::Value* addr; const Shape& target_shape = op->shape(); - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - assignment_.GetUniqueTopLevelSlice(op)); - llvm::Value* addr = EmitTempBufferPointer(slice, target_shape); + if (op == op->parent()->root_instruction()) { + // For the root node, we write directly to the output buffer of the + // function. + llvm::Argument* retval = compute_function_->result_arg(); + if ((ShapeUtil::IsArray(target_shape) && + !ShapeUtil::IsZeroElementArray(target_shape)) || + (ShapeUtil::IsTuple(target_shape) && + !ShapeUtil::IsEmptyTuple(target_shape))) { + llvm::AttrBuilder attr_builder; + attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); + attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); + retval->addAttrs(attr_builder); + } + addr = b_.CreateBitCast(retval, IrShapeType(target_shape)->getPointerTo()); + } else { + // For other nodes, we need the temporary buffer allocated for this node to + // write the result into. + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + assignment_.GetUniqueTopLevelSlice(op)); + addr = EmitTempBufferPointer(slice, target_shape); + } addr->setName(AsStringRef(IrName(op))); emitted_value_[op] = addr; return Status::OK(); @@ -2837,69 +2936,20 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator)); } -llvm::Value* IrEmitter::EmitThreadLocalCall( - const HloComputation& callee, - tensorflow::gtl::ArraySlice<llvm::Value*> parameters, - tensorflow::StringPiece name) { - const Shape& return_shape = callee.root_instruction()->shape(); - - // Lifting this restriction to allow "small" arrays should be easy. Allowing - // larger arrays is difficult because we allocate the buffer for this return - // value on the stack. - CHECK(ShapeUtil::IsScalar(return_shape)); - - PrimitiveType return_type = return_shape.element_type(); - - std::vector<llvm::Value*> parameter_addrs; - for (llvm::Value* parameter : parameters) { - CHECK(!parameter->getType()->isPointerTy()); - llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry( - parameter->getType(), "arg_addr", &b_); - b_.CreateStore(parameter, parameter_addr); - parameter_addrs.push_back(parameter_addr); +StatusOr<llvm::Value*> IrEmitter::EmitScalarCall( + PrimitiveType return_type, HloComputation* computation, + const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name) { + llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation); + std::vector<llvm::Value*> argument_addrs; + for (auto argument : arguments) { + llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry( + argument->getType(), "arg_addr", &b_); + b_.CreateStore(argument, argument_addr); + argument_addrs.push_back(argument_addr); } - - llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(return_type, module_), - tensorflow::strings::StrCat(name, "_retval_addr"), &b_, - MinimumAlignmentForPrimitiveType(return_type)); - - b_.CreateCall( - FindOrDie(emitted_functions_, &callee), - GetArrayFunctionCallArguments( - parameter_addrs, &b_, name, - /*return_value_buffer=*/return_value_buffer, - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/ - llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), - /*profile_counters_arg=*/GetProfileCountersArgument())); - - return b_.CreateLoad(return_value_buffer); + return EmitElementFunctionCall(llvm_function, + ShapeUtil::MakeShape(return_type, {}), + argument_addrs, name); } - -void IrEmitter::EmitGlobalCall(const HloComputation& callee, - tensorflow::StringPiece name) { - b_.CreateCall(FindOrDie(emitted_functions_, &callee), - GetArrayFunctionCallArguments( - /*parameter_addresses=*/{}, &b_, name, - /*return_value_buffer=*/ - llvm::Constant::getNullValue(b_.getInt8PtrTy()), - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), - /*profile_counters_arg=*/GetProfileCountersArgument())); -} - -llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( - const HloComputation& callee) { - const HloInstruction* root_inst = callee.root_instruction(); - if (root_inst->opcode() == HloOpcode::kOutfeed) { - return llvm::Constant::getNullValue(b_.getInt8PtrTy()); - } - - const BufferAllocation::Slice root_buffer = - assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie(); - return EmitTempBufferPointer(root_buffer, root_inst->shape()); -} - } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 372017441f..03bbb2afb5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -100,15 +100,14 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::IRBuilder<>* b() { return &b_; } + // Emits a call to `computation` with scalar arguments `arguments`. + StatusOr<llvm::Value*> EmitScalarCall( + PrimitiveType return_type, HloComputation* computation, + const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name); + // Emit an LLVM global variable for every constant buffer allocation. Status EmitConstantGlobals(); - // Emit code to map one element according to `map_instr`. - llvm::Value* EmitElementalMap( - const HloMapInstruction& map_instr, - tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands, - tensorflow::StringPiece name); - protected: // // The following methods implement the DfsHloVisitor interface. @@ -144,6 +143,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleRecvDone(HloInstruction* recv_done) override; Status HandlePad(HloInstruction* pad) override; Status HandleTuple(HloInstruction* tuple) override; + Status HandleMap(HloInstruction* map) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call) override; @@ -218,18 +218,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { // computation function being emitted by this emitter. llvm::Value* GetTempBuffersArgument(); - // Helper for EmitTempBufferPointer. - llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice, - const Shape& target_shape); - - // Helper for EmitTempBufferPointer. - llvm::Value* EmitThreadLocalTempBufferPointer( - const BufferAllocation::Slice& slice, const Shape& target_shape); - - // Emits code that computes the address of the given buffer allocation slice. - // - // TODO(sanjoy): This should be renamed to reflect that it no longer provides - // access to just temporaries. + // Emits code that computes the address of the given temporary buffer to the + // function. target_shape is the shape of this temporary buffer. + // The returned Value's type is a pointer to element_type. llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice, const Shape& target_shape); @@ -241,27 +232,44 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::StringPiece function_name_suffix); // Used for LLVM IR register names. - // Emits a call to a thread local function (e.g. to the computation nested - // within a reduce or a map). Thread local callees (by definition) only write - // to and read from thread local allocations. - // - // `parameters` holds the *scalar values* that need to be passed to the - // callee. The return value is the scalar returned by the callee. - llvm::Value* EmitThreadLocalCall( - const HloComputation& callee, - tensorflow::gtl::ArraySlice<llvm::Value*> parameters, + // Methods that emit a function call. + // Parameters: + // function - The LLVM function to call. + // return_shape - The return shape of the HLO computation that was used to + // make the function. Not the same as the return type of the function + // in LLVM, since we use output parameters for the return type. + // element_count - number of elements to return (array form only). + // parameter_addresses - pointers to be passed to the function as + // parameters. + // name - used for LLVM IR register names. + + // Emits a function call, returning a scalar, often an element of a larger + // array. Returns a Value for the scalar element returned by the function. + llvm::Value* EmitElementFunctionCall( + llvm::Function* function, const Shape& return_shape, + tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, tensorflow::StringPiece name); - // Emits a call to a "global" function (e.g. to the computation nested within - // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to - // the parameters and return values for these computations so there is no need - // to explicitly pass parameters or return results. - void EmitGlobalCall(const HloComputation& callee, - tensorflow::StringPiece name); - - // Returns the buffer to which a global call to `callee` would have written - // its result. - llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee); + // Array function call emitter. Stores the function's result into a supplied + // buffer. + // Parameters: + // function - The LLVM function to call. + // parameter_addresses - pointers to be passed to the function as + // parameters. + // return_value - pointer to a buffer where the call result is stored. + + void EmitArrayFunctionCallInto( + llvm::Function* function, + tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, + llvm::Value* return_value_buffer, tensorflow::StringPiece name); + + // Array function call emitter. Returns a Value for the function's return + // value buffer address. The return value buffer is alloca'ed by this + // function. + llvm::Value* EmitArrayFunctionCall( + llvm::Function* function, const Shape& return_shape, int64 element_count, + tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, + tensorflow::StringPiece name); // Verifies that the element types of all of the given operand instructions // match and are of one of the given supported types. @@ -400,10 +408,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { NameUniquer name_uniquer_; // Map containing all previously emitted computations. - std::map<const HloComputation*, llvm::Function*> emitted_functions_; + std::map<HloComputation*, llvm::Function*> emitted_functions_; // Map containing all previously emitted thread-local temporary buffers. - std::map<std::pair<llvm::Function*, BufferAllocation::Slice>, llvm::Value*> + std::map<std::pair<llvm::Function*, BufferAllocation::Slice>, + llvm::AllocaInst*> thread_local_buffers_; // The following fields track the IR emission state. According to LLVM memory @@ -413,16 +422,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { std::unique_ptr<IrFunction> compute_function_; llvm::IRBuilder<> b_; - // The buffer allocation slice for the root of the computation being compiled. - // Only relevant for thread local computations. - BufferAllocation::Slice computation_root_allocation_; - - // Maps the buffer allocation slices for the parameters to the computation - // being compiled to their parameter numbers. Only relevant for thread local - // computations. - tensorflow::gtl::FlatMap<BufferAllocation::Index, int64> - computation_parameter_allocations_; - // Maps HLO instructions to their index into the profile counter array. const std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx_; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index 2db4d000f5..6aff838462 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -80,16 +80,9 @@ void IrFunction::Initialize(const string& function_name, // void function(i8* retval, i8* run_options, i8** params, i8** temps, // i64* dynamic_loop_bounds, i64* prof_counters) // - // For thread local functions: - // retval: points to the returned value. - // params: address of an array with pointers to parameters. - // temps: is null - // - // For global functions: - // retval: is null - // params: is null - // temps: address of an array with pointers to temporary buffers and entry - // computation parameters. + // retval: points to the returned value. + // params: address of an array with pointers to parameters. + // temps: address of an array with pointers to temporary buffers. // // Therefore, the generated function's signature (FunctionType) is statically // determined - parameter unpacking is done in code generated into the @@ -203,25 +196,18 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments( llvm::IRBuilder<>* b, tensorflow::StringPiece name, llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { - llvm::Value* parameter_addresses_buffer; - - if (parameter_addresses.empty()) { - parameter_addresses_buffer = - llvm::Constant::getNullValue(b->getInt8PtrTy()->getPointerTo()); - } else { - parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()), - tensorflow::strings::StrCat(name, "_parameter_addresses"), b); - - for (size_t i = 0; i < parameter_addresses.size(); ++i) { - llvm::Value* parameter_as_i8ptr = - b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(), - AsStringRef(tensorflow::strings::StrCat( - name, "_parameter_", i, "_address_as_i8ptr"))); - llvm::Value* slot_in_param_addresses = - b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); - b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); - } + llvm::Value* parameter_addresses_buffer = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()), + tensorflow::strings::StrCat(name, "_parameter_addresses"), b); + for (size_t i = 0; i < parameter_addresses.size(); ++i) { + llvm::Value* parameter_as_i8ptr = + b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(), + AsStringRef(tensorflow::strings::StrCat( + name, "_parameter_", i, "_address_as_i8ptr"))); + llvm::Value* slot_in_param_addresses = + b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); + b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); } const auto to_int8_ptr = [=](llvm::Value* ptr) { diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc index 52b6c8eb80..d03da46575 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc @@ -65,7 +65,6 @@ void __xla_cpu_runtime_ParallelForkJoin( VLOG(2) << "ParallelForkJoin ENTRY" << " num_partitions: " << num_partitions << " num_partitioned_dims: " << num_partitioned_dims; - CHECK_EQ(params, nullptr); CHECK_GT(num_partitions, 1); CHECK_GT(num_partitioned_dims, 0); const xla::ExecutableRunOptions* run_options = @@ -80,9 +79,9 @@ void __xla_cpu_runtime_ParallelForkJoin( for (int32 i = 1; i < num_partitions; ++i) { const int64 offset = i * stride; run_options->intra_op_thread_pool()->enqueueNoNotification( - [i, function, result_ptr, run_options_ptr, temps, prof_counters, + [i, function, result_ptr, run_options_ptr, params, temps, prof_counters, partitions, offset, &bc]() { - function(result_ptr, run_options_ptr, nullptr, temps, + function(result_ptr, run_options_ptr, params, temps, &partitions[offset], prof_counters); bc.DecrementCount(); VLOG(3) << "ParallelForkJoin partition " << i << " done."; diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc index fe5ec1cc66..941d940684 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc @@ -56,12 +56,12 @@ ENTRY while3 { )"; CompileAndVerifyIr(hlo_string, R"( -; CHECK-LABEL: @body(i8* %retval +; CHECK-LABEL: @body(i8* align 4 dereferenceable(4) %retval ; CHECK: %[[add_result:.*]] = fadd fast float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]] ; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:[0-9]+]] ; -; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params -; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %temps, i64 0 +; CHECK-LABEL: @condition(i8* align 1 dereferenceable(1) %fusion, i8* noalias %run_options, i8** noalias %params +; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %params, i64 0 ; CHECK: %[[cond_state_buf_untyped:.*]] = load i8*, i8** %[[cond_state_buf_ptr]] ; CHECK: %[[cond_state_buf_typed:.*]] = bitcast i8* %[[cond_state_buf_untyped]] to float* ; CHECK: load float, float* %[[cond_state_buf_typed]], !alias.scope ![[alias_scope_md_for_store]], !noalias ![[noalias_md_for_load:.*]] diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc index 115448c908..47cab79604 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc @@ -42,12 +42,13 @@ extern "C" void SumStructElements(float* out, void** parameters) { TEST_F(LocalClientAotTest, Constant) { xla::ExecutableRunOptions run_options; OpaqueData opaque_data{100, 20, 3}; + void* parameters[] = {&opaque_data}; float out = 0; - void* temporary_buffers[] = {&opaque_data, &out}; - SumAndDouble(&out, &run_options, nullptr, temporary_buffers); + void* temporary_buffers[] = {nullptr, &out}; + SumAndDouble(&out, &run_options, parameters, temporary_buffers); EXPECT_EQ(out, 246.0f); opaque_data = {1, 2, 3}; - SumAndDouble(&out, &run_options, nullptr, temporary_buffers); + SumAndDouble(&out, &run_options, parameters, temporary_buffers); EXPECT_EQ(out, 12.0f); } diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index e310966d8b..74494e60e8 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -93,7 +93,7 @@ int main(int argc, char** argv) { // local_client_aot_test.cc to be able to easily invoke the function. CHECK_EQ(result->result_buffer_index(), 1); CHECK_EQ(result->buffer_sizes().size(), 3); - CHECK_EQ(result->buffer_sizes()[0], -2); // param buffer + CHECK_EQ(result->buffer_sizes()[0], -1); // param buffer CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // result buffer CHECK_EQ(result->buffer_sizes()[2], -1); // const buffer if (triple.isOSBinFormatELF()) { diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 1bdf1867b9..c81c27891c 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -1236,35 +1236,6 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { {param_value.get()}, ErrorSpec(4e-5)); } -TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) { - auto while_shape = ShapeUtil::MakeShape(S32, {}); - - XlaComputation condition; - { - XlaBuilder builder("condition"); - Parameter(&builder, 0, while_shape, "state"); - Infeed(&builder, ShapeUtil::MakeShape(PRED, {})); - TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); - } - - XlaComputation body; - { - XlaBuilder builder("body"); - auto indvar = Parameter(&builder, 0, while_shape, "state"); - Add(indvar, ConstantR0<int32>(&builder, 1)); - TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); - } - - XlaBuilder builder(TestName()); - While(condition, body, ConstantR0<int32>(&builder, 0)); - - TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true))); - TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true))); - TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(false))); - - ComputeAndCompareR0<int32>(&builder, 2, {}); -} - void BM_WhileLoop(int num_iters) { // Benchmark a simple kernel to measure while loop overheads. tensorflow::testing::StopTiming(); |