diff options
Diffstat (limited to 'tensorflow/core/framework/allocator_registry.cc')
-rw-r--r-- | tensorflow/core/framework/allocator_registry.cc | 120 |
1 files changed, 85 insertions, 35 deletions
diff --git a/tensorflow/core/framework/allocator_registry.cc b/tensorflow/core/framework/allocator_registry.cc index 486be39ae3..099c4bacc8 100644 --- a/tensorflow/core/framework/allocator_registry.cc +++ b/tensorflow/core/framework/allocator_registry.cc @@ -21,60 +21,110 @@ limitations under the License. namespace tensorflow { // static -AllocatorRegistry* AllocatorRegistry::Global() { - static AllocatorRegistry* global_allocator_registry = new AllocatorRegistry; - return global_allocator_registry; +AllocatorFactoryRegistry* AllocatorFactoryRegistry::singleton() { + static AllocatorFactoryRegistry* singleton = new AllocatorFactoryRegistry; + return singleton; } -Allocator* AllocatorRegistry::GetRegisteredAllocator(const string& name, - int priority) { - for (auto entry : allocators_) { +const AllocatorFactoryRegistry::FactoryEntry* +AllocatorFactoryRegistry::FindEntry(const string& name, int priority) const { + for (auto& entry : factories_) { if (!name.compare(entry.name) && priority == entry.priority) { - return entry.allocator; + return &entry; } } return nullptr; } -void AllocatorRegistry::Register(const string& name, int priority, - Allocator* allocator) { +void AllocatorFactoryRegistry::Register(const char* source_file, + int source_line, const string& name, + int priority, + AllocatorFactory* factory) { + mutex_lock l(mu_); + CHECK(!first_alloc_made_) << "Attempt to register an AllocatorFactory " + << "after call to GetAllocator()"; CHECK(!name.empty()) << "Need a valid name for Allocator"; CHECK_GE(priority, 0) << "Priority needs to be non-negative"; - Allocator* existing = GetRegisteredAllocator(name, priority); + const FactoryEntry* existing = FindEntry(name, priority); if (existing != nullptr) { - // A duplicate is if the registration name and priority match - // but the Allocator::Name()'s don't match. - CHECK_EQ(existing->Name(), allocator->Name()) - << "Allocator with name: [" << name << "], type [" << existing->Name() - << "], priority: [" << priority - << "] already registered. Choose a different name to register " - << "an allocator of type " << allocator->Name(); - - // The allocator names match, so we can just return. - // It should be safe to delete the allocator since the caller - // gives up ownership of it. - delete allocator; - return; + // Duplicate registration is a hard failure. + LOG(FATAL) << "New registration for AllocatorFactory with name=" << name + << " priority=" << priority << " at location " << source_file + << ":" << source_line + << " conflicts with previous registration at location " + << existing->source_file << ":" << existing->source_line; } - AllocatorRegistryEntry tmp_entry; - tmp_entry.name = name; - tmp_entry.priority = priority; - tmp_entry.allocator = allocator; + FactoryEntry entry; + entry.source_file = source_file; + entry.source_line = source_line; + entry.name = name; + entry.priority = priority; + entry.factory.reset(factory); + factories_.push_back(std::move(entry)); +} - allocators_.push_back(tmp_entry); - int high_pri = -1; - for (auto entry : allocators_) { - if (high_pri < entry.priority) { - m_curr_allocator_ = entry.allocator; - high_pri = entry.priority; +Allocator* AllocatorFactoryRegistry::GetAllocator() { + mutex_lock l(mu_); + first_alloc_made_ = true; + FactoryEntry* best_entry = nullptr; + for (auto& entry : factories_) { + if (best_entry == nullptr) { + best_entry = &entry; + } else if (entry.priority > best_entry->priority) { + best_entry = &entry; } } + if (best_entry) { + if (!best_entry->allocator) { + best_entry->allocator.reset(best_entry->factory->CreateAllocator()); + } + return best_entry->allocator.get(); + } else { + LOG(FATAL) << "No registered CPU AllocatorFactory"; + return nullptr; + } } -Allocator* AllocatorRegistry::GetAllocator() { - return CHECK_NOTNULL(m_curr_allocator_); +SubAllocator* AllocatorFactoryRegistry::GetSubAllocator(int numa_node) { + mutex_lock l(mu_); + first_alloc_made_ = true; + FactoryEntry* best_entry = nullptr; + for (auto& entry : factories_) { + if (best_entry == nullptr) { + best_entry = &entry; + } else if (best_entry->factory->NumaEnabled()) { + if (entry.factory->NumaEnabled() && + (entry.priority > best_entry->priority)) { + best_entry = &entry; + } + } else { + DCHECK(!best_entry->factory->NumaEnabled()); + if (entry.factory->NumaEnabled() || + (entry.priority > best_entry->priority)) { + best_entry = &entry; + } + } + } + if (best_entry) { + int index = 0; + if (numa_node != port::kNUMANoAffinity) { + CHECK_LE(numa_node, port::NUMANumNodes()); + index = 1 + numa_node; + } + if (best_entry->sub_allocators.size() < (index + 1)) { + best_entry->sub_allocators.resize(index + 1); + } + if (!best_entry->sub_allocators[index].get()) { + best_entry->sub_allocators[index].reset( + best_entry->factory->CreateSubAllocator(numa_node)); + } + return best_entry->sub_allocators[index].get(); + } else { + LOG(FATAL) << "No registered CPU AllocatorFactory"; + return nullptr; + } } } // namespace tensorflow |