aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/allocator_registry.h
blob: c419366ae1aa6f35cf98c351844d930bf1b49728 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Classes to maintain a static registry of memory allocators
#ifndef TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_REGISTRY_H_
#define TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_REGISTRY_H_

#include <string>
#include <vector>

#include "tensorflow/core/framework/allocator.h"

namespace tensorflow {

// A global AllocatorRegistry is used to hold allocators for CPU backends
class AllocatorRegistry {
 public:
  // Add an allocator to the registry.
  void Register(const string& name, int priority, Allocator* allocator);

  // Return allocator with highest priority
  // If multiple allocators have the same high priority, return one of them
  Allocator* GetAllocator();

  // Returns the global registry of allocators.
  static AllocatorRegistry* Global();

 private:
  typedef struct {
    string name;
    int priority;
    Allocator* allocator;  // not owned
  } AllocatorRegistryEntry;

  bool CheckForDuplicates(const string& name, int priority);

  std::vector<AllocatorRegistryEntry> allocators_;
  Allocator* m_curr_allocator_;  // not owned
};

namespace allocator_registration {

class AllocatorRegistration {
 public:
  AllocatorRegistration(const string& name, int priority,
                        Allocator* allocator) {
    AllocatorRegistry::Global()->Register(name, priority, allocator);
  }
};

}  // namespace allocator_registration

#define REGISTER_MEM_ALLOCATOR(name, priority, allocator) \
  REGISTER_MEM_ALLOCATOR_UNIQ_HELPER(__COUNTER__, name, priority, allocator)

#define REGISTER_MEM_ALLOCATOR_UNIQ_HELPER(ctr, name, priority, allocator) \
  REGISTER_MEM_ALLOCATOR_UNIQ(ctr, name, priority, allocator)

#define REGISTER_MEM_ALLOCATOR_UNIQ(ctr, name, priority, allocator) \
  static allocator_registration::AllocatorRegistration              \
      register_allocator_##ctr(name, priority, new allocator)

}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_REGISTRY_H_