diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/gpu_executable.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gpu_executable.cc | 55 |
1 files changed, 54 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 9767836cd6..52c8595ffb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -181,6 +181,51 @@ Status GpuExecutable::ExecuteThunks( return Status::OK(); } +StatusOr<const GpuExecutable::BufferAllocToDeviceMemoryMap*> +GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { + tensorflow::mutex_lock lock(module_handle_mutex_); + auto it = module_globals_.find(executor); + if (it != module_globals_.end()) { + return &it->second; + } + + se::MultiModuleLoaderSpec module_spec; + module_spec.AddCudaCubinInMemory(cubin()); + module_spec.AddCudaPtxInMemory(ptx().c_str()); + + tensorflow::gtl::FlatMap<int64, se::DeviceMemoryBase> globals; + se::ModuleHandle module_handle; + executor->LoadModule(module_spec, &module_handle); + + for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); + ++i) { + const BufferAllocation& allocation = assignment_->GetAllocation(i); + if (allocation.is_constant()) { + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase global, + executor->GetUntypedSymbol( + ConstantBufferAllocationToGlobalName(allocation), module_handle)); + VLOG(3) << "Resolved global " + << ConstantBufferAllocationToGlobalName(allocation) << " to " + << global.opaque(); + InsertOrDie(&globals, i, global); + + const Literal& literal = LiteralForConstantAllocation(allocation); + CHECK(ShapeUtil::IsArray(literal.shape())); + if (!ShouldEmitLiteralInLlvmIr(literal)) { + VLOG(3) << "H2D memcpy for constant with shape " + << ShapeUtil::HumanString(literal.shape()); + TF_RETURN_IF_ERROR(executor->SynchronousMemcpyH2D( + literal.untyped_data(), allocation.size(), &global)); + } + } + } + + module_handles_.emplace(executor, + se::ScopedModuleHandle(executor, module_handle)); + return &module_globals_.emplace(executor, std::move(globals)).first->second; +} + StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, @@ -192,6 +237,10 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream( } BufferAllocations::Builder buffer_allocations_builder; + se::StreamExecutor* executor = run_options->stream()->parent(); + + TF_ASSIGN_OR_RETURN(auto* const globals, ResolveConstantGlobals(executor)); + for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); ++i) { const BufferAllocation& allocation = assignment_->GetAllocation(i); @@ -213,8 +262,12 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream( buffer_allocations_builder.RegisterBuffer(i, buffer); } + + if (allocation.is_constant()) { + buffer_allocations_builder.RegisterBuffer(i, FindOrDie(*globals, i)); + } } - se::StreamExecutor* executor = run_options->stream()->parent(); + TF_ASSIGN_OR_RETURN( auto buffer_allocations, buffer_allocations_builder.Build( |