diff options
author | 2017-04-26 12:03:45 -0800 | |
---|---|---|
committer | 2017-04-26 13:33:31 -0700 | |
commit | 0293b46e724b85546a2175fdcb3992f44f3c0ef4 (patch) | |
tree | 9f462b80408ffe3d1bd63b49fee9c4ed49460e5d /tensorflow | |
parent | aa1f99845dacba0f37f1b6fad5e51ce7688ee1c3 (diff) |
AllocationRegistry: only check fail if two different allocator types
are defined for the same name and priority.
Change: 154333776
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/framework/allocator_registry.cc | 27 | ||||
-rw-r--r-- | tensorflow/core/framework/allocator_registry.h | 7 |
2 files changed, 26 insertions, 8 deletions
diff --git a/tensorflow/core/framework/allocator_registry.cc b/tensorflow/core/framework/allocator_registry.cc index 946050687d..486be39ae3 100644 --- a/tensorflow/core/framework/allocator_registry.cc +++ b/tensorflow/core/framework/allocator_registry.cc @@ -26,22 +26,37 @@ AllocatorRegistry* AllocatorRegistry::Global() { return global_allocator_registry; } -bool AllocatorRegistry::CheckForDuplicates(const string& name, int priority) { +Allocator* AllocatorRegistry::GetRegisteredAllocator(const string& name, + int priority) { for (auto entry : allocators_) { if (!name.compare(entry.name) && priority == entry.priority) { - return true; + return entry.allocator; } } - return false; + return nullptr; } void AllocatorRegistry::Register(const string& name, int priority, Allocator* allocator) { CHECK(!name.empty()) << "Need a valid name for Allocator"; CHECK_GE(priority, 0) << "Priority needs to be non-negative"; - CHECK(!CheckForDuplicates(name, priority)) - << "Allocator with name: [" << name << "] and priority: [" << priority - << "] already registered"; + + Allocator* existing = GetRegisteredAllocator(name, priority); + if (existing != nullptr) { + // A duplicate is if the registration name and priority match + // but the Allocator::Name()'s don't match. + CHECK_EQ(existing->Name(), allocator->Name()) + << "Allocator with name: [" << name << "], type [" << existing->Name() + << "], priority: [" << priority + << "] already registered. Choose a different name to register " + << "an allocator of type " << allocator->Name(); + + // The allocator names match, so we can just return. + // It should be safe to delete the allocator since the caller + // gives up ownership of it. + delete allocator; + return; + } AllocatorRegistryEntry tmp_entry; tmp_entry.name = name; diff --git a/tensorflow/core/framework/allocator_registry.h b/tensorflow/core/framework/allocator_registry.h index c419366ae1..b26e79ac3b 100644 --- a/tensorflow/core/framework/allocator_registry.h +++ b/tensorflow/core/framework/allocator_registry.h @@ -27,7 +27,8 @@ namespace tensorflow { // A global AllocatorRegistry is used to hold allocators for CPU backends class AllocatorRegistry { public: - // Add an allocator to the registry. + // Add an allocator to the registry. Caller releases ownership of + // 'allocator'. void Register(const string& name, int priority, Allocator* allocator); // Return allocator with highest priority @@ -44,7 +45,9 @@ class AllocatorRegistry { Allocator* allocator; // not owned } AllocatorRegistryEntry; - bool CheckForDuplicates(const string& name, int priority); + // Returns the Allocator registered for 'name' and 'priority', + // or 'nullptr' if not found. + Allocator* GetRegisteredAllocator(const string& name, int priority); std::vector<AllocatorRegistryEntry> allocators_; Allocator* m_curr_allocator_; // not owned |