aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/temporary_memory_manager.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/temporary_memory_manager.cc')
-rw-r--r--tensorflow/stream_executor/temporary_memory_manager.cc113
1 files changed, 113 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/temporary_memory_manager.cc b/tensorflow/stream_executor/temporary_memory_manager.cc
new file mode 100644
index 0000000000..0352aa4b2b
--- /dev/null
+++ b/tensorflow/stream_executor/temporary_memory_manager.cc
@@ -0,0 +1,113 @@
+#include "tensorflow/stream_executor/temporary_memory_manager.h"
+
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+#include "tensorflow/stream_executor/lib/ptr_util.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/stream_executor/stream_executor_pimpl.h"
+
+namespace perftools {
+namespace gputools {
+namespace internal {
+
+void TemporaryMemoryManager::ForceDeallocateAll() {
+ mutex_lock lock(mutex_);
+ VLOG(1) << "force-deallocating " << records_.size() << " remaining records";
+ for (auto it = records_.begin(); it != records_.end(); ++it) {
+ DeviceMemoryBase device_memory = it->first;
+ stream_->parent()->Deallocate(&device_memory);
+ }
+}
+
+void TemporaryMemoryManager::MarkFinalized(
+ const DeviceMemoryBase& device_memory, uint64 generation, bool must_exist) {
+ mutex_lock lock(mutex_);
+ auto it = records_.find(device_memory);
+ if (it == records_.end()) {
+ if (must_exist) {
+ LOG(FATAL) << "attempted to mark finalization for temporary "
+ "memory that does not exist";
+ }
+ return;
+ }
+ it->second.finalized = true;
+}
+
+void TemporaryMemoryManager::DeallocateFinalizedTemporaries() {
+ mutex_lock lock(mutex_);
+ int deallocated_count = 0;
+ for (auto it = records_.begin(); it != records_.end();) {
+ if (it->second.finalized) {
+ DeviceMemoryBase device_memory = it->first;
+ stream_->parent()->Deallocate(&device_memory);
+ ++deallocated_count;
+ it = records_.erase(it);
+ } else {
+ ++it;
+ }
+ }
+ VLOG(1) << "deallocated " << deallocated_count << " finalized temporaries";
+}
+
+bool TemporaryMemoryManager::IsFinalized(const DeviceMemoryBase& device_memory,
+ uint64 allocation_generation) const {
+ mutex_lock lock(mutex_);
+ auto it = records_.find(device_memory);
+ if (it == records_.end()) {
+ return true; // If there's no record present it's vacuously finalized.
+ }
+
+ if (it->second.allocation_generation == allocation_generation) {
+ return it->second.finalized;
+ }
+
+ // If the allocation generation did not match, it's vacuously true.
+ return true;
+}
+
+bool TemporaryMemoryManager::HasAllocated(const DeviceMemoryBase& device_memory,
+ uint64 generation) const {
+ mutex_lock lock(mutex_);
+ auto it = records_.find(device_memory);
+ if (it == records_.end()) {
+ return false;
+ }
+ return it->second.allocation_generation == generation;
+}
+
+port::StatusOr<std::unique_ptr<TemporaryDeviceMemoryBase>>
+TemporaryMemoryManager::AllocateArrayBase(uint64 element_count,
+ uint64 element_size) {
+ uint64 byte_size = element_count * element_size;
+ DeviceMemoryBase device_memory =
+ stream_->parent()->AllocateArray<uint8>(byte_size);
+ if (device_memory == nullptr) {
+ return port::Status(port::error::RESOURCE_EXHAUSTED,
+ port::StrCat("could not allocate temporary memory of ",
+ byte_size, " bytes"));
+ }
+
+ uint64 generation;
+
+ // Add the record before instantiating the device memory instance so we can
+ // check the allocation invariant at TemporaryDeviceMemory construction time.
+ {
+ mutex_lock lock(mutex_);
+ generation = ++generation_;
+ DCHECK(records_.find(device_memory) == records_.end());
+ records_[device_memory] = {generation,
+ /*finalized=*/false};
+ }
+
+ VLOG(1) << port::Printf(
+ "stream %p allocated temporary device memory at %p (size %llu) in "
+ "generation %llu",
+ stream_, device_memory.opaque(), byte_size, generation);
+ std::unique_ptr<TemporaryDeviceMemoryBase> result(
+ new TemporaryDeviceMemoryBase(stream_, device_memory, generation));
+ return std::move(result);
+}
+
+} // namespace internal
+} // namespace gputools
+} // namespace perftools