aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream_executor_internal.h
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-07-23 16:17:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-23 16:20:36 -0700
commit632e48c27e09b53ab52523149e759f9bc1711e71 (patch)
tree6c080226ca18ed1937b8a2afe973702d0ffffaee /tensorflow/stream_executor/stream_executor_internal.h
parent9225bbbe0aaaa14b69176576097bb67bae98e6c5 (diff)
Teach StreamExecutor to load modules and resolve symbols in them
This will be used in a future CL. PiperOrigin-RevId: 205742731
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_internal.h')
-rw-r--r--tensorflow/stream_executor/stream_executor_internal.h32
1 files changed, 30 insertions, 2 deletions
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
index fb1b92cb84..f34b1fc083 100644
--- a/tensorflow/stream_executor/stream_executor_internal.h
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -36,20 +36,38 @@ limitations under the License.
#include "tensorflow/stream_executor/kernel_cache_config.h"
#include "tensorflow/stream_executor/kernel_spec.h"
#include "tensorflow/stream_executor/launch_dim.h"
+#include "tensorflow/stream_executor/lib/inlined_vector.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/module_spec.h"
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/plugin_registry.h"
#include "tensorflow/stream_executor/shared_memory_config.h"
#include "tensorflow/stream_executor/trace_listener.h"
-#include "tensorflow/stream_executor/lib/inlined_vector.h"
namespace stream_executor {
class Stream;
class Timer;
+// An opaque handle to a loaded module.
+//
+// An instance of this is returned from StreamExecutor::GetModule.
+class ModuleHandle {
+ public:
+ /*implicit*/ ModuleHandle(void *id = nullptr) : id_(id) {}
+
+ // A ModuleHandle with id() == nullptr is an invalid module handle, akin to a
+ // null pointer.
+ void *id() const { return id_; }
+
+ explicit operator bool() const { return id() != nullptr; }
+
+ private:
+ void *id_;
+};
+
namespace internal {
// Platform-dependent interface class for the generic Events interface, in
@@ -164,6 +182,11 @@ class StreamExecutorInterface {
KernelBase *kernel) {
return false;
}
+ virtual bool LoadModule(const MultiModuleLoaderSpec &spec,
+ ModuleHandle *module_handle) {
+ return false;
+ }
+ virtual bool UnloadModule(ModuleHandle module_handle) { return false; }
virtual bool Launch(Stream *stream, const ThreadDim &thread_dims,
const BlockDim &block_dims, const KernelBase &k,
const KernelArgsArrayBase &args) {
@@ -247,7 +270,12 @@ class StreamExecutorInterface {
// null, however, both of them cannot be null at the same time. To use
// constant memory in CUDA, GetSymbol has to be used. Returns true if symbol
// is found.
- virtual bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes) {
+ //
+ // If ModuleHandle is set then we search for `symbol_name` only within the
+ // module corresponding to `module_handle`. Otherwise all loaded modules are
+ // searched.
+ virtual bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
+ void **mem, size_t *bytes) {
return false;
}