aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream_executor_pimpl.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_pimpl.cc')
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc38
1 files changed, 36 insertions, 2 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 000795ff00..2e0137a485 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -220,6 +220,15 @@ void StreamExecutor::UnloadKernel(const KernelBase *kernel) {
implementation_->UnloadKernel(kernel);
}
+bool StreamExecutor::LoadModule(const MultiModuleLoaderSpec &spec,
+ ModuleHandle *module_handle) {
+ return implementation_->LoadModule(spec, module_handle);
+}
+
+bool StreamExecutor::UnloadModule(ModuleHandle module_handle) {
+ return implementation_->UnloadModule(module_handle);
+}
+
void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
<< ") mem->size()=" << mem->size() << StackTraceIfVLOG10();
@@ -459,9 +468,34 @@ void *StreamExecutor::Allocate(uint64 size) {
return buf;
}
-bool StreamExecutor::GetSymbol(const string &symbol_name, void **mem,
+port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
+ const string &symbol_name, ModuleHandle module_handle) {
+ // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
+ // be nullptr/0 for consistency with DeviceMemory semantics.
+ void *opaque = nullptr;
+ size_t bytes = 0;
+ if (GetSymbol(symbol_name, module_handle, &opaque, &bytes)) {
+ return DeviceMemoryBase(opaque, bytes);
+ }
+
+ if (static_cast<bool>(module_handle)) {
+ return port::Status(
+ port::error::NOT_FOUND,
+ port::StrCat("Check if module containing symbol ", symbol_name,
+ " is loaded (module_handle = ",
+ reinterpret_cast<uintptr_t>(module_handle.id()), ")"));
+ } else {
+ return port::Status(
+ port::error::NOT_FOUND,
+ port::StrCat("Check if kernel using the symbol is loaded: ",
+ symbol_name));
+ }
+}
+
+bool StreamExecutor::GetSymbol(const string &symbol_name,
+ ModuleHandle module_handle, void **mem,
size_t *bytes) {
- return implementation_->GetSymbol(symbol_name, mem, bytes);
+ return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes);
}
void *StreamExecutor::UnifiedMemoryAllocate(uint64 bytes) {