aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/tracking_allocator.cc
blob: 78311ded1968e317901cc0152f5d8cbea9e97bad (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include "tensorflow/core/framework/tracking_allocator.h"

#include "tensorflow/core/platform/logging.h"

namespace tensorflow {

TrackingAllocator::TrackingAllocator(Allocator* allocator)
    : allocator_(allocator),
      ref_(1),
      allocated_(0),
      high_watermark_(0),
      total_bytes_(0) {}

void* TrackingAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
  void* ptr = allocator_->AllocateRaw(alignment, num_bytes);
  // If memory is exhausted AllocateRaw returns nullptr, and we should
  // pass this through to the caller
  if (nullptr == ptr) {
    return ptr;
  }
  if (allocator_->TracksAllocationSizes()) {
    size_t allocated_bytes = allocator_->AllocatedSize(ptr);
    {
      mutex_lock lock(mu_);
      allocated_ += allocated_bytes;
      high_watermark_ = std::max(high_watermark_, allocated_);
      total_bytes_ += allocated_bytes;
      ++ref_;
    }
  } else {
    mutex_lock lock(mu_);
    total_bytes_ += num_bytes;
    ++ref_;
  }
  return ptr;
}

void TrackingAllocator::DeallocateRaw(void* ptr) {
  // freeing a null ptr is a no-op
  if (nullptr == ptr) {
    return;
  }
  bool should_delete;
  // fetch the following outside the lock in case the call to
  // AllocatedSize is slow
  bool tracks_allocation_sizes = allocator_->TracksAllocationSizes();
  size_t allocated_bytes = 0;
  if (tracks_allocation_sizes) {
    allocated_bytes = allocator_->AllocatedSize(ptr);
  }
  Allocator* allocator = allocator_;
  {
    mutex_lock lock(mu_);
    if (tracks_allocation_sizes) {
      CHECK_GE(allocated_, allocated_bytes);
      allocated_ -= allocated_bytes;
    }
    should_delete = UnRef();
  }
  allocator->DeallocateRaw(ptr);
  if (should_delete) {
    delete this;
  }
}

bool TrackingAllocator::TracksAllocationSizes() {
  return allocator_->TracksAllocationSizes();
}

size_t TrackingAllocator::RequestedSize(void* ptr) {
  return allocator_->RequestedSize(ptr);
}

size_t TrackingAllocator::AllocatedSize(void* ptr) {
  return allocator_->AllocatedSize(ptr);
}

std::pair<size_t, size_t> TrackingAllocator::GetSizesAndUnRef() {
  size_t high_watermark;
  size_t total_bytes;
  bool should_delete;
  {
    mutex_lock lock(mu_);
    high_watermark = high_watermark_;
    total_bytes = total_bytes_;
    should_delete = UnRef();
  }
  if (should_delete) {
    delete this;
  }
  return std::make_pair(total_bytes, high_watermark);
}

bool TrackingAllocator::UnRef() {
  CHECK_GE(ref_, 1);
  --ref_;
  return (ref_ == 0);
}

}  // end namespace tensorflow