aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream_executor_pimpl.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_pimpl.h')
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h725
1 files changed, 725 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
new file mode 100644
index 0000000000..29ab235d0e
--- /dev/null
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -0,0 +1,725 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
+#define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
+
+#include <atomic>
+#include <set>
+#include <tuple>
+#include <vector>
+
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/lib/strcat.h"
+#include "tensorflow/stream_executor/lib/threadpool.h"
+#include "tensorflow/stream_executor/platform.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
+#include "tensorflow/stream_executor/rng.h"
+#include "tensorflow/stream_executor/shared_memory_config.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+#include "tensorflow/stream_executor/trace_listener.h"
+
+namespace perftools {
+namespace gputools {
+
+// Structure used for device memory leak checking.
+struct AllocRecord {
+ // The requested allocation size of the buffer.
+ uint64 bytes;
+
+ // Holds a representation of the stack at the time the associated buffer was
+ // allocated. Produced in a form described in
+ // //util/symbolize/symbolized_stacktrace.h.
+ string stack_trace;
+};
+
+// Forward declaration of private friend class.
+template <typename BeginCallT, typename CompleteCallT,
+ typename ReturnT, typename... BeginArgsT>
+class ScopedTracer;
+
+// A StreamExecutor manages a single device, in terms of executing work (kernel
+// launches) and memory management (allocation/deallocation, memory copies to
+// and from the device). It is conceptually the "handle" for a device -- Stream
+// objects, which are used to enqueue work to run on the
+// coprocessor have a StreamExecutor instance as their "parent" object.
+//
+// StreamExecutor objects have an underlying platform that is specified up
+// front;
+// e.g. either it is a CUDA or OpenCL executor.
+//
+// Thread-safe after initialization.
+// StreamExecutor interface should not be invoked from a signal handler.
+class StreamExecutor {
+ public:
+ explicit StreamExecutor(PlatformKind kind,
+ const PluginConfig &plugin_config = PluginConfig());
+
+ // Primarily used for testing.
+ StreamExecutor(PlatformKind kind,
+ internal::StreamExecutorInterface *implementation);
+
+ ~StreamExecutor();
+
+ port::Status Init();
+ port::Status Init(int device_ordinal, DeviceOptions device_options);
+
+ // Returns the platform that this StreamExecutor is acting upon.
+ PlatformKind platform_kind() const { return platform_kind_; }
+
+ // Retrieves (loads) a kernel for the platform this StreamExecutor is acting
+ // upon, if one exists.
+ //
+ // Parameters:
+ // spec: The MultiKernelLoaderSpec is usually generated as a compile-time
+ // constant into an appropriate namespace. For example, see
+ // perftools::gputools::executor_sample::kKernelLoaderSpecs, from which a
+ // MultiKernelLoaderSpec is selected.
+ // kernel: Outparam that the kernel is loaded into. A given Kernel
+ // instantiation should not be loaded into more than once.
+ //
+ // If an error occurs, or there is no kernel available for the StreamExecutor
+ // platform, false is returned.
+ bool GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel);
+
+ // Synchronously allocates an array on the GPU device of type T with
+ // element_count elements.
+ template <typename T>
+ DeviceMemory<T> AllocateArray(uint64 element_count);
+
+ // As AllocateArray(), but returns a ScopedDeviceMemory<T>.
+ template <typename T>
+ ScopedDeviceMemory<T> AllocateOwnedArray(uint64 element_count) {
+ return ScopedDeviceMemory<T>(this, AllocateArray<T>(element_count));
+ }
+
+ // Convenience wrapper that allocates space for a single element of type T
+ // in GPU memory.
+ template <typename T>
+ DeviceMemory<T> AllocateScalar() {
+ return AllocateArray<T>(1);
+ }
+
+ // As AllocateScalar(), but returns a ScopedDeviceMemory<T>.
+ template <typename T>
+ ScopedDeviceMemory<T> AllocateOwnedScalar() {
+ return AllocateOwnedArray<T>(1);
+ }
+
+ // Synchronously allocates a scalar of type T on the GPU device that is
+ // (POD) zero-byte initialized.
+ template <typename T>
+ DeviceMemory<T> AllocateZeroed();
+
+ // As AllocateZeroed(), but returns a ScopedDeviceMemory<T>.
+ template <typename T>
+ ScopedDeviceMemory<T> AllocateOwnedZeroed() {
+ return ScopedDeviceMemory<T>(this, AllocateZeroed<T>());
+ }
+
+ // Allocate a memory region inside another allocated memory region.
+ // Offset and size are specified in terms of T elements.
+ // Warning: Do not free a parent buffer before its sub-buffers; this may cause
+ // use-after-free issues (the specific behavior is not consistent across
+ // platforms).
+ // - Note: OpenCL uses refcounting to manage buffer lifetimes, so use of a
+ // sub-buffer after parent deallocation is expected to be safe. This will
+ // render your code non-platform-portable, however.
+ template <typename T>
+ DeviceMemory<T> AllocateSubBuffer(DeviceMemory<T> *parent,
+ uint64 element_offset,
+ uint64 element_count);
+
+ // As AllocateSubBuffer(), but returns a ScopedDeviceMemory<T>.
+ template <typename T>
+ ScopedDeviceMemory<T> AllocateOwnedSubBuffer(DeviceMemory<T> *parent,
+ uint64 element_offset,
+ uint64 element_count) {
+ return ScopedDeviceMemory<T>(
+ this, AllocateSubBuffer<T>(parent, element_offset, element_count));
+ }
+
+ // Finds a symbol and returns device memory allocated to the symbol. The
+ // symbol is searched in any kernels that were previously loaded through
+ // GetKernel() before the GetSymbol() call. The user has to make sure that the
+ // type of symbol and T match.
+ // - Note: symbol_name should include its namespace as well. For example,
+ // pass "nms0::symbol" if referring to nms0::symbol.
+ template <typename T>
+ port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name);
+
+ // Deallocate the DeviceMemory previously allocated via this interface.
+ // Deallocation of a nullptr-representative value is permitted.
+ //
+ // Resets the internal contents of mem to be null-representative, but this
+ // null-out effect should not be relied upon in client code.
+ void Deallocate(DeviceMemoryBase *mem);
+
+ // Retrieves a mapping of active opaque GPU memory pointer to a string
+ // representation of the [allocating thread's] stack at the time the pointer
+ // was allocated. Useful for tracking GPU memory leaks.
+ //
+ // Note: this will only be populated if --check_gpu_leaks flag is activated.
+ void GetMemAllocs(std::map<void *, AllocRecord> *records_out);
+
+ // Allocates a region of host memory and registers it with the platform API.
+ // Memory allocated in this manner (or allocated and registered with
+ // HostMemoryRegister() is required for use in asynchronous memcpy operations,
+ // such as Stream::ThenMemcpy.
+ void *HostMemoryAllocate(uint64 bytes);
+
+ // Deallocates a region of host memory allocated by HostMemoryAllocate().
+ void HostMemoryDeallocate(void *location);
+
+ // Registers a region of host memory with the platform API. Registered memory
+ // (or memory allocated with HostMemoryAllocate) is required for use with
+ // asynchronous memcpy operations, such as Stream::ThenMemcpy. This method
+ // is used to register memory allocated outside the StreamExecutor;
+ // HostMemoryAllocate implicitly registers its allocations and
+ // HostMemoryDeallocate implicitly deregisters on deallocation.
+ bool HostMemoryRegister(void *location, uint64 size) SE_MUST_USE_RESULT;
+
+ // Unregisters a region of host memory registered with HostMemoryRegister.
+ // This should be done before deallocating the region with delete[]/free/etc.
+ bool HostMemoryUnregister(void *location) SE_MUST_USE_RESULT;
+
+ // Synchronizes all activity occuring in the StreamExecutor's context (most
+ // likely a whole device).
+ bool SynchronizeAllActivity() SE_MUST_USE_RESULT;
+
+ // Blocks the caller while "size" bytes are zeroed out (in POD fashion) at the
+ // given location in GPU memory.
+ bool SynchronousMemZero(DeviceMemoryBase *location,
+ uint64 size) SE_MUST_USE_RESULT;
+
+ // Blocks the caller while "size" bytes are initialized to "value" (in POD
+ // fashion) at the given location in GPU memory.
+ bool SynchronousMemSet(DeviceMemoryBase *location, int value,
+ uint64 size) SE_MUST_USE_RESULT;
+
+ // [deprecated] Blocks the caller while a data segment of the given size is
+ // copied from the host source to the GPU destination.
+ //
+ // Deprecation: prefer explicit H2D below, to avoid error-prone API usage.
+ bool SynchronousMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
+ uint64 size) SE_MUST_USE_RESULT;
+
+ // [deprecated] Blocks the caller while a data segment of the given size is
+ // copied from the GPU source to the host destination.
+ //
+ // Deprecation: prefer explicit D2H below, to avoid error-prone API usage.
+ bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
+ uint64 size) SE_MUST_USE_RESULT;
+
+ // Same as SynchronousMemcpy(DeviceMemoryBase*, ...) above.
+ port::Status SynchronousMemcpyH2D(const void *host_src, int64 size,
+ DeviceMemoryBase *gpu_dst);
+
+ // Alternative interface for memcpying from host to device that takes an
+ // array slice. Checks that the destination size can accomodate the host
+ // slice size.
+ template <class T>
+ port::Status SynchronousMemcpyH2D(port::ArraySlice<T> host_src,
+ DeviceMemoryBase *gpu_dst) {
+ auto host_size = host_src.size() * sizeof(T);
+ CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size);
+ return SynchronousMemcpyH2D(host_src.begin(), host_size, gpu_dst);
+ }
+
+ // Same as SynchronousMemcpy(void*, ...) above.
+ port::Status SynchronousMemcpyD2H(const DeviceMemoryBase &gpu_src, int64 size,
+ void *host_dst);
+
+ // Alternative interface for memcpying from device to host that takes an
+ // array slice. Checks that the destination size can accomodate the host
+ // slice size.
+ template <typename T>
+ port::Status SynchronousMemcpyD2H(const DeviceMemory<T> &gpu_src,
+ port::MutableArraySlice<T> host_dst) {
+ auto host_size = host_dst.size() * sizeof(T);
+ CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size());
+ return SynchronousMemcpyD2H(gpu_src, host_size, host_dst.begin());
+ }
+
+ // Blocks the caller while a data segment of the given size is copied from the
+ // GPU source to the GPU destination.
+ bool SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
+ const DeviceMemoryBase &gpu_src,
+ uint64 size) SE_MUST_USE_RESULT;
+
+ // Enqueues an operation onto stream to zero out size bytes at the given GPU
+ // memory location. Neither stream nor location may be null. Returns whether
+ // the operation was successfully enqueued onto the stream.
+ bool MemZero(Stream *stream, DeviceMemoryBase *location,
+ uint64 size) SE_MUST_USE_RESULT;
+
+ // Enqueues an operation onto stream to set 32-bit patterns starting at
+ // location, for byte count given by size. size must be 32-bit quantified
+ // (i.e. evently divisible by 4). Returns whether the operation was
+ // successfully enqueued onto the stream.
+ bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern,
+ uint64 size) SE_MUST_USE_RESULT;
+
+ // Enables peer access from this StreamExecutor to memory
+ // allocated by other, such that launched device code, memcpies, etc may
+ // access it directly.
+ //
+ // Both this StreamExecutor and other must be backed by the same platform (as
+ // in
+ // CUDA vs OpenCL) implementation.
+ port::Status EnablePeerAccessTo(StreamExecutor *other);
+
+ // Returns whether it's possible to enable peer access from this
+ // StreamExecutor
+ // to memory allocated by another.
+ //
+ // Even when this returns true, EnablePeerAccessTo may fail for other reasons;
+ // this is more an up-front test as to whether it's expressly forbidden.
+ bool CanEnablePeerAccessTo(StreamExecutor *other);
+
+ // Gets the preferred shared memory configuration for the device to which this
+ // executor is bound.
+ SharedMemoryConfig GetDeviceSharedMemoryConfig();
+
+ // Sets the preferred shared memory configuration for the device to which this
+ // executor is bound.
+ port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config);
+
+ // Obtains metadata about the underlying device.
+ // The value is cached on first use.
+ const DeviceDescription &GetDeviceDescription() const;
+
+ // Returns the underlying device memory usage information, if it is available.
+ // If it is not available (false is returned), free/total may not be
+ // initialized.
+ //
+ // Note: "Free" reflects the amount of free memory on the underlying device,
+ // so allocations via other StreamExecutors that have the same underlying
+ // device
+ // will be reflected in "free".
+ bool DeviceMemoryUsage(int64 *free, int64 *total) const;
+
+ // The device count reported by this StreamExecutor's platform.
+ // Note: on OpenCL we implicitly select platform zero at the moment.
+ int PlatformDeviceCount() const;
+
+ // Returns whether the StreamExecutor supports BLAS routines for the platform
+ // that underlies this interface.
+ bool SupportsBlas() const;
+
+ // Returns whether the StreamExecutor supports FFT routines for the platform
+ // that underlies this interface.
+ bool SupportsFft() const;
+
+ // Returns whether the StreamExecutor supports RNG routines for the platform
+ // that underlies this interface.
+ bool SupportsRng() const;
+
+ // Returns whether the StreamExecutor support neural net routines for the
+ // platform that underlies this interface.
+ bool SupportsDnn() const;
+
+ // Returns the device ordinal that this StreamExecutor was initialized with.
+ // Meaningless before initialization.
+ int device_ordinal() const { return device_ordinal_; }
+
+ // Returns a borrowed pointer to the underlying StreamExecutor implementation.
+ internal::StreamExecutorInterface *implementation();
+
+ // Warning: use Stream::ThenLaunch instead, this method is not for general
+ // consumption. However, this is the only way to launch a kernel for which
+ // the type signature is only known at runtime; say, if an application
+ // supports loading/launching kernels with arbitrary type signatures.
+ // In this case, the application is expected to know how to do parameter
+ // packing that obeys the contract of the underlying platform implementation.
+ //
+ // Launches a data parallel kernel with the given thread/block
+ // dimensionality and already-packed args/sizes to pass to the underlying
+ // platform driver.
+ //
+ // This is called by Stream::Launch() to delegate to the platform's launch
+ // implementation in StreamExecutorInterface::Launch().
+ bool Launch(Stream *stream, const ThreadDim &thread_dims,
+ const BlockDim &block_dims, const KernelBase &kernel,
+ const std::vector<KernelArg> &args);
+
+ // Gets-or-creates (creates with memoization) a FftSupport datatype that can
+ // be used to execute FFT routines on the current platform.
+ //
+ // Ownership and user-facing is the same as AsBlas() below.
+ //
+ // Returns null if there was an error initializing the FFT support for the
+ // underlying platform.
+ fft::FftSupport *AsFft();
+
+ // Gets-or-creates (creates with memoization) a DnnSupport datatype that can
+ // be used for neural network routines on the current platform.
+ //
+ // Ownership and user-facing is the same as AsBlas() below.
+ //
+ // Returns null if there was an error initializing the DNN support for the
+ // underlying platform.
+ dnn::DnnSupport *AsDnn();
+
+ // Turns StreamExecutor operation tracing on or off.
+ void EnableTracing(bool enable);
+
+ // Registers a trace listener to receive callbacks for only a single
+ // StreamExecutor instance.
+ // To register a listener for all executors for a given platform, see
+ // Platform::RegisterTraceListener().
+ // Does not take ownership of listener.
+ void RegisterTraceListener(TraceListener* listener);
+
+ // Removes a TraceListener from this StreamExecutor instance.
+ // Returns false (and logs) in cases where the argument listener was not
+ // previously registered.
+ bool UnregisterTraceListener(TraceListener* listener);
+
+ // Converts a DeviceMemory object into a KernelArg object for passing to the
+ // device driver for kernel launch.
+ KernelArg DeviceMemoryToKernelArg(const DeviceMemoryBase &gpu_mem) const;
+
+ private:
+ template <typename BeginCallT, typename CompleteCallT,
+ typename ReturnT, typename... BeginArgsT>
+ friend class ScopedTracer;
+ friend class Event;
+ friend class Stream;
+ friend class Timer;
+ template <typename... Params>
+ friend class TypedKernel;
+ template <typename... Args>
+ friend struct ThenBlasImpl;
+
+ // Gets-or-creates (creates with memoization) a BlasSupport datatype that can
+ // be used to execute BLAS routines on the current platform. This is typically
+ // not user-facing, as users will use the Stream::ThenBlas* family of routines
+ // to entrain BLAS operations. See blas.h for additional details.
+ //
+ // Ownership is not transferred to the caller -- ownership is retained by this
+ // object for memoization. This BLAS interface is also only expected to be
+ // used by a Stream for entraining calls to BLAS functionality.
+ //
+ // Returns null if there was an error initializing the BLAS support for the
+ // underlying platform.
+ blas::BlasSupport *AsBlas();
+
+ // Gets-or-creates (creates with memoization) an RngSupport datatype that can
+ // be used for random-number-generation routines on the current platform.
+ //
+ // Ownership and user-facing is the same as AsBlas() above.
+ //
+ // Returns null if there was an error initializing the RNG support for the
+ // underlying platform.
+ rng::RngSupport *AsRng();
+
+ // Causes the host code to synchronously wait for operations entrained onto
+ // stream to complete. Effectively a join on the asynchronous GPU operations
+ // enqueued on the stream before this program point.
+ bool BlockHostUntilDone(Stream *stream);
+
+ // Synchronously allocates size bytes on the underlying platform and returns
+ // an opaque void* representing that allocation. In the case of failure,
+ // nullptr is returned.
+ void *Allocate(uint64 size);
+
+ // Finds and retrieves device memory for the symbol on the underlying
+ // platform.
+ bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes);
+
+ // Entrains a memcpy operation onto stream, with a host destination location
+ // host_dst and a GPU memory source, with target size size.
+ bool Memcpy(Stream *stream, void *host_dst, const DeviceMemoryBase &gpu_src,
+ uint64 size);
+
+ // Entrains a memcpy operation onto stream, with a GPU destination location
+ // and a host memory source, with target size size.
+ bool Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst, const void *host_src,
+ uint64 size);
+
+ // Entrains a memcpy operation onto stream, with a GPU destination location
+ // and a GPU source location, with target size size. Peer access should have
+ // been enabled between the StreamExecutors owning the GPU memory regions.
+ bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *gpu_dst,
+ const DeviceMemoryBase &gpu_src, uint64 size);
+
+ // Entrains on a stream a user-specified function to be run on the host.
+ // See Stream::ThenDoHostCallback for full details.
+ bool HostCallback(Stream *stream, std::function<void()> callback);
+
+ // Performs platform-specific allocation and initialization of an event.
+ port::Status AllocateEvent(Event *event);
+
+ // Performs platform-specific deallocation and cleanup of an event.
+ port::Status DeallocateEvent(Event *event);
+
+ // Inserts the specified event at the end of the specified stream.
+ port::Status RecordEvent(Stream *stream, Event *event);
+
+ // Wait for the specified event at the end of the specified stream.
+ port::Status WaitForEvent(Stream *stream, Event *event);
+
+ // Requests the current status of the event from the underlying platform.
+ Event::Status PollForEventStatus(Event *event);
+
+ // Allocates stream resources on the underlying platform for subject and
+ // initializes its internals.
+ bool AllocateStream(Stream *subject);
+
+ // Deallocates stream resources on the underlying platform.
+ void DeallocateStream(Stream *subject);
+
+ // Causes dependent to not begin execution until other has finished its
+ // last-enqueued work.
+ bool CreateStreamDependency(Stream *dependent, Stream *other);
+
+ // Allocates timer resources on the underlying platform for subject and
+ // initializes its internals.
+ bool AllocateTimer(Timer *subject);
+
+ // Deallocates timer resources on the underlying platform.
+ void DeallocateTimer(Timer *subject);
+
+ // Records a start event for an interval timer.
+ bool StartTimer(Stream *stream, Timer *timer);
+
+ // Records a stop event for an interval timer.
+ bool StopTimer(Stream *stream, Timer *timer);
+
+ // Allocates a new metadata object, appropriately populated, on the heap, with
+ // ownership transfer to caller.
+ DeviceDescription *PopulateDeviceDescription() const;
+
+ // Adds a task to the port::ThreadPool work queue. These tasks must be
+ // fire-and-forget and have no external data or timing dependencies; their
+ // execution order and completion time have no guarantees.
+ // For an example of an appropriate task, see HostBlas::DoBlasGemmInternal;
+ // there, temporary internal buffers are freed using this method.
+ void EnqueueOnBackgroundThread(std::function<void()> task);
+
+ // Adds an AllocRecord for 'opaque' of size 'bytes' to the record map, for
+ // leak checking. NULL buffer pointers and buffer sizes of 0 will not be
+ // tracked.
+ void CreateAllocRecord(void *opaque, uint64 size);
+
+ // Removes the AllocRecord keyed by 'opaque' from the record map. NULL
+ // pointers will not be erased (as they're not tracked, per above).
+ void EraseAllocRecord(void *opaque);
+
+ // Calls the relevant TraceListener routine to begin tracing for the specified
+ // asynchronous method.
+ template <typename TraceCallT, typename... ArgsT>
+ void SubmitTrace(TraceCallT trace_call, ArgsT&&... args);
+
+ // Reader/writer lock for class-static StreamExecutor members.
+ static mutex static_mu_;
+
+ // Reader/writer lock for mutable data structures on this StreamExecutor.
+ //
+ // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.)
+ // can acquire the lock on their first (mutating) call as well.
+ mutable mutex mu_;
+
+ // A mapping of pointer (to GPU memory) to string representation of the stack
+ // (of the allocating thread) at the time at which the pointer was allocated.
+ std::map<void *, AllocRecord> mem_allocs_ GUARDED_BY(mu_);
+
+ // Pointer to the platform-specific-interface implementation. This is
+ // delegated to by the interface routines in pointer-to-implementation
+ // fashion.
+ std::unique_ptr<internal::StreamExecutorInterface> implementation_;
+
+ // Memoized BLAS support object -- we only want to create this once when asked
+ // for a BLAS interface.
+ std::unique_ptr<blas::BlasSupport> blas_ GUARDED_BY(mu_);
+
+ // Memoized DNN support object -- we only want to create this once when asked
+ // for an DNN interface.
+ std::unique_ptr<dnn::DnnSupport> dnn_ GUARDED_BY(mu_);
+
+ // Memoized FFT support object -- we only want to create this once when asked
+ // for a FFT interface.
+ std::unique_ptr<fft::FftSupport> fft_;
+
+ // Memoized RNG support object -- we only want to create this once when asked
+ // for an RNG interface.
+ std::unique_ptr<rng::RngSupport> rng_ GUARDED_BY(mu_);
+
+ // Slot to cache the owned DeviceDescription for the underlying device
+ // once it has been quieried from DeviceDescription().
+ mutable std::unique_ptr<DeviceDescription> device_description_
+ GUARDED_BY(mu_);
+
+ // The kind of the underlying platform that is being targeted, as passed
+ // during construction.
+ //
+ // Immutable post-initialization.
+ PlatformKind platform_kind_;
+
+ // The device ordinal that this object was initialized with.
+ //
+ // Immutable post-initialization.
+ int device_ordinal_;
+
+ // Executor for handling host callback work that cannot be performed
+ // by a host callback thread - for example, cleanup after a host BLAS routine
+ // (which may make device API calls). This work cannot block the host
+ // callback thread, will be completed asynchronously, and should be treated
+ // as fire-and-forget. Assume no ordering guarantees WRT the tasks enqueued
+ // here.
+ //
+ // Immutable post-initialization. Object is thread-safe.
+ std::unique_ptr<port::ThreadPool> background_threads_;
+
+ // Counter for the current number of live streams. This is used to check
+ // for accidentally-outstanding streams at StreamExecutor teardown time, as
+ // well
+ // as to indicate leaks (via a large outstanding count being logged) in the
+ // case we can't allocate more streams.
+ std::atomic_int_fast32_t live_stream_count_;
+
+ // Only one worker thread is needed; little work will be done by the
+ // executor.
+ static const int kNumBackgroundThreads = 1;
+
+ // Indicates if StreamExecutor operation tracing should be performed.
+ bool tracing_enabled_;
+
+ // The set of TraceListeners registered for this StreamExecutor.
+ std::set<TraceListener*> listeners_ GUARDED_BY(mu_);
+
+ SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
+};
+
+////////////
+// Inlines
+
+template <typename T>
+inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count) {
+ uint64 bytes = sizeof(T) * element_count;
+ void *opaque = Allocate(bytes);
+ return DeviceMemory<T>::MakeFromByteSize(opaque, bytes);
+}
+
+template <typename T>
+inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
+ const string &symbol_name) {
+ // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
+ // be nullptr/0 for consistency with DeviceMemory semantics.
+ void *opaque = nullptr;
+ size_t bytes = 0;
+ if (GetSymbol(symbol_name, &opaque, &bytes)) {
+ CHECK_EQ(bytes % sizeof(T), 0);
+ return DeviceMemory<T>::MakeFromByteSize(opaque, bytes);
+ }
+ return port::Status(
+ port::error::NOT_FOUND,
+ port::StrCat("Check if kernel using the symbol is loaded: ",
+ symbol_name));
+}
+
+template <typename ElemT>
+ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(StreamExecutor *parent,
+ DeviceMemoryBase value)
+ : wrapped_(value), parent_(parent) {}
+
+template <typename ElemT>
+ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(
+ StreamExecutor *parent, std::initializer_list<ElemT> values)
+ : ScopedDeviceMemory(parent, parent->AllocateArray<ElemT>(values.size())) {
+ if (ptr() != nullptr) {
+ std::vector<ElemT> local(values);
+ if (!parent->SynchronousMemcpy(ptr(), const_cast<const ElemT *>(&local[0]),
+ ptr()->size())) {
+ Reset(nullptr);
+ }
+ }
+}
+
+template <typename ElemT>
+ScopedDeviceMemory<ElemT>::~ScopedDeviceMemory() {
+ parent_->Deallocate(&wrapped_);
+}
+
+template <typename ElemT>
+void ScopedDeviceMemory<ElemT>::Reset(DeviceMemory<ElemT> updated) {
+ parent_->Deallocate(&wrapped_);
+ wrapped_ = updated;
+}
+
+template <typename ElemT>
+void ScopedDeviceMemory<ElemT>::Reset(std::nullptr_t) {
+ parent_->Deallocate(&wrapped_);
+ wrapped_ = DeviceMemory<ElemT>{};
+}
+
+template <typename T>
+DeviceMemory<T> StreamExecutor::AllocateZeroed() {
+ void *opaque = Allocate(sizeof(T));
+ if (opaque == nullptr) {
+ return DeviceMemory<T>{};
+ }
+
+ DeviceMemory<T> result = DeviceMemory<T>::MakeFromByteSize(opaque, sizeof(T));
+ bool ok = SynchronousMemZero(&result, sizeof(T));
+ if (!ok) {
+ Deallocate(&result);
+ return DeviceMemory<T>{};
+ }
+
+ return result;
+}
+
+template <typename T>
+DeviceMemory<T> StreamExecutor::AllocateSubBuffer(DeviceMemory<T> *parent,
+ uint64 element_offset,
+ uint64 element_count) {
+ if (element_offset + element_count > parent->ElementCount()) {
+ LOG(ERROR) << "requested sub-buffer allocation (offset + size) is greater "
+ << "than parent allocation size: (" << element_offset << " + "
+ << element_count << ") vs. (" << parent->ElementCount() << ")";
+ return DeviceMemory<T>{};
+ }
+
+ void *opaque = implementation_->AllocateSubBuffer(
+ parent, sizeof(T) * element_offset, sizeof(T) * element_count);
+ if (opaque == nullptr) {
+ return DeviceMemory<T>{};
+ }
+ CreateAllocRecord(opaque, sizeof(T) * element_count);
+ return DeviceMemory<T>(DeviceMemoryBase(opaque, sizeof(T) * element_count,
+ true /* = is_sub_buffer */));
+}
+
+template <typename... Params, typename... Args>
+inline Stream &Stream::ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
+ const TypedKernel<Params...> &kernel,
+ Args... args) {
+ KernelInvocationChecker<std::tuple<Params...>,
+ std::tuple<Args...>>::CheckAllStaticAssert();
+ if (ok()) {
+ // This is the core that allows type-safe kernel launching.
+ // Since the platforms take kernel arguments as tuples of (void *, size),
+ // we pack the variadic parameters passed as ...args into the desired
+ // tuple form and pass that packed form to the StreamExecutor::Launch()
+ // implementation.
+ std::vector<KernelArg> kernel_args;
+ kernel_args.reserve(kernel.Arity());
+ kernel.PackParams(&kernel_args, args...);
+ bool ok =
+ parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args);
+ if (!ok) {
+ SetError();
+ LOG(WARNING) << "parent failed to launch kernel: " << &kernel;
+ }
+ }
+ return *this;
+}
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_