aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream_executor_internal.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_internal.h')
-rw-r--r--tensorflow/stream_executor/stream_executor_internal.h66
1 files changed, 48 insertions, 18 deletions
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
index 9c989b971d..f34b1fc083 100644
--- a/tensorflow/stream_executor/stream_executor_internal.h
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -36,20 +36,38 @@ limitations under the License.
#include "tensorflow/stream_executor/kernel_cache_config.h"
#include "tensorflow/stream_executor/kernel_spec.h"
#include "tensorflow/stream_executor/launch_dim.h"
+#include "tensorflow/stream_executor/lib/inlined_vector.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/module_spec.h"
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/plugin_registry.h"
#include "tensorflow/stream_executor/shared_memory_config.h"
#include "tensorflow/stream_executor/trace_listener.h"
-#include "tensorflow/stream_executor/lib/inlined_vector.h"
namespace stream_executor {
class Stream;
class Timer;
+// An opaque handle to a loaded module.
+//
+// An instance of this is returned from StreamExecutor::GetModule.
+class ModuleHandle {
+ public:
+ /*implicit*/ ModuleHandle(void *id = nullptr) : id_(id) {}
+
+ // A ModuleHandle with id() == nullptr is an invalid module handle, akin to a
+ // null pointer.
+ void *id() const { return id_; }
+
+ explicit operator bool() const { return id() != nullptr; }
+
+ private:
+ void *id_;
+};
+
namespace internal {
// Platform-dependent interface class for the generic Events interface, in
@@ -100,19 +118,20 @@ class StreamInterface {
// Default destructor for the abstract interface.
virtual ~StreamInterface() {}
- // Returns the CUDA stream associated with this platform's stream
+ // Returns the GPU stream associated with this platform's stream
// implementation.
//
- // WARNING: checks that the underlying platform is, in fact, CUDA, causing a
- // fatal error if it is not. This hack is made available solely for use from
- // distbelief code, which temporarily has strong ties to CUDA as a platform.
- virtual void *CudaStreamHack() { return nullptr; }
-
- // See the above comment on CudaStreamHack -- this further breaks abstraction
- // for Eigen within distbelief, which has strong ties to CUDA as a platform,
- // and a historical attachment to a programming model which takes a
+ // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm,
+ // causing a fatal error if it is not. This hack is made available solely for
+ // use from distbelief code, which temporarily has strong ties to CUDA or
+ // ROCm as a platform.
+ virtual void *GpuStreamHack() { return nullptr; }
+
+ // See the above comment on GpuStreamHack -- this further breaks abstraction
+ // for Eigen within distbelief, which has strong ties to CUDA or ROCm as a
+ // platform, and a historical attachment to a programming model which takes a
// stream-slot rather than a stream-value.
- virtual void **CudaStreamMemberHack() { return nullptr; }
+ virtual void **GpuStreamMemberHack() { return nullptr; }
private:
SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface);
@@ -163,6 +182,11 @@ class StreamExecutorInterface {
KernelBase *kernel) {
return false;
}
+ virtual bool LoadModule(const MultiModuleLoaderSpec &spec,
+ ModuleHandle *module_handle) {
+ return false;
+ }
+ virtual bool UnloadModule(ModuleHandle module_handle) { return false; }
virtual bool Launch(Stream *stream, const ThreadDim &thread_dims,
const BlockDim &block_dims, const KernelBase &k,
const KernelArgsArrayBase &args) {
@@ -246,7 +270,12 @@ class StreamExecutorInterface {
// null, however, both of them cannot be null at the same time. To use
// constant memory in CUDA, GetSymbol has to be used. Returns true if symbol
// is found.
- virtual bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes) {
+ //
+ // If ModuleHandle is set then we search for `symbol_name` only within the
+ // module corresponding to `module_handle`. Otherwise all loaded modules are
+ // searched.
+ virtual bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
+ void **mem, size_t *bytes) {
return false;
}
@@ -324,13 +353,14 @@ class StreamExecutorInterface {
virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0;
virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0;
- // Returns the CUDA context associated with this StreamExecutor platform
- // implementation.
+ // Returns the CUDA or ROCm context associated with this StreamExecutor
+ // platform implementation.
//
- // WARNING: checks that the underlying platform is, in fact, CUDA, causing a
- // fatal error if it is not. This hack is made available solely for use from
- // distbelief code, which temporarily has strong ties to CUDA as a platform.
- virtual void *CudaContextHack() { return nullptr; }
+ // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm,
+ // causing a fatal error if it is not. This hack is made available solely for
+ // use from distbelief code, which temporarily has strong ties to CUDA or ROCm
+ // as a platform.
+ virtual void *GpuContextHack() { return nullptr; }
private:
SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface);