aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_gpu_executor.h')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.h18
1 files changed, 16 insertions, 2 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
index 773cbfb8a1..8a954d5461 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
@@ -62,6 +62,9 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
bool GetKernel(const MultiKernelLoaderSpec &spec,
KernelBase *kernel) override;
void UnloadKernel(const KernelBase *kernel) override;
+ bool LoadModule(const MultiModuleLoaderSpec &spec,
+ ModuleHandle *module_handle) override;
+ bool UnloadModule(ModuleHandle module_handle) override;
bool Launch(Stream *stream, const ThreadDim &thread_dims,
const BlockDim &block_dims, const KernelBase &k,
@@ -175,7 +178,8 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
// Search for the symbol and returns a device pointer and size.
// Returns false if symbol does not exist.
- bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes) override;
+ bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
+ void **mem, size_t *bytes) override;
DeviceDescription *PopulateDeviceDescription() const override;
@@ -210,7 +214,7 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override;
- void *CudaContextHack() override;
+ void *GpuContextHack() override;
CudaContext* cuda_context();
@@ -239,6 +243,16 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
void VlogOccupancyInfo(const KernelBase &kernel, const ThreadDim &thread_dims,
const BlockDim &block_dims);
+ bool LoadModuleFromCuBin(const char *cubin, CUmodule *module)
+ EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);
+
+ // Loads the PTX text `ptx` as a CUDA module. `ptx` must be null terminated.
+ bool LoadModuleFromPtx(const char *ptx, CUmodule *module)
+ EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);
+
+ bool UnloadGpuBinary(const void *gpu_binary)
+ EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);
+
// Guards the in-memory-module mapping.
mutex in_memory_modules_mu_;