From 4009f82f71f0421e4ed1f50d38e9105074062d1e Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Thu, 26 Jul 2018 16:43:58 -0700 Subject: Implement constant buffer allocation for XLA:GPU This CL teaches XLA:GPU to use "normal" buffer assignment for constant instructions. Constant instructions are mapped to a BufferAllocation, like all other instructions, except the storage for this buffer is allocated statically as a global in the generated PTX. This CL does not change how we access the constants -- in IrEmitterUnnested::BuildKernelThunk (used for top level computations) and in HloToIrBindings::EmitBasePointersForHlos (used for nested computations) we bind the kConstant instructions to the llvm::GlobalVariable backing them. So users of constant instructions still access the globals corresponding to the constants directly. However, we no longer emit the constant literals inline. Instead we emit a constant with a zero initializer and then memcpy in the contents of the literal when we load the CUBIN/PTX. This works around compile time issues in LLVM and ptxas caused by large constants. We also populate `BufferAllocations` with the device pointers for the constant globals. This is at least needed for TupleThunk today because TupleThunk wants the addresses for the sub-buffers on the host. I'm not sure if there are other places in XLA:GPU that rely on there being an entry in BufferAllocations for every BufferAllocation. PiperOrigin-RevId: 206243319 --- tensorflow/compiler/xla/service/gpu/BUILD | 2 + .../compiler/xla/service/gpu/buffer_allocations.cc | 62 ++++++++++- .../compiler/xla/service/gpu/buffer_allocations.h | 15 +++ .../compiler/xla/service/gpu/gpu_constants.cc | 2 + .../compiler/xla/service/gpu/gpu_constants.h | 3 + .../compiler/xla/service/gpu/gpu_copy_insertion.cc | 79 ++------------ .../compiler/xla/service/gpu/gpu_executable.cc | 55 +++++++++- .../compiler/xla/service/gpu/gpu_executable.h | 21 +++- .../compiler/xla/service/gpu/hlo_to_ir_bindings.cc | 23 +++- tensorflow/compiler/xla/service/gpu/ir_emitter.cc | 13 --- .../xla/service/gpu/ir_emitter_unnested.cc | 116 +++++++++++++++++---- .../compiler/xla/service/gpu/ir_emitter_unnested.h | 3 + .../compiler/xla/service/gpu/nvptx_compiler.cc | 15 ++- tensorflow/compiler/xla/service/gpu/while_thunk.cc | 3 + 14 files changed, 291 insertions(+), 121 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 4c21811698..625d1448e7 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -114,6 +114,7 @@ cc_library( srcs = ["hlo_to_ir_bindings.cc"], hdrs = ["hlo_to_ir_bindings.h"], deps = [ + ":buffer_allocations", ":ir_emission_utils", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:buffer_assignment", @@ -142,6 +143,7 @@ cc_library( ], deps = [ ":backend_configs", + ":buffer_allocations", ":cudnn_convolution_runner", ":elemental_ir_emitter", ":gpu_constants", diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index b095d4cd73..20d4285766 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -44,12 +44,22 @@ StatusOr> BufferAllocations::Builder::Build( num_buffers, device_ordinal, memory_allocator, buffer_assignment)); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { + const BufferAllocation& allocation = buffer_assignment->GetAllocation(i); + const int64 expected_alignment = [&] { + if (allocation.is_entry_computation_parameter()) { + return kEntryParameterAlignBytes; + } else if (allocation.is_constant()) { + return kConstantBufferAlignBytes; + } else { + return kXlaAllocatedBufferAlignBytes; + } + }(); + // If buffer #i's address is already registered (e.g. external arguments or // result buffers), use that registered buffer. if (registered_buffers_.count(i)) { se::DeviceMemoryBase address = FindOrDie(registered_buffers_, i); - if (reinterpret_cast(address.opaque()) % - kEntryParameterAlignBytes != + if (reinterpret_cast(address.opaque()) % expected_alignment != 0) { return InternalError( "Address of registered buffer %lld must be a multiple of %llx, but " @@ -62,7 +72,6 @@ StatusOr> BufferAllocations::Builder::Build( // Allocate each allocation that might escape, or is the temp buffer. bool seen_temp_buffer = false; - const BufferAllocation& allocation = buffer_assignment->GetAllocation(i); if (allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer()) { const int64 buffer_size = allocation.size(); se::DeviceMemoryBase buffer_address; @@ -70,8 +79,7 @@ StatusOr> BufferAllocations::Builder::Build( OwningDeviceMemory buffer; TF_ASSIGN_OR_RETURN( buffer, memory_allocator->Allocate(device_ordinal, buffer_size)); - if (reinterpret_cast(buffer.opaque()) % - kXlaAllocatedBufferAlignBytes != + if (reinterpret_cast(buffer.opaque()) % expected_alignment != 0) { return InternalError( "Address returned by memory_allocator->Allocate must be a " @@ -165,5 +173,49 @@ void BufferAllocations::SetBuffer(BufferAllocation::Index buffer_index, buffers_[buffer_index] = buffer; } +static const HloInstruction& InstrForConstantBufferAllocation( + const BufferAllocation& allocation) { + CHECK(allocation.is_constant()); + HloInstruction* const_instr = nullptr; + for (const auto& buffer_offset_pair : allocation.assigned_buffers()) { + const LogicalBuffer* buffer = buffer_offset_pair.first; + // BufferAssignment may have assigned non-constant instructions to this + // allocation too so we can't CHECK this condition. E.g. for + // + // while(init = constant, body = identity, cond = ...) + // + // the LogicalBuffer for the kWhile instruction will have the same + // BufferAllocation as the LogicalBuffer for the (init) constant. + if (buffer->instruction()->opcode() == HloOpcode::kConstant) { + CHECK_EQ(const_instr, nullptr) + << const_instr->ToString() << " " << buffer->ToString(); + const_instr = buffer->instruction(); + } + } + CHECK_NE(const_instr, nullptr); + return *const_instr; +} + +string ConstantBufferAllocationToGlobalName( + const BufferAllocation& allocation) { + string instr_name = InstrForConstantBufferAllocation(allocation).name(); + for (char& c : instr_name) { + if (c == '.') { + c = '_'; + } + } + return tensorflow::strings::StrCat("buffer_for_", instr_name); +} + +const Literal& LiteralForConstantAllocation( + const BufferAllocation& allocation) { + return InstrForConstantBufferAllocation(allocation).literal(); +} + +bool ShouldEmitLiteralInLlvmIr(const Literal& literal) { + // LLVM can sometimes do interesting optimizations using scalar constants. + return ShapeUtil::IsScalar(literal.shape()); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index 6366235025..f21861ed81 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -107,6 +107,21 @@ class BufferAllocations { bool torn_down_ = false; }; +// In XLA:GPU we map constant buffer allocations to globals in the generated +// LLVM IR. This function gives us the name of the global variable a constant +// buffer is mapped to. +string ConstantBufferAllocationToGlobalName(const BufferAllocation& allocation); + +// Return the Literal corresponding to `allocation`, which must be a constant +// allocation. +const Literal& LiteralForConstantAllocation(const BufferAllocation& allocation); + +// LLVM and PTXAS don't deal well with large constants, so we only emit very +// small constants directly in LLVM IR. Larger constants are emitted with zero +// initializers in LLVM IR and are later overwritten when the PTX/CUBIN is +// loaded. +bool ShouldEmitLiteralInLlvmIr(const Literal& literal); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.cc b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc index e6ddea6d25..7f0b030fec 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_constants.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc @@ -30,5 +30,7 @@ const int64 kEntryParameterAlignBytes = 16; const int64 kXlaAllocatedBufferAlignBytes = tensorflow::Allocator::kAllocatorAlignment; +const int64 kConstantBufferAlignBytes = kXlaAllocatedBufferAlignBytes; + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.h b/tensorflow/compiler/xla/service/gpu/gpu_constants.h index 925e6927b6..6f5f1fa09c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_constants.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.h @@ -28,6 +28,9 @@ extern const int64 kEntryParameterAlignBytes; // out (result) buffers. extern const int64 kXlaAllocatedBufferAlignBytes; +// Minimum alignment for constant buffers. +extern const int64 kConstantBufferAlignBytes; + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index fbc1303085..75f414e47f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -48,80 +48,17 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module)); - TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, - HloDataflowAnalysis::Run(*module)); - - // Make sure all operands of a library call are in memory instead of constants - // in IR. Also, init values of while and conditional nodes cannot be - // constants. Insert copies for any constants found at the operands of these - // nodes. - tensorflow::gtl::FlatSet inserted_copies; + // Check the assumption that the epsilon and feature_index constants of the + // CUDNN batchnorm op are not shared with other ops where we would replace + // them with a copy. These custom op calls are generated with the + // CudnnBatchNormRewriter, so this would only happen if HloCSE merges them. for (HloComputation* computation : module->computations()) { for (HloInstruction* hlo : computation->instructions()) { - // Inserts a copy of hlo->operand(n) if it's a constant. - auto copy_operand_if_constant = [&](int64 n) -> Status { - HloInstruction* operand = hlo->mutable_operand(n); - // Skip the operands that have already been replaced with a copy in a - // previous iteration (which is possible when a constant is used as an - // operand in multiple places). - if (ContainsKey(inserted_copies, operand)) { - return Status::OK(); - } - for (auto& pair : dataflow->GetInstructionValueSet(operand)) { - const HloValueSet& value_set = pair.second; - for (const HloValue* value : value_set.values()) { - if (value->defining_instruction()->IsConstant() && - !ContainsKey(hlo_to_copy_map_, value->defining_instruction())) { - HloInstruction* constant = value->defining_instruction(); - TF_ASSIGN_OR_RETURN(HloInstruction * copy, - FindOrInsertCopy(constant)); - TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy)); - inserted_copies.insert(copy); - changed = true; - } - } - } - return Status::OK(); - }; - - if (IsCustomCallToDnnBatchNorm(*hlo)) { - // The epsilon and feature_index operands to a CUDNN batchnorm op don't - // need to be materialized in memory -- in fact, they must be constants. - // These are the last two operands of all three batchnorm ops. - for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } - } else if (ImplementedAsLibraryCall(*hlo) || - hlo->opcode() == HloOpcode::kCrossReplicaSum || - hlo->opcode() == HloOpcode::kWhile || - hlo->opcode() == HloOpcode::kConditional) { - // For all other library calls, cross-replica-sum, while and conditional - // ops materialize all the operands into memory. (Cross-replica-sum - // gets its constant args materialized even if it's not implemented as a - // libcall to simplify the implementation. It's slower, but we can - // constant fold away constant args *anyway*, so we just need to make it - // work.) - for (int64 i = 0; i < hlo->operand_count(); ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } + if (!IsCustomCallToDnnBatchNorm(*hlo)) { + continue; } - } - } - - if (changed) { - // Check the assumption that the epsilon and feature_index constants of the - // CUDNN batchnorm op are not shared with other ops where we would replace - // them with a copy. These custom op calls are generated with the - // CudnnBatchNormRewriter, so this would only happen if HloCSE merges them. - for (HloComputation* computation : module->computations()) { - for (HloInstruction* hlo : computation->instructions()) { - if (!IsCustomCallToDnnBatchNorm(*hlo)) { - continue; - } - for (int64 i = hlo->operand_count() - 2; i < hlo->operand_count(); - ++i) { - CHECK_EQ(hlo->operand(i)->opcode(), HloOpcode::kConstant); - } + for (int64 i = hlo->operand_count() - 2; i < hlo->operand_count(); ++i) { + CHECK_EQ(hlo->operand(i)->opcode(), HloOpcode::kConstant); } } } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 9767836cd6..52c8595ffb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -181,6 +181,51 @@ Status GpuExecutable::ExecuteThunks( return Status::OK(); } +StatusOr +GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { + tensorflow::mutex_lock lock(module_handle_mutex_); + auto it = module_globals_.find(executor); + if (it != module_globals_.end()) { + return &it->second; + } + + se::MultiModuleLoaderSpec module_spec; + module_spec.AddCudaCubinInMemory(cubin()); + module_spec.AddCudaPtxInMemory(ptx().c_str()); + + tensorflow::gtl::FlatMap globals; + se::ModuleHandle module_handle; + executor->LoadModule(module_spec, &module_handle); + + for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); + ++i) { + const BufferAllocation& allocation = assignment_->GetAllocation(i); + if (allocation.is_constant()) { + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase global, + executor->GetUntypedSymbol( + ConstantBufferAllocationToGlobalName(allocation), module_handle)); + VLOG(3) << "Resolved global " + << ConstantBufferAllocationToGlobalName(allocation) << " to " + << global.opaque(); + InsertOrDie(&globals, i, global); + + const Literal& literal = LiteralForConstantAllocation(allocation); + CHECK(ShapeUtil::IsArray(literal.shape())); + if (!ShouldEmitLiteralInLlvmIr(literal)) { + VLOG(3) << "H2D memcpy for constant with shape " + << ShapeUtil::HumanString(literal.shape()); + TF_RETURN_IF_ERROR(executor->SynchronousMemcpyH2D( + literal.untyped_data(), allocation.size(), &global)); + } + } + } + + module_handles_.emplace(executor, + se::ScopedModuleHandle(executor, module_handle)); + return &module_globals_.emplace(executor, std::move(globals)).first->second; +} + StatusOr GpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -192,6 +237,10 @@ StatusOr GpuExecutable::ExecuteOnStream( } BufferAllocations::Builder buffer_allocations_builder; + se::StreamExecutor* executor = run_options->stream()->parent(); + + TF_ASSIGN_OR_RETURN(auto* const globals, ResolveConstantGlobals(executor)); + for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); ++i) { const BufferAllocation& allocation = assignment_->GetAllocation(i); @@ -213,8 +262,12 @@ StatusOr GpuExecutable::ExecuteOnStream( buffer_allocations_builder.RegisterBuffer(i, buffer); } + + if (allocation.is_constant()) { + buffer_allocations_builder.RegisterBuffer(i, FindOrDie(*globals, i)); + } } - se::StreamExecutor* executor = run_options->stream()->parent(); + TF_ASSIGN_OR_RETURN( auto buffer_allocations, buffer_allocations_builder.Build( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 80ec38c3ac..c7ce6d0acb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -34,6 +34,8 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -66,7 +68,7 @@ class GpuExecutable : public Executable { } // Returns the compiled PTX for the computation. - tensorflow::StringPiece ptx() const { return ptx_; } + const string& ptx() const { return ptx_; } // Returns the cubin (compiled PTX) stored in this GpuExecutable. May be // empty, in which case compilation is left up to the GPU driver. @@ -98,6 +100,15 @@ class GpuExecutable : public Executable { // computation. Uses points-to analysis from buffer assignment. const PointsToSet& GetRootPointsToSet() const; + using BufferAllocToDeviceMemoryMap = + tensorflow::gtl::FlatMap; + + // Loads the PTX or CUBIN for this executable into `executor` and resolves the + // globals corresponding to constant buffers. Returns a map mapping buffer + // allocation indices to GPU pointers. + StatusOr ResolveConstantGlobals( + stream_executor::StreamExecutor* executor); + // The LLVM IR, in string format, of the unoptimized module generated for this // GpuExecutable. We save a string instead of an llvm::Module* because leaving // llvm::Module* in a singleton can cause the heap checker to emit false @@ -126,6 +137,14 @@ class GpuExecutable : public Executable { // memory for every output/temp buffers. const std::unique_ptr assignment_; + // Cache of module handles and constant buffer allocation maps used by + // `ResolveConstantGlobals`. + tensorflow::mutex module_handle_mutex_; + std::map + module_handles_ GUARDED_BY(module_handle_mutex_); + std::map + module_globals_ GUARDED_BY(module_handle_mutex_); + TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable); }; diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 1b6315ec03..c02a95d193 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -18,6 +18,7 @@ limitations under the License. #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -110,6 +111,11 @@ void HloToIrBindings::EmitBasePointersForHlos( llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_); BindHloToIrValue(*non_io_hlo, b_->CreateAlloca(pointee_type), index); + } else if (slice.allocation()->is_constant()) { + llvm::Value* global_for_constant = + module_->getGlobalVariable(llvm_ir::AsStringRef( + ConstantBufferAllocationToGlobalName(*slice.allocation()))); + BindHloToIrValue(*non_io_hlo, global_for_constant); } else { const int64 offset = slice.offset(); CHECK_NE(nullptr, temp_buffer_base_); @@ -135,6 +141,14 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, EmitGetTupleElement(gte->operand(0), base_ptr), b_, module_); } +// Returns true if `value` has a name that should not be changed. +static bool HasMeaningfulName(llvm::Value* value) { + if (auto* global = llvm::dyn_cast(value)) { + return global->getLinkage() != llvm::GlobalValue::PrivateLinkage; + } + return false; +} + llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, ShapeIndexView shape_index, llvm::Value* ir_value) { @@ -149,8 +163,13 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, } else { typed_ir_value = b_->CreateBitCast(ir_value, pointee_type->getPointerTo()); } - ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw"))); - typed_ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed"))); + if (!HasMeaningfulName(ir_value)) { + ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw"))); + } + if (!HasMeaningfulName(typed_ir_value)) { + typed_ir_value->setName( + llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed"))); + } return typed_ir_value; } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 973848c336..1295e83c0c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -81,19 +81,6 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { } Status IrEmitter::HandleConstant(HloInstruction* constant) { - const Literal& literal = constant->literal(); - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, module_); - llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( - *module_, initializer->getType(), - /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, - /*Name=*/""); - VLOG(2) << "HandleConstant: " << constant->ToString() << std::endl - << " emitted_value: " << llvm_ir::DumpToString(*global_for_const) - << std::endl - << " its type: " - << llvm_ir::DumpToString(*global_for_const->getType()); - bindings_.BindHloToIrValue(*constant, global_for_const); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 5445d7b3ab..fb9540b7ef 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" @@ -231,11 +232,20 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( ++arg_it; kernel->addDereferenceableAttr(arg_no + 1, alloc->size()); + + const int64 alignment = [&] { + if (alloc->is_entry_computation_parameter()) { + return kEntryParameterAlignBytes; + } else if (alloc->is_constant()) { + return kConstantBufferAlignBytes; + } else { + return kXlaAllocatedBufferAlignBytes; + } + }(); + kernel->addParamAttr( - arg_no, llvm::Attribute::get(context, llvm::Attribute::Alignment, - alloc->is_entry_computation_parameter() - ? kEntryParameterAlignBytes - : kXlaAllocatedBufferAlignBytes)); + arg_no, + llvm::Attribute::get(context, llvm::Attribute::Alignment, alignment)); if (alloc->IsPreallocatedTempBuffer()) { fn_arg->setName("temp_buf"); @@ -1763,6 +1773,8 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { .GetUniqueTopLevelSlice(tuple_element) .ok(); }); + // TODO(b/111689850): This logic isn't quite correct. + // // Tuples (especially tuples that are the final result of a computation) can // be so huge that if we were to emit a kernel that took each tuple element as // a parameter, we would exceed the max allowable number of parameters to a @@ -1770,9 +1782,9 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { // buffer, we collect their buffer addresses in a host array, and then copy // that array to the tuple's buffer. // - // Some tuple elements (e.g. const or bitcast of const) might not have a - // buffer -- their contents are stored in code. In that case, we fall back to - // emitting kernels which have access to their buffer addresses in code. + // Some tuple elements might not have an unambiguous buffer (like the result + // of a select-tuple). In that case, we fall back to emitting kernels which + // have access to their buffer addresses in code. if (all_tuple_elements_have_buffer) { std::vector tuple_element_buffers; for (const HloInstruction* tuple_element : tuple->operands()) { @@ -2299,11 +2311,6 @@ GetHloBufferSlices(const HloInstruction* hlo, // Adds entries for all subshapes of instr to `slices`. auto add_slices_for = [&](const HloInstruction* instr) { - // GPU constants don't have buffers; don't bother looking for one. - if (instr->IsConstant()) { - return; - } - ShapeUtil::ForEachSubshape( instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) { if (slices.count({instr, index})) { @@ -2365,21 +2372,25 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( // We'll pass a pointer to each of the elements of `buffers` to our kernel, in // this order. - std::vector buffers(buffers_needed.begin(), - buffers_needed.end()); - std::sort(buffers.begin(), buffers.end(), + std::vector non_constant_buffers; + c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers), + [](const BufferAllocation* allocation) { + return !allocation->is_constant(); + }); + + std::sort(non_constant_buffers.begin(), non_constant_buffers.end(), [](const BufferAllocation* a, const BufferAllocation* b) { return a->index() < b->index(); }); - llvm::Function* kernel = BuildKernelPrototype(*inst, buffers); + llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers); // Build a map from a BufferAllocation to the corresponding argument in our // kernel. std::unordered_map kernel_args; { auto arg_it = kernel->arg_begin(); - auto buffers_it = buffers.begin(); + auto buffers_it = non_constant_buffers.begin(); for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) { kernel_args[*buffers_it] = arg_it; } @@ -2397,8 +2408,16 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( << " is found in slice " << slice.ToString() << " at GTE index " << gte_index.ToString(); - llvm::Value* loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), - {b_.getInt64(slice.offset())}); + llvm::Value* loc; + if (slice.allocation()->is_constant()) { + loc = ir_emitter_context_->llvm_module()->getGlobalVariable( + llvm_ir::AsStringRef( + ConstantBufferAllocationToGlobalName(*slice.allocation()))); + CHECK_NE(loc, nullptr); + } else { + loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), + {b_.getInt64(slice.offset())}); + } // If gte_index is nonempty, we have to dereference `loc` to get to the // value we're ultimately interested in. @@ -2421,9 +2440,9 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( llvm::ConstantPointerNull::get(b_.getInt8PtrTy())); } - return MakeUnique(buffers, llvm_ir::AsString(kernel->getName()), - implements_whole_instruction ? inst : nullptr, - unroll_factor); + return MakeUnique( + non_constant_buffers, llvm_ir::AsString(kernel->getName()), + implements_whole_instruction ? inst : nullptr, unroll_factor); } std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( @@ -2660,7 +2679,17 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( // If the init_value was fused into this reduce we have to generate it first. if (fused && init_value_operand->opcode() != HloOpcode::kParameter) { CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode()); - TF_RETURN_IF_ERROR(HandleConstant(const_cast(init_value))); + + const Literal& literal = init_value_operand->literal(); + llvm::Constant* initializer = + llvm_ir::ConvertLiteralToIrConstant(literal, module_); + + llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( + *module_, initializer->getType(), + /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, + /*Name=*/""); + global_for_const->setAlignment(kConstantBufferAlignBytes); + bindings_.BindHloToIrValue(*init_value_operand, global_for_const); } TF_RETURN_IF_ERROR(ParallelLoopEmitter( [=](const IrArray::Index& index) { @@ -3392,5 +3421,46 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return true; } +Status IrEmitterUnnested::EmitConstantGlobals() { + for (const BufferAllocation& allocation : + ir_emitter_context_->buffer_assignment().Allocations()) { + if (!allocation.is_constant()) { + continue; + } + + const Literal& literal = LiteralForConstantAllocation(allocation); + const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal); + llvm::ArrayType* global_type = + llvm::ArrayType::get(b_.getInt8Ty(), allocation.size()); + llvm::Constant* initializer = + should_emit_initializer + ? llvm_ir::ConvertLiteralToIrConstant(literal, module_) + : llvm::ConstantAggregateZero::get(global_type); + if (should_emit_initializer) { + VLOG(3) << "Emitted initializer for constant with shape " + << ShapeUtil::HumanString(literal.shape()); + } + + // These globals will be looked up by name by GpuExecutable so we need to + // give them an external linkage. Not all of their uses are visible in the + // LLVM IR (e.g. TupleThunk) so we can't give then a linkage that merely + // preserves their names (like available_externally), we also need to ensure + // that they stick around even if they're "unused". + // + // We may have to be more more clever here in the future if we notice that + // we're keeping around too many globals because of their linkage. + llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( + global_type, /*isConstant=*/should_emit_initializer, + llvm::GlobalValue::ExternalLinkage, + /*Initializer=*/initializer, + llvm_ir::AsStringRef(ConstantBufferAllocationToGlobalName(allocation))); + global_for_const->setAlignment(kConstantBufferAlignBytes); + ir_emitter_context_->llvm_module()->getGlobalList().push_back( + global_for_const); + } + + return Status::OK(); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 616d8a2206..5254419907 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -92,6 +92,9 @@ class IrEmitterUnnested : public IrEmitter { const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, KernelThunk* thunk); + // Emits LLVM global variables corresponding to constant instructions. + Status EmitConstantGlobals(); + private: // Builds the appropriate thunk for the instruction hlo and returns the owning // pointer to it. The caller needs to make sure `inst` outlives the lifetime diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 2eefadebcd..6d8996dac1 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -540,11 +540,13 @@ StatusOr> NVPTXCompiler::RunBackend( // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr buffer_assignment, - BufferAssigner::Run(module.get(), hlo_schedule->ConsumeHloOrdering(), - BufferSizeBytesFunction(), - /*color_alignment=*/[](LogicalBuffer::Color) { - return kXlaAllocatedBufferAlignBytes; - })); + BufferAssigner::Run( + module.get(), hlo_schedule->ConsumeHloOrdering(), + BufferSizeBytesFunction(), + /*color_alignment=*/ + [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::Stats::ToString() and BufferAssignment::ToString() // include headers, so no need for us to print them ourselves. XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString()); @@ -565,6 +567,9 @@ StatusOr> NVPTXCompiler::RunBackend( HloComputation* entry_computation = module->entry_computation(); IrEmitterUnnested ir_emitter(module->config(), entry_computation, &ir_emitter_context); + + TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + { XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend - IR emission"); TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter)); diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 1315a4183a..d81d87e7dc 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -57,6 +57,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, while (true) { // Invoke thunk sequence for while 'condition' computation. profiler->StartHloComputation(); + VLOG(3) << "Executing condition computation"; TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream( buffer_allocations, stream, profiler)); profiler->FinishHloComputation(hlo_instruction()->while_condition()); @@ -64,6 +65,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, // Copy the result of condition computation and break the loop if 'false'. bool condition_result; stream->ThenMemcpy(&condition_result, condition_result_data, sizeof(bool)); + VLOG(3) << "condition_result = " << condition_result; Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError( @@ -78,6 +80,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, // We measure the time of one execution of the while body computation. The // while body may be executed more than once, the last measurement "wins". profiler->StartHloComputation(); + VLOG(3) << "Executing body computation"; // Invoke thunk sequence for while 'body' computation, and pass on // 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'. TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations, -- cgit v1.2.3