// Temporary memories are used to allocate scratch space required by an // operation about to be enqueued onto a stream. // // std::unique_ptr> temporary_memory = // stream.AllocateTemporaryArray(1024).ConsumeValueOrDie(); // // ... enqueue stuff onto the stream using the temporary memory ... // // Note that the memory is accessible via // // temporary_memory->device_memory() and similar. // // // Finalize the temporary memory. The underlying device memory may // // be released any time after this program point, as another thread may // // call Stream::BlockHostUntilDone, causing synchronization. This // // finalization also happens automatically for the user if the unique_ptr // // goes out of scope. // temporary_memory.Finalize(); // // WARNING: do NOT hold onto the device memory associated with temporary_memory // after finalization. If temporary_memory->device_memory() is used after the // temporary memory is finalized, it will cause a DCHECK failure. // // Note that standard usage takes advantage of the type-safe wrapper, // TemporaryDeviceMemory, defined below. // // Also see tests for executable sample usage. #ifndef TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_ #define TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_ #include "tensorflow/stream_executor/device_memory.h" namespace perftools { namespace gputools { class Stream; namespace internal { class TemporaryMemoryManager; } // Untyped base class (analogous to a void*) for temporary device memory // allocations associated with a stream. class TemporaryDeviceMemoryBase { public: // Marks the temporary memory as finalized if it is not already marked as // such. ~TemporaryDeviceMemoryBase(); // Precondition: !IsFinalized() DeviceMemoryBase* mutable_device_memory(); // Precondition: !IsFinalized() const DeviceMemoryBase& device_memory() const; // "Finalizes" this temporary memory, making it acceptable to release at the // next stream synchronization point -- the device memory can be reclaimed at // any time after the temporary memory is marked as finalized (e.g. if a // separate thread is calls Stream::BlockHostUntilDone). This may only be // called once -- see the precondition below. // // Precondition: !IsFinalized() void Finalize(); // Returns true iff the temporary memory is finalized (that is, the user is // done referring to the temporary device memory, and thus it can be released // at the next stream synchronization point). bool IsFinalized() const; // Returns true iff the temporary memory is still allocated. // // Note: this is a polling call, no guarantee is made that the temporary // memory is still allocated after the call has completed. bool IsAllocated() const; private: friend class internal::TemporaryMemoryManager; friend class TemporaryDeviceMemoryTest; // Note: construction DCHECKs that the memory is known-allocated in the // stream's temporary-allocation-manager. TemporaryDeviceMemoryBase(Stream* parent, DeviceMemoryBase device_memory, uint64 allocation_generation); // The device memory region that has allocated. DeviceMemoryBase device_memory_; // The generation counter value for the temporary memory record in the // temporary memory manager. uint64 allocation_generation_; // The stream that this temporary memory was allocated for. Stream* parent_; }; // Type-safe wrapper around the base type (which is analogous to a void*). template class TemporaryDeviceMemory : public TemporaryDeviceMemoryBase { public: // Type-safe wrapper around TemporaryDeviceMemoryBase::mutable_device_memory. DeviceMemory* mutable_device_memory() { StaticSlicingAssertionDummy(); return reinterpret_cast*>( TemporaryDeviceMemoryBase::mutable_device_memory()); } // Type-safe wrapper around TemporaryDeviceMemoryBase::device_memory. const DeviceMemory& device_memory() const { StaticSlicingAssertionDummy(); return reinterpret_cast&>( TemporaryDeviceMemoryBase::device_memory()); } private: static void StaticSlicingAssertionDummy() { static_assert( sizeof(TemporaryDeviceMemory) == sizeof(TemporaryDeviceMemoryBase), "derived class is simply a wrapper, no members may be added due to " "slicing"); } }; } // namespace gputools } // namespace perftools #endif // TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_