aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/allocator_registry.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/allocator_registry.cc')
-rw-r--r--tensorflow/core/framework/allocator_registry.cc120
1 files changed, 85 insertions, 35 deletions
diff --git a/tensorflow/core/framework/allocator_registry.cc b/tensorflow/core/framework/allocator_registry.cc
index 486be39ae3..099c4bacc8 100644
--- a/tensorflow/core/framework/allocator_registry.cc
+++ b/tensorflow/core/framework/allocator_registry.cc
@@ -21,60 +21,110 @@ limitations under the License.
namespace tensorflow {
// static
-AllocatorRegistry* AllocatorRegistry::Global() {
- static AllocatorRegistry* global_allocator_registry = new AllocatorRegistry;
- return global_allocator_registry;
+AllocatorFactoryRegistry* AllocatorFactoryRegistry::singleton() {
+ static AllocatorFactoryRegistry* singleton = new AllocatorFactoryRegistry;
+ return singleton;
}
-Allocator* AllocatorRegistry::GetRegisteredAllocator(const string& name,
- int priority) {
- for (auto entry : allocators_) {
+const AllocatorFactoryRegistry::FactoryEntry*
+AllocatorFactoryRegistry::FindEntry(const string& name, int priority) const {
+ for (auto& entry : factories_) {
if (!name.compare(entry.name) && priority == entry.priority) {
- return entry.allocator;
+ return &entry;
}
}
return nullptr;
}
-void AllocatorRegistry::Register(const string& name, int priority,
- Allocator* allocator) {
+void AllocatorFactoryRegistry::Register(const char* source_file,
+ int source_line, const string& name,
+ int priority,
+ AllocatorFactory* factory) {
+ mutex_lock l(mu_);
+ CHECK(!first_alloc_made_) << "Attempt to register an AllocatorFactory "
+ << "after call to GetAllocator()";
CHECK(!name.empty()) << "Need a valid name for Allocator";
CHECK_GE(priority, 0) << "Priority needs to be non-negative";
- Allocator* existing = GetRegisteredAllocator(name, priority);
+ const FactoryEntry* existing = FindEntry(name, priority);
if (existing != nullptr) {
- // A duplicate is if the registration name and priority match
- // but the Allocator::Name()'s don't match.
- CHECK_EQ(existing->Name(), allocator->Name())
- << "Allocator with name: [" << name << "], type [" << existing->Name()
- << "], priority: [" << priority
- << "] already registered. Choose a different name to register "
- << "an allocator of type " << allocator->Name();
-
- // The allocator names match, so we can just return.
- // It should be safe to delete the allocator since the caller
- // gives up ownership of it.
- delete allocator;
- return;
+ // Duplicate registration is a hard failure.
+ LOG(FATAL) << "New registration for AllocatorFactory with name=" << name
+ << " priority=" << priority << " at location " << source_file
+ << ":" << source_line
+ << " conflicts with previous registration at location "
+ << existing->source_file << ":" << existing->source_line;
}
- AllocatorRegistryEntry tmp_entry;
- tmp_entry.name = name;
- tmp_entry.priority = priority;
- tmp_entry.allocator = allocator;
+ FactoryEntry entry;
+ entry.source_file = source_file;
+ entry.source_line = source_line;
+ entry.name = name;
+ entry.priority = priority;
+ entry.factory.reset(factory);
+ factories_.push_back(std::move(entry));
+}
- allocators_.push_back(tmp_entry);
- int high_pri = -1;
- for (auto entry : allocators_) {
- if (high_pri < entry.priority) {
- m_curr_allocator_ = entry.allocator;
- high_pri = entry.priority;
+Allocator* AllocatorFactoryRegistry::GetAllocator() {
+ mutex_lock l(mu_);
+ first_alloc_made_ = true;
+ FactoryEntry* best_entry = nullptr;
+ for (auto& entry : factories_) {
+ if (best_entry == nullptr) {
+ best_entry = &entry;
+ } else if (entry.priority > best_entry->priority) {
+ best_entry = &entry;
}
}
+ if (best_entry) {
+ if (!best_entry->allocator) {
+ best_entry->allocator.reset(best_entry->factory->CreateAllocator());
+ }
+ return best_entry->allocator.get();
+ } else {
+ LOG(FATAL) << "No registered CPU AllocatorFactory";
+ return nullptr;
+ }
}
-Allocator* AllocatorRegistry::GetAllocator() {
- return CHECK_NOTNULL(m_curr_allocator_);
+SubAllocator* AllocatorFactoryRegistry::GetSubAllocator(int numa_node) {
+ mutex_lock l(mu_);
+ first_alloc_made_ = true;
+ FactoryEntry* best_entry = nullptr;
+ for (auto& entry : factories_) {
+ if (best_entry == nullptr) {
+ best_entry = &entry;
+ } else if (best_entry->factory->NumaEnabled()) {
+ if (entry.factory->NumaEnabled() &&
+ (entry.priority > best_entry->priority)) {
+ best_entry = &entry;
+ }
+ } else {
+ DCHECK(!best_entry->factory->NumaEnabled());
+ if (entry.factory->NumaEnabled() ||
+ (entry.priority > best_entry->priority)) {
+ best_entry = &entry;
+ }
+ }
+ }
+ if (best_entry) {
+ int index = 0;
+ if (numa_node != port::kNUMANoAffinity) {
+ CHECK_LE(numa_node, port::NUMANumNodes());
+ index = 1 + numa_node;
+ }
+ if (best_entry->sub_allocators.size() < (index + 1)) {
+ best_entry->sub_allocators.resize(index + 1);
+ }
+ if (!best_entry->sub_allocators[index].get()) {
+ best_entry->sub_allocators[index].reset(
+ best_entry->factory->CreateSubAllocator(numa_node));
+ }
+ return best_entry->sub_allocators[index].get();
+ } else {
+ LOG(FATAL) << "No registered CPU AllocatorFactory";
+ return nullptr;
+ }
}
} // namespace tensorflow