diff options
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_pimpl.h')
-rw-r--r-- | tensorflow/stream_executor/stream_executor_pimpl.h | 76 |
1 files changed, 62 insertions, 14 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index ad80a1ba25..47b3a2b030 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -106,6 +106,16 @@ class StreamExecutor { // Releases any state associated with the previously loaded kernel. void UnloadKernel(const KernelBase *kernel); + // Loads a module for the platform this StreamExecutor is acting upon. + // + // `spec` describes the module to be loaded. On success writes the handle for + // the loaded module to `module_handle` and returns true. Else returns false. + bool LoadModule(const MultiModuleLoaderSpec &spec, + ModuleHandle *module_handle); + + // Unloads the module with handle `module_handle`. + bool UnloadModule(ModuleHandle module_handle); + // Synchronously allocates an array on the device of type T with element_count // elements. template <typename T> @@ -169,8 +179,16 @@ class StreamExecutor { // type of symbol and T match. // - Note: symbol_name should include its namespace as well. For example, // pass "nms0::symbol" if referring to nms0::symbol. + // + // If `module_handle` is set then searches only within the module + // corresponding to `module_handle`. template <typename T> - port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name); + port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name, + ModuleHandle module_handle = {}); + + // An untyped version of GetSymbol. + port::StatusOr<DeviceMemoryBase> GetUntypedSymbol( + const string &symbol_name, ModuleHandle module_handle = {}); // Deallocate the DeviceMemory previously allocated via this interface. // Deallocation of a nullptr-representative value is permitted. @@ -507,7 +525,8 @@ class StreamExecutor { // Finds and retrieves device memory for the symbol on the underlying // platform. - bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes); + bool GetSymbol(const string &symbol_name, ModuleHandle module_handle, + void **mem, size_t *bytes); // Entrains a memcpy operation onto stream, with a host destination location // host_dst and a device memory source, with target size size. @@ -678,6 +697,41 @@ class StreamExecutor { SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor); }; +// A wrapper around ModuleHandle that uses RAII to manage its lifetime. +class ScopedModuleHandle { + public: + explicit ScopedModuleHandle(StreamExecutor *executor, + ModuleHandle module_handle) + : executor_(executor), module_handle_(module_handle) {} + + ScopedModuleHandle(ScopedModuleHandle &&other) { + executor_ = other.executor_; + module_handle_ = other.module_handle_; + other.executor_ = nullptr; + other.module_handle_ = ModuleHandle(); + } + + ScopedModuleHandle &operator=(ScopedModuleHandle &&other) { + executor_ = other.executor_; + module_handle_ = other.module_handle_; + other.executor_ = nullptr; + other.module_handle_ = ModuleHandle(); + return *this; + } + + ~ScopedModuleHandle() { + if (static_cast<bool>(module_handle_)) { + CHECK(executor_->UnloadModule(module_handle_)); + } + } + + private: + StreamExecutor *executor_; + ModuleHandle module_handle_; + + TF_DISALLOW_COPY_AND_ASSIGN(ScopedModuleHandle); +}; + //////////// // Inlines @@ -690,19 +744,13 @@ inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count) { template <typename T> inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol( - const string &symbol_name) { - // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to - // be nullptr/0 for consistency with DeviceMemory semantics. - void *opaque = nullptr; - size_t bytes = 0; - if (GetSymbol(symbol_name, &opaque, &bytes)) { - CHECK_EQ(bytes % sizeof(T), 0); - return DeviceMemory<T>::MakeFromByteSize(opaque, bytes); + const string &symbol_name, ModuleHandle module_handle) { + port::StatusOr<DeviceMemoryBase> untyped_symbol = + GetUntypedSymbol(symbol_name, module_handle); + if (!untyped_symbol.ok()) { + return untyped_symbol.status(); } - return port::Status( - port::error::NOT_FOUND, - port::StrCat("Check if kernel using the symbol is loaded: ", - symbol_name)); + return DeviceMemory<T>(untyped_symbol.ValueOrDie()); } template <typename ElemT> |