blob: 742459c63b9541fecf4da0e8b9f6d782a8459e9c (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
|
#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h"
#include "tensorflow/core/public/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/port.h"
namespace tensorflow {
GPUAllocatorRetry::GPUAllocatorRetry() : env_(Env::Default()) {}
void* GPUAllocatorRetry::AllocateRaw(
std::function<void*(size_t alignment, size_t num_bytes,
bool verbose_failure)> alloc_func,
int max_millis_to_wait, size_t alignment, size_t num_bytes) {
if (num_bytes == 0) {
LOG(WARNING) << "Request to allocate 0 bytes";
return nullptr;
}
uint64 deadline_micros = env_->NowMicros() + max_millis_to_wait * 1000;
void* ptr = nullptr;
while (ptr == nullptr) {
ptr = alloc_func(alignment, num_bytes, false);
if (ptr == nullptr) {
uint64 now = env_->NowMicros();
if (now < deadline_micros) {
mutex_lock l(mu_);
WaitForMilliseconds(&l, &memory_returned_,
(deadline_micros - now) / 1000);
} else {
return alloc_func(alignment, num_bytes, true);
}
}
}
return ptr;
}
void GPUAllocatorRetry::DeallocateRaw(std::function<void(void*)> dealloc_func,
void* ptr) {
if (ptr == nullptr) {
LOG(ERROR) << "Request to free nullptr";
return;
}
dealloc_func(ptr);
{
mutex_lock l(mu_);
memory_returned_.notify_all();
}
}
} // namespace tensorflow
|