aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/allocator_registry.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/allocator_registry.h')
-rw-r--r--tensorflow/core/framework/allocator_registry.h111
1 files changed, 76 insertions, 35 deletions
diff --git a/tensorflow/core/framework/allocator_registry.h b/tensorflow/core/framework/allocator_registry.h
index b26e79ac3b..24f282ce84 100644
--- a/tensorflow/core/framework/allocator_registry.h
+++ b/tensorflow/core/framework/allocator_registry.h
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// Classes to maintain a static registry of memory allocators
+// Classes to maintain a static registry of memory allocator factories.
#ifndef TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_REGISTRY_H_
#define TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_REGISTRY_H_
@@ -21,59 +21,100 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/numa.h"
namespace tensorflow {
-// A global AllocatorRegistry is used to hold allocators for CPU backends
-class AllocatorRegistry {
+class AllocatorFactory {
public:
- // Add an allocator to the registry. Caller releases ownership of
- // 'allocator'.
- void Register(const string& name, int priority, Allocator* allocator);
+ virtual ~AllocatorFactory() {}
- // Return allocator with highest priority
- // If multiple allocators have the same high priority, return one of them
+ // Returns true if the factory will create a functionally different
+ // SubAllocator for different (legal) values of numa_node.
+ virtual bool NumaEnabled() { return false; }
+
+ // Create an Allocator.
+ virtual Allocator* CreateAllocator() = 0;
+
+ // Create a SubAllocator. If NumaEnabled() is true, then returned SubAllocator
+ // will allocate memory local to numa_node. If numa_node == kNUMANoAffinity
+ // then allocated memory is not specific to any NUMA node.
+ virtual SubAllocator* CreateSubAllocator(int numa_node) = 0;
+};
+
+// A singleton registry of AllocatorFactories.
+//
+// Allocators should be obtained through ProcessState or cpu_allocator()
+// (deprecated), not directly through this interface. The purpose of this
+// registry is to allow link-time discovery of multiple AllocatorFactories among
+// which ProcessState will obtain the best fit at startup.
+class AllocatorFactoryRegistry {
+ public:
+ AllocatorFactoryRegistry() {}
+ ~AllocatorFactoryRegistry() {}
+
+ void Register(const char* source_file, int source_line, const string& name,
+ int priority, AllocatorFactory* factory);
+
+ // Returns 'best fit' Allocator. Find the factory with the highest priority
+ // and return an allocator constructed by it. If multiple factories have
+ // been registered with the same priority, picks one by unspecified criteria.
Allocator* GetAllocator();
- // Returns the global registry of allocators.
- static AllocatorRegistry* Global();
+ // Returns 'best fit' SubAllocator. First look for the highest priority
+ // factory that is NUMA-enabled. If none is registered, fall back to the
+ // highest priority non-NUMA-enabled factory. If NUMA-enabled, return a
+ // SubAllocator specific to numa_node, otherwise return a NUMA-insensitive
+ // SubAllocator.
+ SubAllocator* GetSubAllocator(int numa_node);
+
+ // Returns the singleton value.
+ static AllocatorFactoryRegistry* singleton();
private:
- typedef struct {
+ mutex mu_;
+ bool first_alloc_made_ = false;
+ struct FactoryEntry {
+ const char* source_file;
+ int source_line;
string name;
int priority;
- Allocator* allocator; // not owned
- } AllocatorRegistryEntry;
-
- // Returns the Allocator registered for 'name' and 'priority',
- // or 'nullptr' if not found.
- Allocator* GetRegisteredAllocator(const string& name, int priority);
-
- std::vector<AllocatorRegistryEntry> allocators_;
- Allocator* m_curr_allocator_; // not owned
+ std::unique_ptr<AllocatorFactory> factory;
+ std::unique_ptr<Allocator> allocator;
+ // Index 0 corresponds to kNUMANoAffinity, other indices are (numa_node +
+ // 1).
+ std::vector<std::unique_ptr<SubAllocator>> sub_allocators;
+ };
+ std::vector<FactoryEntry> factories_ GUARDED_BY(mu_);
+
+ // Returns any FactoryEntry registered under 'name' and 'priority',
+ // or 'nullptr' if none found.
+ const FactoryEntry* FindEntry(const string& name, int priority) const
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(AllocatorFactoryRegistry);
};
-namespace allocator_registration {
-
-class AllocatorRegistration {
+class AllocatorFactoryRegistration {
public:
- AllocatorRegistration(const string& name, int priority,
- Allocator* allocator) {
- AllocatorRegistry::Global()->Register(name, priority, allocator);
+ AllocatorFactoryRegistration(const char* file, int line, const string& name,
+ int priority, AllocatorFactory* factory) {
+ AllocatorFactoryRegistry::singleton()->Register(file, line, name, priority,
+ factory);
}
};
-} // namespace allocator_registration
-
-#define REGISTER_MEM_ALLOCATOR(name, priority, allocator) \
- REGISTER_MEM_ALLOCATOR_UNIQ_HELPER(__COUNTER__, name, priority, allocator)
+#define REGISTER_MEM_ALLOCATOR(name, priority, factory) \
+ REGISTER_MEM_ALLOCATOR_UNIQ_HELPER(__COUNTER__, __FILE__, __LINE__, name, \
+ priority, factory)
-#define REGISTER_MEM_ALLOCATOR_UNIQ_HELPER(ctr, name, priority, allocator) \
- REGISTER_MEM_ALLOCATOR_UNIQ(ctr, name, priority, allocator)
+#define REGISTER_MEM_ALLOCATOR_UNIQ_HELPER(ctr, file, line, name, priority, \
+ factory) \
+ REGISTER_MEM_ALLOCATOR_UNIQ(ctr, file, line, name, priority, factory)
-#define REGISTER_MEM_ALLOCATOR_UNIQ(ctr, name, priority, allocator) \
- static allocator_registration::AllocatorRegistration \
- register_allocator_##ctr(name, priority, new allocator)
+#define REGISTER_MEM_ALLOCATOR_UNIQ(ctr, file, line, name, priority, factory) \
+ static AllocatorFactoryRegistration allocator_factory_reg_##ctr( \
+ file, line, name, priority, new factory)
} // namespace tensorflow