aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/session_factory.cc
blob: 666b99812d4035067dd4083626e96dd49cdc5e07 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include "tensorflow/core/common_runtime/session_factory.h"

#include <unordered_map>

#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/port.h"
namespace tensorflow {
namespace {

static mutex* get_session_factory_lock() {
  static mutex session_factory_lock;
  return &session_factory_lock;
}

typedef std::unordered_map<string, SessionFactory*> SessionFactories;
SessionFactories* session_factories() {
  static SessionFactories* factories = new SessionFactories;
  return factories;
}

}  // namespace

void SessionFactory::Register(const string& runtime_type,
                              SessionFactory* factory) {
  mutex_lock l(*get_session_factory_lock());
  if (!session_factories()->insert({runtime_type, factory}).second) {
    LOG(ERROR) << "Two session factories are being registered "
               << "under" << runtime_type;
  }
}

SessionFactory* SessionFactory::GetFactory(const string& runtime_type) {
  mutex_lock l(*get_session_factory_lock());  // could use reader lock
  auto it = session_factories()->find(runtime_type);
  if (it == session_factories()->end()) {
    return nullptr;
  }
  return it->second;
}

}  // namespace tensorflow