aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
blob: f9436566d458a5e1bdd3b40d1a203ad076b8ffb3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_

#include <deque>
#include <vector>
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/tensor.h"

namespace perftools {
namespace gputools {
class Event;
class Stream;
class StreamExecutor;
}  // namespace gputools
}  // namespace perftools

namespace tensorflow {

// An object to keep track of pending Events in the StreamExecutor streams
// and associated Tensors that cannot safely be deleted until the associated
// Events are recorded.
class EventMgr {
 public:
  explicit EventMgr(perftools::gputools::StreamExecutor* se);

  ~EventMgr();

  // Takes ownership of *tensors and deletes it as soon as all events
  // currently enqueued on *stream have completed.
  inline void ThenDeleteTensors(perftools::gputools::Stream* stream,
                                std::vector<Tensor>* tensors) {
    mutex_lock l(mu_);
    QueueTensors(stream, tensors);
    PollEvents(false);
  }

  struct BufRec {
    Allocator* alloc;
    void* buf;
  };

  // Takes ownership of *bufrec.buf and calls bufrec.alloc->DeallocateRaw()
  // on it as soon as all events currently enqueued on *stream have completed.
  inline void ThenDeleteBuffer(perftools::gputools::Stream* stream,
                               BufRec bufrec) {
    mutex_lock l(mu_);
    QueueBuffer(stream, bufrec);
    PollEvents(false);
  }

  inline void ThenExecute(perftools::gputools::Stream* stream,
                          std::function<void()> func) {
    mutex_lock l(mu_);
    QueueFunc(stream, func);
    PollEvents(false);
  }

 private:
  friend class TEST_EventMgrHelper;
  mutex mu_;
  perftools::gputools::StreamExecutor* exec_;

  struct InUse {
    perftools::gputools::Event* event;
    std::vector<Tensor>* mem;
    BufRec bufrec;
    std::function<void()> func;
  };

  // Stream-enqueue an unused Event and save with it a collection of
  // Tensors and/or a BufRec to be deleted only after the Event
  // records.
  void QueueInUse(perftools::gputools::Stream* stream, InUse in_use)
      EXCLUSIVE_LOCKS_REQUIRED(mu_);

  void QueueTensors(perftools::gputools::Stream* stream,
                    std::vector<Tensor>* tensors)
      EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    QueueInUse(stream, {nullptr, tensors, BufRec(), nullptr});
  }

  void QueueBuffer(perftools::gputools::Stream* stream, BufRec bufrec)
      EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    QueueInUse(stream, {nullptr, nullptr, bufrec, nullptr});
  }

  void QueueFunc(perftools::gputools::Stream* stream,
                 std::function<void()> func) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    QueueInUse(stream, {nullptr, nullptr, BufRec(), func});
  }

  // This function should be called at roughly the same tempo as
  // QueueTensors() to check whether pending events have recorded,
  // and then retire them.
  void PollEvents(bool is_dedicated_poller) EXCLUSIVE_LOCKS_REQUIRED(mu_);

  // An internal polling loop that runs at a low frequency to clear
  // straggler Events.
  void PollLoop();

  // A stack of unused events
  std::vector<perftools::gputools::Event*> free_events_ GUARDED_BY(mu_);

  // A FIFO queue of InUse events and associated tensors.
  std::deque<InUse> used_events_ GUARDED_BY(mu_);

  Notification stop_polling_;
  Notification polling_stopped_;

  // The main PollLoop for the event manager runs in this threadpool.
  thread::ThreadPool threadpool_;
};

}  // namespace tensorflow
#endif  // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_