aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/resource_mgr.h
blob: 65e859caf1dfb3e387b48107b8cd909eaa8243d6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
#ifndef TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_
#define TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_

#include <string>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>

#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/status.h"

namespace tensorflow {

// A ResourceMgr instance keeps track of named and typed resources
// grouped into containers.
//
// Each resource must be represented as a sub-class of ResourceBase,
// which is reference counted explicitly.  Each named resource is
// registered with ResourceMgr under a named "container" name. At any
// time, there is at most one instance of a resource given the container
// name, the resource type and the resource name.
//
// All resources for a given container can be dropped by one call of
// Cleanup().
//
// E.g.,
//   struct MyVar : public ResourceBase {
//     mutex mu;
//     Tensor val;
//   }
//
//   ResourceMgr rm;
//
//   // Create a var.
//   MyVar* my_var = new MyVar;
//   my_var.val = Tensor(DT_FLOAT, my_shape);
//   my_val.val.flat<float>().setZeros();   // 0 initialized.
//   ctx->SetStatus(rm.Create("my_container", "my_name", my_val));
//
//   // += a variable.
//   MyVar* my_var = nullptr;
//   Status s = rm.Lookup("my_container", "my_name", &my_var);
//   if (s.ok()) {
//     my_var->val.flat<float>() += grad;
//   }
//   my_var->Unref();   // Or use ScopedUnref().
//   ctx->SetStatus(s);
class ResourceBase : public core::RefCounted {
 public:
  // Returns a debug string for *this.
  virtual string DebugString() = 0;
};

class ResourceMgr {
 public:
  ResourceMgr();
  explicit ResourceMgr(const string& default_container);
  ~ResourceMgr();

  // Returns the default container name for *this.
  const string& default_container() const { return default_container_; }

  // Creates a resource "name" in the "container".  The caller transfers
  // the ownership of one ref on "resource" to *this
  //
  // REQUIRES: std::is_base_of<ResourceBase, T>
  // REQUIRES: resource != nullptr.
  template <typename T>
  Status Create(const string& container, const string& name,
                T* resource) TF_MUST_USE_RESULT;

  // If "container" has a resource "name", returns it in "*resource" and
  // the caller takes the ownership of one ref on "*resource".
  //
  // REQUIRES: std::is_base_of<ResourceBase, T>
  // REQUIRES: resource != nullptr
  template <typename T>
  Status Lookup(const string& container, const string& name,
                T** resource) const TF_MUST_USE_RESULT;

  // If "container" has a resource "name", returns it in
  // "*resource". Otherwise, invokes creator() to create the resource.
  // The caller takes the ownership of one ref on "*resource".
  //
  // REQUIRES: std::is_base_of<ResourceBase, T>
  // REQUIRES: resource != nullptr
  template <typename T>
  Status LookupOrCreate(const string& container, const string& name,
                        T** resource,
                        std::function<Status(T**)> creator) TF_MUST_USE_RESULT;

  // Deletes the resource "name" from the "container".
  //
  // REQUIRES: std::is_base_of<ResourceBase, T>
  template <typename T>
  Status Delete(const string& container, const string& name) TF_MUST_USE_RESULT;

  // Deletes all resources from the "container" and removes the container.
  Status Cleanup(const string& container) TF_MUST_USE_RESULT;

  // Deletes all resources in all containers.
  void Clear();

 private:
  typedef std::pair<std::type_index, string> Key;
  struct KeyHash {
    std::size_t operator()(const Key& k) const {
      return Hash64(k.second.data(), k.second.size(), k.first.hash_code());
    }
  };
  struct KeyEqual {
    bool operator()(const Key& x, const Key& y) const {
      return (x.second == y.second) && (x.first == y.first);
    }
  };
  typedef std::unordered_map<Key, ResourceBase*, KeyHash, KeyEqual> Container;

  const string default_container_;
  mutable mutex mu_;
  std::unordered_map<string, Container*> containers_ GUARDED_BY(mu_);

  Status DoCreate(const string& container, std::type_index type,
                  const string& name,
                  ResourceBase* resource) TF_MUST_USE_RESULT;
  Status DoLookup(const string& container, std::type_index type,
                  const string& name,
                  ResourceBase** resource) const TF_MUST_USE_RESULT;
  Status DoDelete(const string& container, std::type_index type,
                  const string& name) TF_MUST_USE_RESULT;

  TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr);
};

// Policy helper to decide which container/shared_name to use for a
// stateful kernel that accesses shared resource.
class ContainerInfo {
 public:
  // Analyze the node attribute of 'ndef' and decides the container and
  // resource name the kernel should use for accessing the shared
  // resource.
  //
  // 'ndef' is expected to have node attribute "container" and
  // "shared_name". Returns non-OK if they are not provided or they are
  // invalid.
  //
  // The policy is as following:
  // * If the attribute "container" is non-empty, it is used as is.
  //   Otherwise, uses the resource manager's default container.
  // * If the attribute "shared_name" is non-empty, it is used as is.
  //   Otherwise, if "use_node_name_as_default" is true, the kernel's
  //   node name is used as the resource name. Otherwise, a string
  //   unique to this process is used.
  Status Init(ResourceMgr* rmgr, const NodeDef& ndef,
              bool use_node_name_as_default);
  Status Init(ResourceMgr* rmgr, const NodeDef& ndef) {
    return Init(rmgr, ndef, false);
  }

  // The policy decides that the kernel should access the resource in
  // resource_manager(), the resource is in the container() and its
  // name is name().  If resource_is_private_to_kernel() is true, the
  // kernel should delete the resource when the kernel is deleted.
  ResourceMgr* resource_manager() const { return rmgr_; }
  const string& container() const { return container_; }
  const string& name() const { return name_; }
  bool resource_is_private_to_kernel() const {
    return resource_is_private_to_kernel_;
  }

  // Returns a readable string for *this.
  string DebugString() const;

 private:
  ResourceMgr* rmgr_ = nullptr;
  string container_;
  string name_;
  bool resource_is_private_to_kernel_ = false;
};

// Helper for kernels to obtain 'resource' from the
// ctx->resource_manager().
//
// "input_name" specifies the kernel's ref input which gives a string
// tensor with two elements, which specifies the container and
// resource name.
//
// Returns OK if the resource is found and transfers one ref of
// *resource to the caller. Otherwise, returns an error.
template <typename T>
Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
                              T** resource);

// Implementation details below.

template <typename T>
void CheckDeriveFromResourceBase() {
  static_assert(std::is_base_of<ResourceBase, T>::value,
                "T must derive from ResourceBase");
}

template <typename T>
Status ResourceMgr::Create(const string& container, const string& name,
                           T* resource) {
  CheckDeriveFromResourceBase<T>();
  CHECK(resource != nullptr);
  return DoCreate(container, std::type_index(typeid(T)), name, resource);
}

template <typename T>
Status ResourceMgr::Lookup(const string& container, const string& name,
                           T** resource) const {
  CheckDeriveFromResourceBase<T>();
  ResourceBase* found = nullptr;
  Status s = DoLookup(container, std::type_index(typeid(T)), name, &found);
  if (s.ok()) {
    // It's safe to down cast 'found' to T* since
    // typeid(T).hash_code() is part of the map key.
    *resource = static_cast<T*>(found);
  }
  return s;
}

template <typename T>
Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
                                   T** resource,
                                   std::function<Status(T**)> creator) {
  Status s;
  *resource = nullptr;
  while (*resource == nullptr) {
    s = Lookup(container, name, resource);
    if (s.ok()) break;
    s = creator(resource);
    if (!s.ok()) break;
    s = Create(container, name, *resource);
    if (s.ok()) {
      (*resource)->Ref();
      break;
    }
    // Rare event. Concurrent racy creation. Redo the lookup.
    *resource = nullptr;
  }
  return s;
}

template <typename T>
Status ResourceMgr::Delete(const string& container, const string& name) {
  CheckDeriveFromResourceBase<T>();
  return DoDelete(container, std::type_index(typeid(T)), name);
}

template <typename T>
Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
                              T** resource) {
  string container;
  string shared_name;
  {
    mutex* mu;
    TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
    mutex_lock l(*mu);
    Tensor tensor;
    TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true));
    if (tensor.NumElements() != 2) {
      return errors::InvalidArgument(
          "Resource handle must have 2 elements, but had shape: ",
          tensor.shape().DebugString());
    }
    container = tensor.flat<string>()(0);
    shared_name = tensor.flat<string>()(1);
  }
  return ctx->resource_manager()->Lookup(container, shared_name, resource);
}

}  //  end namespace tensorflow

#endif  // TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_