diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-07-23 16:17:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-23 16:20:36 -0700 |
commit | 632e48c27e09b53ab52523149e759f9bc1711e71 (patch) | |
tree | 6c080226ca18ed1937b8a2afe973702d0ffffaee /tensorflow/stream_executor/stream_executor_internal.h | |
parent | 9225bbbe0aaaa14b69176576097bb67bae98e6c5 (diff) |
Teach StreamExecutor to load modules and resolve symbols in them
This will be used in a future CL.
PiperOrigin-RevId: 205742731
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_internal.h')
-rw-r--r-- | tensorflow/stream_executor/stream_executor_internal.h | 32 |
1 files changed, 30 insertions, 2 deletions
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h index fb1b92cb84..f34b1fc083 100644 --- a/tensorflow/stream_executor/stream_executor_internal.h +++ b/tensorflow/stream_executor/stream_executor_internal.h @@ -36,20 +36,38 @@ limitations under the License. #include "tensorflow/stream_executor/kernel_cache_config.h" #include "tensorflow/stream_executor/kernel_spec.h" #include "tensorflow/stream_executor/launch_dim.h" +#include "tensorflow/stream_executor/lib/inlined_vector.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/module_spec.h" #include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/shared_memory_config.h" #include "tensorflow/stream_executor/trace_listener.h" -#include "tensorflow/stream_executor/lib/inlined_vector.h" namespace stream_executor { class Stream; class Timer; +// An opaque handle to a loaded module. +// +// An instance of this is returned from StreamExecutor::GetModule. +class ModuleHandle { + public: + /*implicit*/ ModuleHandle(void *id = nullptr) : id_(id) {} + + // A ModuleHandle with id() == nullptr is an invalid module handle, akin to a + // null pointer. + void *id() const { return id_; } + + explicit operator bool() const { return id() != nullptr; } + + private: + void *id_; +}; + namespace internal { // Platform-dependent interface class for the generic Events interface, in @@ -164,6 +182,11 @@ class StreamExecutorInterface { KernelBase *kernel) { return false; } + virtual bool LoadModule(const MultiModuleLoaderSpec &spec, + ModuleHandle *module_handle) { + return false; + } + virtual bool UnloadModule(ModuleHandle module_handle) { return false; } virtual bool Launch(Stream *stream, const ThreadDim &thread_dims, const BlockDim &block_dims, const KernelBase &k, const KernelArgsArrayBase &args) { @@ -247,7 +270,12 @@ class StreamExecutorInterface { // null, however, both of them cannot be null at the same time. To use // constant memory in CUDA, GetSymbol has to be used. Returns true if symbol // is found. - virtual bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes) { + // + // If ModuleHandle is set then we search for `symbol_name` only within the + // module corresponding to `module_handle`. Otherwise all loaded modules are + // searched. + virtual bool GetSymbol(const string &symbol_name, ModuleHandle module_handle, + void **mem, size_t *bytes) { return false; } |