aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2017-04-26 12:03:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-26 13:33:31 -0700
commit0293b46e724b85546a2175fdcb3992f44f3c0ef4 (patch)
tree9f462b80408ffe3d1bd63b49fee9c4ed49460e5d /tensorflow
parentaa1f99845dacba0f37f1b6fad5e51ce7688ee1c3 (diff)
AllocationRegistry: only check fail if two different allocator types
are defined for the same name and priority. Change: 154333776
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/framework/allocator_registry.cc27
-rw-r--r--tensorflow/core/framework/allocator_registry.h7
2 files changed, 26 insertions, 8 deletions
diff --git a/tensorflow/core/framework/allocator_registry.cc b/tensorflow/core/framework/allocator_registry.cc
index 946050687d..486be39ae3 100644
--- a/tensorflow/core/framework/allocator_registry.cc
+++ b/tensorflow/core/framework/allocator_registry.cc
@@ -26,22 +26,37 @@ AllocatorRegistry* AllocatorRegistry::Global() {
return global_allocator_registry;
}
-bool AllocatorRegistry::CheckForDuplicates(const string& name, int priority) {
+Allocator* AllocatorRegistry::GetRegisteredAllocator(const string& name,
+ int priority) {
for (auto entry : allocators_) {
if (!name.compare(entry.name) && priority == entry.priority) {
- return true;
+ return entry.allocator;
}
}
- return false;
+ return nullptr;
}
void AllocatorRegistry::Register(const string& name, int priority,
Allocator* allocator) {
CHECK(!name.empty()) << "Need a valid name for Allocator";
CHECK_GE(priority, 0) << "Priority needs to be non-negative";
- CHECK(!CheckForDuplicates(name, priority))
- << "Allocator with name: [" << name << "] and priority: [" << priority
- << "] already registered";
+
+ Allocator* existing = GetRegisteredAllocator(name, priority);
+ if (existing != nullptr) {
+ // A duplicate is if the registration name and priority match
+ // but the Allocator::Name()'s don't match.
+ CHECK_EQ(existing->Name(), allocator->Name())
+ << "Allocator with name: [" << name << "], type [" << existing->Name()
+ << "], priority: [" << priority
+ << "] already registered. Choose a different name to register "
+ << "an allocator of type " << allocator->Name();
+
+ // The allocator names match, so we can just return.
+ // It should be safe to delete the allocator since the caller
+ // gives up ownership of it.
+ delete allocator;
+ return;
+ }
AllocatorRegistryEntry tmp_entry;
tmp_entry.name = name;
diff --git a/tensorflow/core/framework/allocator_registry.h b/tensorflow/core/framework/allocator_registry.h
index c419366ae1..b26e79ac3b 100644
--- a/tensorflow/core/framework/allocator_registry.h
+++ b/tensorflow/core/framework/allocator_registry.h
@@ -27,7 +27,8 @@ namespace tensorflow {
// A global AllocatorRegistry is used to hold allocators for CPU backends
class AllocatorRegistry {
public:
- // Add an allocator to the registry.
+ // Add an allocator to the registry. Caller releases ownership of
+ // 'allocator'.
void Register(const string& name, int priority, Allocator* allocator);
// Return allocator with highest priority
@@ -44,7 +45,9 @@ class AllocatorRegistry {
Allocator* allocator; // not owned
} AllocatorRegistryEntry;
- bool CheckForDuplicates(const string& name, int priority);
+ // Returns the Allocator registered for 'name' and 'priority',
+ // or 'nullptr' if not found.
+ Allocator* GetRegisteredAllocator(const string& name, int priority);
std::vector<AllocatorRegistryEntry> allocators_;
Allocator* m_curr_allocator_; // not owned