diff options
Diffstat (limited to 'tensorflow/stream_executor/temporary_device_memory.h')
-rw-r--r-- | tensorflow/stream_executor/temporary_device_memory.h | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/temporary_device_memory.h b/tensorflow/stream_executor/temporary_device_memory.h new file mode 100644 index 0000000000..4e7c63056b --- /dev/null +++ b/tensorflow/stream_executor/temporary_device_memory.h @@ -0,0 +1,123 @@ +// Temporary memories are used to allocate scratch space required by an +// operation about to be enqueued onto a stream. +// +// std::unique_ptr<TemporaryDeviceMemory<float>> temporary_memory = +// stream.AllocateTemporaryArray<float>(1024).ConsumeValueOrDie(); +// // ... enqueue stuff onto the stream using the temporary memory ... +// // Note that the memory is accessible via +// // temporary_memory->device_memory() and similar. +// +// // Finalize the temporary memory. The underlying device memory may +// // be released any time after this program point, as another thread may +// // call Stream::BlockHostUntilDone, causing synchronization. This +// // finalization also happens automatically for the user if the unique_ptr +// // goes out of scope. +// temporary_memory.Finalize(); +// +// WARNING: do NOT hold onto the device memory associated with temporary_memory +// after finalization. If temporary_memory->device_memory() is used after the +// temporary memory is finalized, it will cause a DCHECK failure. +// +// Note that standard usage takes advantage of the type-safe wrapper, +// TemporaryDeviceMemory<T>, defined below. +// +// Also see tests for executable sample usage. + +#ifndef TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_ +#define TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_ + +#include "tensorflow/stream_executor/device_memory.h" + +namespace perftools { +namespace gputools { + +class Stream; +namespace internal { +class TemporaryMemoryManager; +} + +// Untyped base class (analogous to a void*) for temporary device memory +// allocations associated with a stream. +class TemporaryDeviceMemoryBase { + public: + // Marks the temporary memory as finalized if it is not already marked as + // such. + ~TemporaryDeviceMemoryBase(); + + // Precondition: !IsFinalized() + DeviceMemoryBase* mutable_device_memory(); + + // Precondition: !IsFinalized() + const DeviceMemoryBase& device_memory() const; + + // "Finalizes" this temporary memory, making it acceptable to release at the + // next stream synchronization point -- the device memory can be reclaimed at + // any time after the temporary memory is marked as finalized (e.g. if a + // separate thread is calls Stream::BlockHostUntilDone). This may only be + // called once -- see the precondition below. + // + // Precondition: !IsFinalized() + void Finalize(); + + // Returns true iff the temporary memory is finalized (that is, the user is + // done referring to the temporary device memory, and thus it can be released + // at the next stream synchronization point). + bool IsFinalized() const; + + // Returns true iff the temporary memory is still allocated. + // + // Note: this is a polling call, no guarantee is made that the temporary + // memory is still allocated after the call has completed. + bool IsAllocated() const; + + private: + friend class internal::TemporaryMemoryManager; + friend class TemporaryDeviceMemoryTest; + + // Note: construction DCHECKs that the memory is known-allocated in the + // stream's temporary-allocation-manager. + TemporaryDeviceMemoryBase(Stream* parent, DeviceMemoryBase device_memory, + uint64 allocation_generation); + + // The device memory region that has allocated. + DeviceMemoryBase device_memory_; + + // The generation counter value for the temporary memory record in the + // temporary memory manager. + uint64 allocation_generation_; + + // The stream that this temporary memory was allocated for. + Stream* parent_; +}; + +// Type-safe wrapper around the base type (which is analogous to a void*). +template <typename T> +class TemporaryDeviceMemory : public TemporaryDeviceMemoryBase { + public: + // Type-safe wrapper around TemporaryDeviceMemoryBase::mutable_device_memory. + DeviceMemory<T>* mutable_device_memory() { + StaticSlicingAssertionDummy(); + return reinterpret_cast<DeviceMemory<T>*>( + TemporaryDeviceMemoryBase::mutable_device_memory()); + } + + // Type-safe wrapper around TemporaryDeviceMemoryBase::device_memory. + const DeviceMemory<T>& device_memory() const { + StaticSlicingAssertionDummy(); + return reinterpret_cast<const DeviceMemory<T>&>( + TemporaryDeviceMemoryBase::device_memory()); + } + + private: + static void StaticSlicingAssertionDummy() { + static_assert( + sizeof(TemporaryDeviceMemory) == sizeof(TemporaryDeviceMemoryBase), + "derived class is simply a wrapper, no members may be added due to " + "slicing"); + } +}; + +} // namespace gputools +} // namespace perftools + +#endif // TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_ |