aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/temporary_device_memory.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/temporary_device_memory.h')
-rw-r--r--tensorflow/stream_executor/temporary_device_memory.h123
1 files changed, 123 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/temporary_device_memory.h b/tensorflow/stream_executor/temporary_device_memory.h
new file mode 100644
index 0000000000..4e7c63056b
--- /dev/null
+++ b/tensorflow/stream_executor/temporary_device_memory.h
@@ -0,0 +1,123 @@
+// Temporary memories are used to allocate scratch space required by an
+// operation about to be enqueued onto a stream.
+//
+// std::unique_ptr<TemporaryDeviceMemory<float>> temporary_memory =
+// stream.AllocateTemporaryArray<float>(1024).ConsumeValueOrDie();
+// // ... enqueue stuff onto the stream using the temporary memory ...
+// // Note that the memory is accessible via
+// // temporary_memory->device_memory() and similar.
+//
+// // Finalize the temporary memory. The underlying device memory may
+// // be released any time after this program point, as another thread may
+// // call Stream::BlockHostUntilDone, causing synchronization. This
+// // finalization also happens automatically for the user if the unique_ptr
+// // goes out of scope.
+// temporary_memory.Finalize();
+//
+// WARNING: do NOT hold onto the device memory associated with temporary_memory
+// after finalization. If temporary_memory->device_memory() is used after the
+// temporary memory is finalized, it will cause a DCHECK failure.
+//
+// Note that standard usage takes advantage of the type-safe wrapper,
+// TemporaryDeviceMemory<T>, defined below.
+//
+// Also see tests for executable sample usage.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_
+
+#include "tensorflow/stream_executor/device_memory.h"
+
+namespace perftools {
+namespace gputools {
+
+class Stream;
+namespace internal {
+class TemporaryMemoryManager;
+}
+
+// Untyped base class (analogous to a void*) for temporary device memory
+// allocations associated with a stream.
+class TemporaryDeviceMemoryBase {
+ public:
+ // Marks the temporary memory as finalized if it is not already marked as
+ // such.
+ ~TemporaryDeviceMemoryBase();
+
+ // Precondition: !IsFinalized()
+ DeviceMemoryBase* mutable_device_memory();
+
+ // Precondition: !IsFinalized()
+ const DeviceMemoryBase& device_memory() const;
+
+ // "Finalizes" this temporary memory, making it acceptable to release at the
+ // next stream synchronization point -- the device memory can be reclaimed at
+ // any time after the temporary memory is marked as finalized (e.g. if a
+ // separate thread is calls Stream::BlockHostUntilDone). This may only be
+ // called once -- see the precondition below.
+ //
+ // Precondition: !IsFinalized()
+ void Finalize();
+
+ // Returns true iff the temporary memory is finalized (that is, the user is
+ // done referring to the temporary device memory, and thus it can be released
+ // at the next stream synchronization point).
+ bool IsFinalized() const;
+
+ // Returns true iff the temporary memory is still allocated.
+ //
+ // Note: this is a polling call, no guarantee is made that the temporary
+ // memory is still allocated after the call has completed.
+ bool IsAllocated() const;
+
+ private:
+ friend class internal::TemporaryMemoryManager;
+ friend class TemporaryDeviceMemoryTest;
+
+ // Note: construction DCHECKs that the memory is known-allocated in the
+ // stream's temporary-allocation-manager.
+ TemporaryDeviceMemoryBase(Stream* parent, DeviceMemoryBase device_memory,
+ uint64 allocation_generation);
+
+ // The device memory region that has allocated.
+ DeviceMemoryBase device_memory_;
+
+ // The generation counter value for the temporary memory record in the
+ // temporary memory manager.
+ uint64 allocation_generation_;
+
+ // The stream that this temporary memory was allocated for.
+ Stream* parent_;
+};
+
+// Type-safe wrapper around the base type (which is analogous to a void*).
+template <typename T>
+class TemporaryDeviceMemory : public TemporaryDeviceMemoryBase {
+ public:
+ // Type-safe wrapper around TemporaryDeviceMemoryBase::mutable_device_memory.
+ DeviceMemory<T>* mutable_device_memory() {
+ StaticSlicingAssertionDummy();
+ return reinterpret_cast<DeviceMemory<T>*>(
+ TemporaryDeviceMemoryBase::mutable_device_memory());
+ }
+
+ // Type-safe wrapper around TemporaryDeviceMemoryBase::device_memory.
+ const DeviceMemory<T>& device_memory() const {
+ StaticSlicingAssertionDummy();
+ return reinterpret_cast<const DeviceMemory<T>&>(
+ TemporaryDeviceMemoryBase::device_memory());
+ }
+
+ private:
+ static void StaticSlicingAssertionDummy() {
+ static_assert(
+ sizeof(TemporaryDeviceMemory) == sizeof(TemporaryDeviceMemoryBase),
+ "derived class is simply a wrapper, no members may be added due to "
+ "slicing");
+ }
+};
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_