blob: a39bebd854ccdd56459f3fc73889518cafcfc584 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
|
#include "tensorflow/core/framework/op_segment.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/port.h"
namespace tensorflow {
OpSegment::Item::~Item() {
for (auto kv : name_kernel) delete kv.second;
}
OpSegment::OpSegment() {}
OpSegment::~OpSegment() {
for (auto kv : sessions_) delete kv.second;
}
Status OpSegment::FindOrCreate(const string& session_handle,
const string& node_name, OpKernel** kernel,
CreateKernelFn create_fn) {
{
mutex_lock l(mu_);
auto item = gtl::FindPtrOrNull(sessions_, session_handle);
if (item == nullptr) {
return errors::NotFound("Session ", session_handle, " is not found.");
}
*kernel = gtl::FindPtrOrNull(item->name_kernel, node_name);
if (*kernel != nullptr) {
return Status::OK();
}
}
Status s = create_fn(kernel);
if (!s.ok()) {
LOG(ERROR) << "Create kernel failed: " << s;
return s;
}
{
mutex_lock l(mu_);
auto item = gtl::FindPtrOrNull(sessions_, session_handle);
if (item == nullptr) {
return errors::NotFound("Session ", session_handle, " is not found.");
}
OpKernel** p_kernel = &(item->name_kernel[node_name]);
if (*p_kernel == nullptr) {
*p_kernel = *kernel; // Inserts 'kernel' in the map.
} else {
delete *kernel;
*kernel = *p_kernel;
}
}
return Status::OK();
}
void OpSegment::AddHold(const string& session_handle) {
mutex_lock l(mu_);
Item** item = &sessions_[session_handle];
if (*item == nullptr) {
*item = new Item; // num_holds == 1
} else {
++((*item)->num_holds);
}
}
void OpSegment::RemoveHold(const string& session_handle) {
Item* item = nullptr;
{
mutex_lock l(mu_);
auto siter = sessions_.find(session_handle);
if (siter == sessions_.end()) {
VLOG(1) << "Session " << session_handle << " is not found.";
return;
}
item = siter->second;
if (--(item->num_holds) > 0) {
return;
} else {
sessions_.erase(siter);
}
}
delete item;
}
} // end namespace tensorflow
|