aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/op_segment.h
blob: 55249d2a382f61d99adcb9d6d7a0b5887f077d2d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#ifndef TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_
#define TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_

#include <string>
#include <unordered_map>

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/status.h"

namespace tensorflow {

// OpSegment keeps track of OpKernels registered for sessions running
// on a device.
//
// The implementation maintains a two-level map. The 1st level maps
// session handle to the map of registered OpKernels. The 2nd level
// map maps node names to instantiated OpKernel objects.
//
// Each 2-nd level map is reference-counted and the caller can call
// AddHold to obtain a reference on all kernels of a session and
// ensure these kernels are alive until a corresponding RemoveHold is
// called on the same session.
class OpSegment {
 public:
  OpSegment();
  ~OpSegment();

  // A hold can be placed on a session, preventing all its kernels
  // from being deleted.
  void AddHold(const string& session_handle);
  void RemoveHold(const string& session_handle);

  // If the kernel for "node_name" has been created in the
  // "session_handle", returns the existing op kernel in "*kernel".
  // Otherwise, creates the kernel by calling create_fn(), cache it,
  // and returns it in "*kernel". If create_fn() fails, returns the
  // error.
  //
  // OpSegment keeps the ownership of the returned "*kernel".
  typedef std::function<Status(OpKernel**)> CreateKernelFn;
  Status FindOrCreate(const string& session_handle, const string& node_name,
                      OpKernel** kernel, CreateKernelFn create_fn);

 private:
  // op name -> OpKernel
  typedef std::unordered_map<string, OpKernel*> KernelMap;
  struct Item {
    int num_holds = 1;      // Num of holds put on the session.
    KernelMap name_kernel;  // op name -> kernel.
    ~Item();
  };

  // session handle -> item.
  // Session handles are produced by strings::FpToString()
  typedef std::unordered_map<string, Item*> SessionMap;

  mutable mutex mu_;
  SessionMap sessions_ GUARDED_BY(mu_);

  TF_DISALLOW_COPY_AND_ASSIGN(OpSegment);
};

}  // end namespace tensorflow

#endif  // TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_