aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/allocator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/allocator.cc')
-rw-r--r--tensorflow/core/framework/allocator.cc41
1 files changed, 37 insertions, 4 deletions
diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc
index 1c62d37955..888ed0c57b 100644
--- a/tensorflow/core/framework/allocator.cc
+++ b/tensorflow/core/framework/allocator.cc
@@ -91,6 +91,11 @@ void EnableCPUAllocatorFullStats(bool enable) {
cpu_allocator_collect_full_stats = enable;
}
+namespace {
+// A default Allocator for CPU devices. ProcessState::GetCPUAllocator() will
+// return a different version that may perform better, but may also lack the
+// optional stats triggered by the functions above. TODO(tucker): migrate all
+// uses of cpu_allocator() except tests to use ProcessState instead.
class CPUAllocator : public Allocator {
public:
CPUAllocator()
@@ -170,14 +175,42 @@ class CPUAllocator : public Allocator {
TF_DISALLOW_COPY_AND_ASSIGN(CPUAllocator);
};
+class CPUAllocatorFactory : public AllocatorFactory {
+ public:
+ Allocator* CreateAllocator() override { return new CPUAllocator; }
+
+ SubAllocator* CreateSubAllocator(int numa_node) override {
+ return new CPUSubAllocator(new CPUAllocator);
+ }
+
+ private:
+ class CPUSubAllocator : public SubAllocator {
+ public:
+ explicit CPUSubAllocator(CPUAllocator* cpu_allocator)
+ : cpu_allocator_(cpu_allocator) {}
+
+ void* Alloc(size_t alignment, size_t num_bytes) override {
+ return cpu_allocator_->AllocateRaw(alignment, num_bytes);
+ }
+
+ void Free(void* ptr, size_t num_bytes) override {
+ cpu_allocator_->DeallocateRaw(ptr);
+ }
+
+ private:
+ CPUAllocator* cpu_allocator_;
+ };
+};
+
+REGISTER_MEM_ALLOCATOR("DefaultCPUAllocator", 100, CPUAllocatorFactory);
+} // namespace
+
Allocator* cpu_allocator() {
- static Allocator* cpu_alloc = AllocatorRegistry::Global()->GetAllocator();
+ static Allocator* cpu_alloc =
+ AllocatorFactoryRegistry::singleton()->GetAllocator();
if (cpu_allocator_collect_full_stats && !cpu_alloc->TracksAllocationSizes()) {
cpu_alloc = new TrackingAllocator(cpu_alloc, true);
}
return cpu_alloc;
}
-
-REGISTER_MEM_ALLOCATOR("DefaultCPUAllocator", 100, CPUAllocator);
-
} // namespace tensorflow