diff options
author | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
---|---|---|
committer | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
commit | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch) | |
tree | ef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/stream_executor/temporary_memory_manager.h |
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation
using data flow graphs.
Base CL: 107276108
Diffstat (limited to 'tensorflow/stream_executor/temporary_memory_manager.h')
-rw-r--r-- | tensorflow/stream_executor/temporary_memory_manager.h | 138 |
1 files changed, 138 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/temporary_memory_manager.h b/tensorflow/stream_executor/temporary_memory_manager.h new file mode 100644 index 0000000000..847f0f2182 --- /dev/null +++ b/tensorflow/stream_executor/temporary_memory_manager.h @@ -0,0 +1,138 @@ +// The temporary-memory-manager is a helper class for a Stream to keep track of +// temporary allocations. These allocations defer their deallocation to the next +// Stream::BlockHostUntilDone call for efficiency purposes (as deallocation +// itself generally forces synchronization to occur). + +#ifndef TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_MEMORY_MANAGER_H_ +#define TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_MEMORY_MANAGER_H_ + +#include <map> +#include <memory> + +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/platform/mutex.h" +#include "tensorflow/stream_executor/platform/thread_annotations.h" +#include "tensorflow/stream_executor/temporary_device_memory.h" + +namespace perftools { +namespace gputools { +namespace internal { + +// Record used inside the TemporaryMemoryManager as metadata for a given device +// memory region. +struct TemporaryMemoryRecord { + // What "generation" this record was allocated in. + // + // Currently the generation counter is bumped for every allocation, but this + // could be made coarser if necessary. + uint64 allocation_generation; + + // Notes whether the temporary memory has been marked as finalized, such that + // we can release the DeviceMemory associated with this record at + // synchronization time. + bool finalized; +}; + +// Manages temporary memories associated with a stream -- keeps records of +// outstanding temporaries and their state, and can deallocate them +// appropriately at points in the Stream lifecycle (e.g. BlockHostUntilDone, +// destruction). +class TemporaryMemoryManager { + public: + explicit TemporaryMemoryManager(Stream* stream) : stream_(stream) {} + + // Allocates a temporary array that is then managed by this object. + template <typename T> + port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>> AllocateArray( + uint64 element_count); + + // Forces deallocation of all managed temporary memory regions. + // + // Called, for example, when the Stream owning this temporary memory manager + // is destroyed. + // + // Note: These calls to Deallocate will likely force synchronization. + void ForceDeallocateAll(); + + // Marks the given memory region as finalized. + // + // If must_exist is set, this will check-fail if the temporary memory record + // is not found. + void MarkFinalized(const DeviceMemoryBase& device_memory, uint64 generation, + bool must_exist); + + // Deallocates temporary memories that have been finalized. + // + // Note: These calls to Deallocate will likely force synchronization, so it is + // meant to be called before a "BlockHostUntilDone" is about to be performed. + void DeallocateFinalizedTemporaries(); + + // Returns whether the provided device_memory is finalized. + // + // In the vacuous case where the device memory doesn't appear in the temporary + // memory records, it is either not a temporary at all, or has already been + // deallocated, and thus returns true. + bool IsFinalized(const DeviceMemoryBase& device_memory, + uint64 allocation_generation) const; + + // Returns whether the manager has a live allocation record for the given + // device memory pointer with the given generation counter. + // + // Note: this is a polling call -- there is no guarantee that the region is + // still allocated once the call has completed. + bool HasAllocated(const DeviceMemoryBase& device_memory, + uint64 generation) const; + + private: + // Allocates an array without type parameterization, so that the + // implementation can live in the source file. Without this base allocation + // method, we incur a circular dependency between the StreamExecutor + // definition and this class' definition. + port::StatusOr<std::unique_ptr<TemporaryDeviceMemoryBase>> AllocateArrayBase( + uint64 element_count, uint64 element_size); + + // Mutex to guard temporary record state. + mutable mutex mutex_; + + // Mapping from device memory to the current (live) temporary memory record. + // + // If a device memory is not in this mapping, it is not a temporary currently + // allocated and owned by this temporary memory manager. + std::map<DeviceMemoryBase, TemporaryMemoryRecord> records_ GUARDED_BY(mutex_); + + // Allocation generation -- we bump this counter to distinguish temporary + // memory handles that have been deallocated and later reallocated at the same + // device memory address. + uint64 generation_ GUARDED_BY(mutex_); + + // The stream (parent object) for this temporary memory manager -- allocations + // are performed through this stream handle. + Stream* stream_; + + SE_DISALLOW_COPY_AND_ASSIGN(TemporaryMemoryManager); +}; + +//////////// +// Inlines + +template <typename T> +port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>> +TemporaryMemoryManager::AllocateArray(uint64 element_count) { + port::StatusOr<std::unique_ptr<TemporaryDeviceMemoryBase>> temporary_memory = + AllocateArrayBase(element_count, sizeof(T)); + if (!temporary_memory.ok()) { + return temporary_memory.status(); + } + + return std::unique_ptr<TemporaryDeviceMemory<T>>( + reinterpret_cast<TemporaryDeviceMemory<T>*>( + temporary_memory.ConsumeValueOrDie().release())); +} + +} // namespace internal +} // namespace gputools +} // namespace perftools + +#endif // TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_MEMORY_MANAGER_H_ |