aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream_executor_pimpl.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_pimpl.h')
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h76
1 files changed, 62 insertions, 14 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index ad80a1ba25..47b3a2b030 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -106,6 +106,16 @@ class StreamExecutor {
// Releases any state associated with the previously loaded kernel.
void UnloadKernel(const KernelBase *kernel);
+ // Loads a module for the platform this StreamExecutor is acting upon.
+ //
+ // `spec` describes the module to be loaded. On success writes the handle for
+ // the loaded module to `module_handle` and returns true. Else returns false.
+ bool LoadModule(const MultiModuleLoaderSpec &spec,
+ ModuleHandle *module_handle);
+
+ // Unloads the module with handle `module_handle`.
+ bool UnloadModule(ModuleHandle module_handle);
+
// Synchronously allocates an array on the device of type T with element_count
// elements.
template <typename T>
@@ -169,8 +179,16 @@ class StreamExecutor {
// type of symbol and T match.
// - Note: symbol_name should include its namespace as well. For example,
// pass "nms0::symbol" if referring to nms0::symbol.
+ //
+ // If `module_handle` is set then searches only within the module
+ // corresponding to `module_handle`.
template <typename T>
- port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name);
+ port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name,
+ ModuleHandle module_handle = {});
+
+ // An untyped version of GetSymbol.
+ port::StatusOr<DeviceMemoryBase> GetUntypedSymbol(
+ const string &symbol_name, ModuleHandle module_handle = {});
// Deallocate the DeviceMemory previously allocated via this interface.
// Deallocation of a nullptr-representative value is permitted.
@@ -507,7 +525,8 @@ class StreamExecutor {
// Finds and retrieves device memory for the symbol on the underlying
// platform.
- bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes);
+ bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
+ void **mem, size_t *bytes);
// Entrains a memcpy operation onto stream, with a host destination location
// host_dst and a device memory source, with target size size.
@@ -678,6 +697,41 @@ class StreamExecutor {
SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
};
+// A wrapper around ModuleHandle that uses RAII to manage its lifetime.
+class ScopedModuleHandle {
+ public:
+ explicit ScopedModuleHandle(StreamExecutor *executor,
+ ModuleHandle module_handle)
+ : executor_(executor), module_handle_(module_handle) {}
+
+ ScopedModuleHandle(ScopedModuleHandle &&other) {
+ executor_ = other.executor_;
+ module_handle_ = other.module_handle_;
+ other.executor_ = nullptr;
+ other.module_handle_ = ModuleHandle();
+ }
+
+ ScopedModuleHandle &operator=(ScopedModuleHandle &&other) {
+ executor_ = other.executor_;
+ module_handle_ = other.module_handle_;
+ other.executor_ = nullptr;
+ other.module_handle_ = ModuleHandle();
+ return *this;
+ }
+
+ ~ScopedModuleHandle() {
+ if (static_cast<bool>(module_handle_)) {
+ CHECK(executor_->UnloadModule(module_handle_));
+ }
+ }
+
+ private:
+ StreamExecutor *executor_;
+ ModuleHandle module_handle_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ScopedModuleHandle);
+};
+
////////////
// Inlines
@@ -690,19 +744,13 @@ inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count) {
template <typename T>
inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
- const string &symbol_name) {
- // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
- // be nullptr/0 for consistency with DeviceMemory semantics.
- void *opaque = nullptr;
- size_t bytes = 0;
- if (GetSymbol(symbol_name, &opaque, &bytes)) {
- CHECK_EQ(bytes % sizeof(T), 0);
- return DeviceMemory<T>::MakeFromByteSize(opaque, bytes);
+ const string &symbol_name, ModuleHandle module_handle) {
+ port::StatusOr<DeviceMemoryBase> untyped_symbol =
+ GetUntypedSymbol(symbol_name, module_handle);
+ if (!untyped_symbol.ok()) {
+ return untyped_symbol.status();
}
- return port::Status(
- port::error::NOT_FOUND,
- port::StrCat("Check if kernel using the symbol is loaded: ",
- symbol_name));
+ return DeviceMemory<T>(untyped_symbol.ValueOrDie());
}
template <typename ElemT>