aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/op_segment.cc
blob: a39bebd854ccdd56459f3fc73889518cafcfc584 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include "tensorflow/core/framework/op_segment.h"

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/port.h"

namespace tensorflow {

OpSegment::Item::~Item() {
  for (auto kv : name_kernel) delete kv.second;
}

OpSegment::OpSegment() {}

OpSegment::~OpSegment() {
  for (auto kv : sessions_) delete kv.second;
}

Status OpSegment::FindOrCreate(const string& session_handle,
                               const string& node_name, OpKernel** kernel,
                               CreateKernelFn create_fn) {
  {
    mutex_lock l(mu_);
    auto item = gtl::FindPtrOrNull(sessions_, session_handle);
    if (item == nullptr) {
      return errors::NotFound("Session ", session_handle, " is not found.");
    }
    *kernel = gtl::FindPtrOrNull(item->name_kernel, node_name);
    if (*kernel != nullptr) {
      return Status::OK();
    }
  }
  Status s = create_fn(kernel);
  if (!s.ok()) {
    LOG(ERROR) << "Create kernel failed: " << s;
    return s;
  }
  {
    mutex_lock l(mu_);
    auto item = gtl::FindPtrOrNull(sessions_, session_handle);
    if (item == nullptr) {
      return errors::NotFound("Session ", session_handle, " is not found.");
    }
    OpKernel** p_kernel = &(item->name_kernel[node_name]);
    if (*p_kernel == nullptr) {
      *p_kernel = *kernel;  // Inserts 'kernel' in the map.
    } else {
      delete *kernel;
      *kernel = *p_kernel;
    }
  }
  return Status::OK();
}

void OpSegment::AddHold(const string& session_handle) {
  mutex_lock l(mu_);
  Item** item = &sessions_[session_handle];
  if (*item == nullptr) {
    *item = new Item;  // num_holds == 1
  } else {
    ++((*item)->num_holds);
  }
}

void OpSegment::RemoveHold(const string& session_handle) {
  Item* item = nullptr;
  {
    mutex_lock l(mu_);
    auto siter = sessions_.find(session_handle);
    if (siter == sessions_.end()) {
      VLOG(1) << "Session " << session_handle << " is not found.";
      return;
    }
    item = siter->second;
    if (--(item->num_holds) > 0) {
      return;
    } else {
      sessions_.erase(siter);
    }
  }
  delete item;
}

}  // end namespace tensorflow