diff options
Diffstat (limited to 'tensorflow/core/framework/allocator.cc')
-rw-r--r-- | tensorflow/core/framework/allocator.cc | 41 |
1 files changed, 37 insertions, 4 deletions
diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc index 1c62d37955..888ed0c57b 100644 --- a/tensorflow/core/framework/allocator.cc +++ b/tensorflow/core/framework/allocator.cc @@ -91,6 +91,11 @@ void EnableCPUAllocatorFullStats(bool enable) { cpu_allocator_collect_full_stats = enable; } +namespace { +// A default Allocator for CPU devices. ProcessState::GetCPUAllocator() will +// return a different version that may perform better, but may also lack the +// optional stats triggered by the functions above. TODO(tucker): migrate all +// uses of cpu_allocator() except tests to use ProcessState instead. class CPUAllocator : public Allocator { public: CPUAllocator() @@ -170,14 +175,42 @@ class CPUAllocator : public Allocator { TF_DISALLOW_COPY_AND_ASSIGN(CPUAllocator); }; +class CPUAllocatorFactory : public AllocatorFactory { + public: + Allocator* CreateAllocator() override { return new CPUAllocator; } + + SubAllocator* CreateSubAllocator(int numa_node) override { + return new CPUSubAllocator(new CPUAllocator); + } + + private: + class CPUSubAllocator : public SubAllocator { + public: + explicit CPUSubAllocator(CPUAllocator* cpu_allocator) + : cpu_allocator_(cpu_allocator) {} + + void* Alloc(size_t alignment, size_t num_bytes) override { + return cpu_allocator_->AllocateRaw(alignment, num_bytes); + } + + void Free(void* ptr, size_t num_bytes) override { + cpu_allocator_->DeallocateRaw(ptr); + } + + private: + CPUAllocator* cpu_allocator_; + }; +}; + +REGISTER_MEM_ALLOCATOR("DefaultCPUAllocator", 100, CPUAllocatorFactory); +} // namespace + Allocator* cpu_allocator() { - static Allocator* cpu_alloc = AllocatorRegistry::Global()->GetAllocator(); + static Allocator* cpu_alloc = + AllocatorFactoryRegistry::singleton()->GetAllocator(); if (cpu_allocator_collect_full_stats && !cpu_alloc->TracksAllocationSizes()) { cpu_alloc = new TrackingAllocator(cpu_alloc, true); } return cpu_alloc; } - -REGISTER_MEM_ALLOCATOR("DefaultCPUAllocator", 100, CPUAllocator); - } // namespace tensorflow |