diff options
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_internal.h')
-rw-r--r-- | tensorflow/stream_executor/stream_executor_internal.h | 364 |
1 files changed, 364 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h new file mode 100644 index 0000000000..5b4e596cfe --- /dev/null +++ b/tensorflow/stream_executor/stream_executor_internal.h @@ -0,0 +1,364 @@ +// Interfaces for platform-dependent implementations to satisfy. This are +// delegated to from the StreamExecutor in pointer-to-implementation style; i.e. +// the StreamExecutor is just a husk that delegates calls to the +// platform-specific objects which implement the interfaces defined here. + +#ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ +#define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ + +#include <functional> +#include <map> +#include <memory> +#include <utility> +#include <vector> + +#include "tensorflow/stream_executor/device_description.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/device_options.h" +#include "tensorflow/stream_executor/dnn.h" +#include "tensorflow/stream_executor/event.h" +#include "tensorflow/stream_executor/kernel.h" +#include "tensorflow/stream_executor/kernel_cache_config.h" +#include "tensorflow/stream_executor/kernel_spec.h" +#include "tensorflow/stream_executor/launch_dim.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/platform.h" +#include "tensorflow/stream_executor/platform/port.h" +#include "tensorflow/stream_executor/plugin_registry.h" +#include "tensorflow/stream_executor/shared_memory_config.h" +#include "tensorflow/stream_executor/trace_listener.h" +#include "tensorflow/stream_executor/lib/inlined_vector.h" + +namespace perftools { +namespace gputools { + +class KernelBase; +class Stream; +class Timer; + +namespace blas { +class BlasSupport; +} // namespace blas + +namespace fft { +class Support; +} // namespace fft + +namespace rng { +class RngSupport; +} // namespace rng + +} // namespace gputools +} // namespace perftools + +namespace perftools { +namespace gputools { +namespace internal { + +// Interface for the different StreamExecutor platforms (i.e. CUDA, OpenCL). +// +// Various platforms will provide an implementation that satisfy this interface. +class StreamExecutorInterface { + public: + // Default constructor for the abstract interface. + StreamExecutorInterface() {} + + // Default destructor for the abstract interface. + virtual ~StreamExecutorInterface() {} + + // Returns the (transitively) wrapped executor if this executor is + // wrapping another executor; otherwise, returns this. + virtual StreamExecutorInterface *GetUnderlyingExecutor() { return this; } + + // See the StreamExecutor interface for comments on the same-named methods. + virtual port::Status Init(int device_ordinal, + DeviceOptions device_options) = 0; + virtual bool GetKernel(const MultiKernelLoaderSpec &spec, + KernelBase *kernel) { + return false; + } + virtual bool Launch(Stream *stream, const ThreadDim &thread_dims, + const BlockDim &block_dims, const KernelBase &k, + const std::vector<KernelArg> &args) { + return false; + } + virtual void *Allocate(uint64 size) = 0; + virtual void *AllocateSubBuffer(DeviceMemoryBase *parent, uint64 offset, + uint64 size) = 0; + virtual void Deallocate(DeviceMemoryBase *mem) = 0; + virtual void *HostMemoryAllocate(uint64 size) = 0; + virtual void HostMemoryDeallocate(void *mem) = 0; + virtual bool HostMemoryRegister(void *mem, uint64 size) = 0; + virtual bool HostMemoryUnregister(void *mem) = 0; + virtual bool SynchronizeAllActivity() = 0; + virtual bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) = 0; + virtual bool SynchronousMemSet(DeviceMemoryBase *location, int value, + uint64 size) = 0; + virtual bool SynchronousMemcpy(DeviceMemoryBase *gpu_dst, + const void *host_src, uint64 size) = 0; + virtual bool SynchronousMemcpy(void *host_dst, + const DeviceMemoryBase &gpu_src, + uint64 size) = 0; + virtual bool SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *gpu_dst, + const DeviceMemoryBase &gpu_src, + uint64 size) = 0; + virtual bool MemZero(Stream *stream, DeviceMemoryBase *location, + uint64 size) = 0; + virtual bool Memset32(Stream *stream, DeviceMemoryBase *location, + uint32 pattern, uint64 size) = 0; + virtual bool Memcpy(Stream *stream, void *host_dst, + const DeviceMemoryBase &gpu_src, uint64 size) = 0; + virtual bool Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst, + const void *host_src, uint64 size) = 0; + virtual bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *gpu_dst, + const DeviceMemoryBase &host_src, + uint64 size) = 0; + virtual bool HostCallback(Stream *stream, std::function<void()> callback) = 0; + virtual port::Status AllocateEvent(Event *event) = 0; + virtual port::Status DeallocateEvent(Event *event) = 0; + virtual port::Status RecordEvent(Stream *stream, Event *event) = 0; + virtual port::Status WaitForEvent(Stream *stream, Event *event) = 0; + virtual Event::Status PollForEventStatus(Event *event) = 0; + virtual bool AllocateStream(Stream *stream) = 0; + virtual void DeallocateStream(Stream *stream) = 0; + virtual bool CreateStreamDependency(Stream *dependent, Stream *other) = 0; + virtual bool AllocateTimer(Timer *timer) = 0; + virtual void DeallocateTimer(Timer *timer) = 0; + virtual bool StartTimer(Stream *stream, Timer *timer) = 0; + virtual bool StopTimer(Stream *stream, Timer *timer) = 0; + virtual bool BlockHostUntilDone(Stream *stream) = 0; + virtual int PlatformDeviceCount() = 0; + virtual port::Status EnablePeerAccessTo(StreamExecutorInterface *other) = 0; + virtual bool CanEnablePeerAccessTo(StreamExecutorInterface *other) = 0; + virtual SharedMemoryConfig GetDeviceSharedMemoryConfig() = 0; + virtual port::Status SetDeviceSharedMemoryConfig( + SharedMemoryConfig config) = 0; + + virtual bool DeviceMemoryUsage(int64 *free, int64 *total) const { + return false; + } + + // Retrieves device pointer and size for a symbol. The device pointer is + // stored at mem, and the size is stored at size. Either mem or bytes can be + // null, however, both of them cannot be null at the same time. To use + // constant memory in CUDA, GetSymbol has to be used. Returns true if symbol + // is found. + virtual bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes) { + return false; + } + + // Creates a new DeviceDescription object. Ownership is transferred to the + // caller. + virtual DeviceDescription *PopulateDeviceDescription() const = 0; + + virtual KernelArg DeviceMemoryToKernelArg( + const DeviceMemoryBase &gpu_mem) const = 0; + + // Attempts to register the provided TraceListener with the device-specific + // Executor implementation. When this is called, the PIMPL interface has + // already taken ownership of the object and is managing the generic tracing + // events. The device-specific implementation must determine if the passed + // listener is of a type appropriate for it to trace during registration (and + // before dispatching events to it). + // Returns true if the listener was successfully registered, false otherwise. + // Does not take ownership of listener. + virtual bool RegisterTraceListener(TraceListener* listener) { return false; } + + // Unregisters the specified listener from the device-specific Executor. + // Returns true if the listener was successfully registered, false otherwise. + virtual bool UnregisterTraceListener(TraceListener* listener) { + return false; + } + + // Returns whether this StreamExecutor has BLAS support for its underlying + // platform. + virtual bool SupportsBlas() const { return false; } + + // Creates a new BlasSupport object, ownership is transferred to the caller. + // If SupportsBlas() is false, this will always return null. + // + // If SupportsBlas() is true, this may return null, for example, if the BLAS + // initialization fails. + virtual blas::BlasSupport *CreateBlas() { return nullptr; } + + // Returns whether this StreamExecutor has FFT support for its underlying + // platform. + virtual bool SupportsFft() const { return false; } + + // Creates a new fft::FftSupport object, ownership is transferred to the + // caller. + // If SupportsFft() is false, this will always return null. + // + // If SupportsFft() is true, this may return null, for example, if the FFT + // initialization fails. + virtual fft::FftSupport *CreateFft() { return nullptr; } + + // Returns whether this StreamExecutor has Random Number Generation support + // for + // its underlying platform. + virtual bool SupportsRng() const { return false; } + + // Returns whether this StreamExecutor has neural net support for its + // underlying + // platform. + virtual bool SupportsDnn() const { return false; } + + // Creates a new RngSupport object, ownership is transferred to the caller. + // If SupportsRng() is false, this will always return null. + // + // If SupportsRng() is true, this may return null, for example, if the RNG + // initialization fails. + virtual rng::RngSupport *CreateRng() { return nullptr; } + + // Creates a new DnnSupport object, ownership is transferred to the caller. + // If SupportsDnn() is false, this will always return null. + // + // If SupportsDnn() is true, this may return null, for example, if the RNG + // initialization fails. + virtual dnn::DnnSupport *CreateDnn() { return nullptr; } + + // Please read the warning below. This method is only temporary. See + // http://b/15759750 + // + // Returns the CUDA context associated with this StreamExecutor platform + // implementation. + // + // WARNING: checks that the underlying platform is, in fact, CUDA, causing a + // fatal error if it is not. This hack is made available solely for use from + // distbelief code, which temporarily has strong ties to CUDA as a platform. + virtual void *CudaContextHack() { return nullptr; } + + private: + SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface); +}; + +// Pointer-to-implementation object type (i.e. the KernelBase class delegates to +// this interface) with virtual destruction. This class exists for the +// platform-dependent code to hang any kernel data/resource info/functionality +// off of. +class KernelInterface { + public: + // Default constructor for the abstract interface. + KernelInterface() {} + + // Default destructor for the abstract interface. + virtual ~KernelInterface() {} + + // Returns the number of formal parameters that this kernel accepts. + virtual unsigned Arity() const = 0; + + // Sets the preferred cache configuration. + virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0; + + // Gets the preferred cache configuration. + virtual KernelCacheConfig GetPreferredCacheConfig() const = 0; + + private: + SE_DISALLOW_COPY_AND_ASSIGN(KernelInterface); +}; + +// Platform-dependent interface class for the generic Events interface, in +// the PIMPL style. +class EventInterface { + public: + EventInterface() {} + virtual ~EventInterface() {} + + private: + SE_DISALLOW_COPY_AND_ASSIGN(EventInterface); +}; + +// Pointer-to-implementation object type (i.e. the Stream class delegates to +// this interface) with virtual destruction. This class exists for the +// platform-dependent code to hang any kernel data/resource info/functionality +// off of. +class StreamInterface { + public: + // Default constructor for the abstract interface. + StreamInterface() {} + + // Default destructor for the abstract interface. + virtual ~StreamInterface() {} + + // Please read the warning below. This method is only temporary. See + // http://b/15759750 + // + // Returns the CUDA stream associated with this platform's stream + // implementation. + // + // WARNING: checks that the underlying platform is, in fact, CUDA, causing a + // fatal error if it is not. This hack is made available solely for use from + // distbelief code, which temporarily has strong ties to CUDA as a platform. + virtual void *CudaStreamHack() { return nullptr; } + + // Please read the warning above. This method is only temporary. See + // http://b/15759750 + // + // See the above comment on CudaStreamHack -- this further breaks abstraction + // for Eigen within distbelief, which has strong ties to CUDA as a platform, + // and a historical attachment to a programming model which takes a + // stream-slot rather than a stream-value. + virtual void **CudaStreamMemberHack() { return nullptr; } + + private: + SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface); +}; + +// Pointer-to-implementation object type (i.e. the Timer class delegates to +// this interface) with virtual destruction. This class exists for the +// platform-dependent code to hang any timer data/resource info/functionality +// off of. +class TimerInterface { + public: + // Default constructor for the abstract interface. + TimerInterface() {} + + // Default destructor for the abstract interface. + virtual ~TimerInterface() {} + + // Returns the number of microseconds elapsed in a completed timer. + virtual uint64 Microseconds() const = 0; + + // Returns the number of nanoseconds elapsed in a completed timer. + virtual uint64 Nanoseconds() const = 0; + + private: + SE_DISALLOW_COPY_AND_ASSIGN(TimerInterface); +}; + +// Extern functions for constructing platform-specific instances that conform to +// the StreamExecutor interface. (Defining constructor functions extern in this +// way prevents CUDA/OpenCL headers from leaking into any shared header files.) +// +// TODO(leary) switch this all over to registries. + +using StreamExecutorFactory = + std::function<StreamExecutorInterface *(const PluginConfig &)>; +using EventFactory = std::function<EventInterface *(StreamExecutor *)>; +using StreamFactory = std::function<StreamInterface *(StreamExecutor *)>; +using TimerFactory = std::function<TimerInterface *(StreamExecutor *)>; +using KernelFactory = std::function<KernelInterface*()>; + +EventFactory* MakeCUDAEventImplementation(); +StreamExecutorFactory* MakeCUDAExecutorImplementation(); +StreamFactory* MakeCUDAStreamImplementation(); +TimerFactory* MakeCUDATimerImplementation(); +KernelFactory* MakeCUDAKernelImplementation(); + +StreamExecutorFactory* MakeOpenCLExecutorImplementation(); +StreamExecutorFactory* MakeOpenCLAlteraExecutorImplementation(); +StreamFactory* MakeOpenCLStreamImplementation(); +TimerFactory* MakeOpenCLTimerImplementation(); +KernelFactory* MakeOpenCLKernelImplementation(); + +extern StreamExecutorFactory MakeHostExecutorImplementation; +extern StreamFactory MakeHostStreamImplementation; +extern TimerFactory MakeHostTimerImplementation; + + +} // namespace internal +} // namespace gputools +} // namespace perftools + +#endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ |