aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/device_factory.cc
blob: 7d391bde1dd921d32c1d5d6e2d2d2a68c2b44f2b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include "tensorflow/core/common_runtime/device_factory.h"

#include <memory>
#include <string>
#include <unordered_map>

#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/public/session_options.h"

namespace tensorflow {

namespace {

static mutex* get_device_factory_lock() {
  static mutex device_factory_lock;
  return &device_factory_lock;
}

struct FactoryItem {
  std::unique_ptr<DeviceFactory> factory;
  int priority;
};

std::unordered_map<string, FactoryItem>& device_factories() {
  static std::unordered_map<string, FactoryItem>* factories =
      new std::unordered_map<string, FactoryItem>;
  return *factories;
}
}  // namespace

void DeviceFactory::Register(const string& device_type, DeviceFactory* factory,
                             int priority) {
  mutex_lock l(*get_device_factory_lock());
  std::unique_ptr<DeviceFactory> factory_ptr(factory);
  std::unordered_map<string, FactoryItem>& factories = device_factories();
  auto iter = factories.find(device_type);
  if (iter == factories.end()) {
    factories[device_type] = {std::move(factory_ptr), priority};
  } else {
    if (iter->second.priority < priority) {
      iter->second = {std::move(factory_ptr), priority};
    } else if (iter->second.priority == priority) {
      LOG(FATAL) << "Duplicate registration of device factory for type "
                 << device_type << " with the same priority " << priority;
    }
  }
}

DeviceFactory* DeviceFactory::GetFactory(const string& device_type) {
  mutex_lock l(*get_device_factory_lock());  // could use reader lock
  auto it = device_factories().find(device_type);
  if (it == device_factories().end()) {
    return nullptr;
  }
  return it->second.factory.get();
}

void DeviceFactory::AddDevices(const SessionOptions& options,
                               const string& name_prefix,
                               std::vector<Device*>* devices) {
  // CPU first.
  auto cpu_factory = GetFactory("CPU");
  if (!cpu_factory) {
    LOG(FATAL)
        << "CPU Factory not registered.  Did you link in threadpool_device?";
  }
  size_t init_size = devices->size();
  cpu_factory->CreateDevices(options, name_prefix, devices);
  if (devices->size() == init_size) {
    LOG(FATAL) << "No CPU devices are available in this process";
  }

  // Then GPU.
  auto gpu_factory = GetFactory("GPU");
  if (gpu_factory) {
    gpu_factory->CreateDevices(options, name_prefix, devices);
  }

  // Then the rest.
  mutex_lock l(*get_device_factory_lock());
  for (auto& p : device_factories()) {
    auto factory = p.second.factory.get();
    if (factory != cpu_factory && factory != gpu_factory) {
      factory->CreateDevices(options, name_prefix, devices);
    }
  }
}

Device* DeviceFactory::NewDevice(const string& type,
                                 const SessionOptions& options,
                                 const string& name_prefix) {
  auto device_factory = GetFactory(type);
  if (!device_factory) {
    return nullptr;
  }
  SessionOptions opt = options;
  (*opt.config.mutable_device_count())[type] = 1;
  std::vector<Device*> devices;
  device_factory->CreateDevices(opt, name_prefix, &devices);
  CHECK_EQ(devices.size(), 1);
  return devices[0];
}

}  // namespace tensorflow