diff options
Diffstat (limited to 'tensorflow/stream_executor/platform.h')
-rw-r--r-- | tensorflow/stream_executor/platform.h | 185 |
1 files changed, 185 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/platform.h b/tensorflow/stream_executor/platform.h new file mode 100644 index 0000000000..c8b500b424 --- /dev/null +++ b/tensorflow/stream_executor/platform.h @@ -0,0 +1,185 @@ +// Defines types and declares functions for identifying and extracting +// information about the types of platforms and supporting libraries for which +// StreamExecutor implementations exist. +#ifndef TENSORFLOW_STREAM_EXECUTOR_PLATFORM_H_ +#define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_H_ + +#include <map> +#include "tensorflow/stream_executor/platform/port.h" + +#include "tensorflow/stream_executor/device_options.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/platform/port.h" +#include "tensorflow/stream_executor/plugin.h" +#include "tensorflow/stream_executor/trace_listener.h" + +namespace perftools { +namespace gputools { + +class StreamExecutor; + +// Describes the platform for a StreamExecutor instantiation to act upon. +// +// Implementors: if you add a value here be sure to update PlatformKindString +// and CheckPlatformKindIsValid. +enum class PlatformKind { + kInvalid, + kCuda, + kOpenCL, + kOpenCLAltera, // Altera FPGA OpenCL platform. + // See documentation: go/fpgaopencl + // (StreamExecutor integration) + kHost, + kMock, + kSize, +}; + +// Returns true if kind represents a valid platform capable of enqueuing items +// on a stream, but not necessarily on an accelerator device. +// Returns false for kMock and any invalid PlatformKind values. +bool PlatformIsRunnable(PlatformKind kind); + +// Returns true if kind represents a valid platform capable of running kernels +// on an accelerator device. Returns false for kHost*, kMock and any invalid +// PlatformKind values. +bool PlatformIsRunnableOnDevice(PlatformKind kind); + +// Returns a printable description of a PlatformKind. +string PlatformKindString(PlatformKind kind); + +// Returns the PlatformKind corresponding to the input string; returns kInvalid +// in the case of no match. +PlatformKind PlatformKindFromString(string platform_string); + +// Checks that kind takes on a valid value. +void CheckPlatformKindIsValid(PlatformKind kind); + +// StreamExecutorConfig encapsulates the set of options for constructing a +// StreamExecutor for a given platform. +struct StreamExecutorConfig { + // Sets members to defaults: -1 for ordinal (must be changed), and default + // PluginConfig and DeviceOptions. + StreamExecutorConfig(); + + // Simple ordinal-setting constructor. + explicit StreamExecutorConfig(int ordinal); + + // The ordinal of the device to be managed by the returned StreamExecutor. + int ordinal; + + // The PluginConfig for the returned StreamExecutor. + PluginConfig plugin_config; + + // The DeviceOptions for the returned StreamExecutor. + DeviceOptions device_options; +}; + +// Abstract base class for a platform registered with the MultiPlatformManager. +class Platform { + public: + virtual ~Platform(); + + // A platform ID is a unique identifier for each registered platform type - + // each platform is required to expose an ID to ensure unique registration and + // as a target against which plugins can register. + // + // The macro below is provided to help generate a [process-unique] identifer. + using Id = void*; + +// Helper macro to define a plugin ID. To be used only inside plugin +// implementation files. Works by "reserving" an address/value (guaranteed to be +// unique) inside a process space. +#define PLATFORM_DEFINE_ID(ID_VAR_NAME) \ + namespace { \ + int plugin_id_value; \ + } \ + const perftools::gputools::Platform::Id ID_VAR_NAME = &plugin_id_value; + + // Returns a key uniquely identifying this platform. + virtual Id id() const = 0; + + // Returns the number of devices accessible on this platform. + // + // Note that, though these devices are visible, if there is only one userspace + // context allowed for the device at a time and another process is using this + // device, a call to ExecutorForDevice may return an error status. + virtual int VisibleDeviceCount() const = 0; + + // Name of this platform. + virtual const string& Name() const = 0; + + // Returns a device with the given ordinal on this platform with a default + // plugin configuration or, if none can be found with the given ordinal or + // there is an error in opening a context to communicate with the device, an + // error status is returned. + // + // Ownership of the executor is NOT transferred to the caller -- + // the Platform owns the executors in a singleton-like fashion. + virtual port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) = 0; + + // Returns a device or error, as above, with the specified plugins. + // + // Ownership of the executor is NOT transferred to the caller. + virtual port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig( + int ordinal, const PluginConfig& plugin_config) = 0; + + // Returns a device constructed with the options specified in "config". + // Ownership of the executor is NOT transferred to the caller. + virtual port::StatusOr<StreamExecutor*> GetExecutor( + const StreamExecutorConfig& config) = 0; + + // Returns a device constructed with the options specified in "config" without + // looking in or storing to the Platform's executor cache. + // Ownership IS transferred to the caller. + virtual port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor( + const StreamExecutorConfig& config) = 0; + + // Warning: this is a dangerous API and should be used with caution. + // + // Forces the platform to delete executor instances, releasing their + // associated device contexts. There must be no held instances of the executor + // and there must be no outstanding activity on the devices for this platform. + // + // This is only useful on platforms which bind a device to a single process + // that has obtained the device context. May return UNIMPLEMENTED on platforms + // that have no reason to destroy device contexts. + virtual port::Status ForceExecutorShutdown(); + + // Registers a TraceListener to listen to all StreamExecutors for this + // platform. + // Takes ownership of listener. + virtual void RegisterTraceListener( + std::unique_ptr<TraceListener> listener) = 0; + + // Removes the specified TraceListener from all StreamExecutors. + virtual void UnregisterTraceListener(TraceListener* listener) = 0; + + // Map of executor-to-executor coordinate and boolean, indicating if the first + // executor can access the second's memory. + using PeerAccessMap = std::map<std::pair<int, int>, bool>; + + // Returns a matrix indicating which executors can access which other + // executors' memory. + virtual std::unique_ptr<PeerAccessMap> GetPeerAccessMap(); + + // Attempts to enable all peer-to-peer access links described by the result of + // GetPeerAccessMap(). Note that calling this routine will force the creation + // of a default-argument (see StreamExecutorConfig) StreamExecutor object for + // each device ordinal in the system, should any not yet exist. + virtual port::Status EnablePeerAccess(); + + protected: + // SE_DISALLOW_COPY_AND_ASSIGN declares a constructor, which suppresses the + // presence of the default constructor. This statement re-enables it, which + // simplifies subclassing. + Platform() = default; + + private: + SE_DISALLOW_COPY_AND_ASSIGN(Platform); +}; + +} // namespace gputools +} // namespace perftools + +#endif // TENSORFLOW_STREAM_EXECUTOR_PLATFORM_H_ |