aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/cancellation.h
blob: feda548e9724beab3bf9a1b206431d6f8b2483b9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#ifndef TENSORFLOW_FRAMEWORK_CANCELLATION_H_
#define TENSORFLOW_FRAMEWORK_CANCELLATION_H_

#include <atomic>
#include <functional>
#include <unordered_map>

#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/status.h"

namespace tensorflow {

// A token that can be used to register and deregister a
// CancelCallback with a CancellationManager.
//
// CancellationToken values must be created by a call to
// CancellationManager::get_cancellation_token.
typedef int64 CancellationToken;

// A callback that is invoked when a step is cancelled.
//
// NOTE(mrry): See caveats about CancelCallback implementations in the
// comment for CancellationManager::RegisterCallback.
typedef std::function<void()> CancelCallback;

class CancellationManager {
 public:
  // A value that won't be returned by get_cancellation_token().
  static const CancellationToken kInvalidToken;

  CancellationManager();
  ~CancellationManager();

  // Run all callbacks associated with this manager.
  void StartCancel();

  // Returns true iff StartCancel() has been called.
  bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); }

  // Returns a token that must be used in calls to RegisterCallback
  // and DeregisterCallback.
  CancellationToken get_cancellation_token();

  // Attempts to register the given callback to be invoked when this
  // manager is cancelled. Returns true if the callback was
  // registered; returns false if this manager was already cancelled,
  // and the callback was not registered.
  //
  // If this method returns false, it is the caller's responsibility
  // to perform any cancellation cleanup.
  //
  // This method is tricky to use correctly. The following usage pattern
  // is recommended:
  //
  // class ObjectWithCancellableOperation {
  //   mutex mu_;
  //   void CancellableOperation(CancellationManager* cm,
  //                             std::function<void(Status)> callback) {
  //     bool already_cancelled;
  //     CancellationToken token = cm->get_cancellation_token();
  //     {
  //       mutex_lock(mu_);
  //       already_cancelled = cm->RegisterCallback(
  //           [this, token]() { Cancel(token); });
  //       if (!already_cancelled) {
  //         // Issue asynchronous operation. Associate the pending operation
  //         // with `token` in some object state, or provide another way for
  //         // the Cancel method to look up the operation for cancellation.
  //         // Ensure that `cm->DeregisterCallback(token)` is called without
  //         // holding `mu_`, before `callback` is invoked.
  //         // ...
  //       }
  //     }
  //     if (already_cancelled) {
  //       callback(errors::Cancelled("Operation was cancelled"));
  //     }
  //   }
  //
  //   void Cancel(CancellationToken token) {
  //     mutex_lock(mu_);
  //     // Take action to cancel the operation with the given cancellation
  //     // token.
  //   }
  //
  // NOTE(mrry): The caller should take care that (i) the calling code
  // is robust to `callback` being invoked asynchronously (e.g. from
  // another thread), (ii) `callback` is deregistered by a call to
  // this->DeregisterCallback(token) when the operation completes
  // successfully, and (iii) `callback` does not invoke any method
  // on this cancellation manager. Furthermore, it is important that
  // the eventual caller of the complementary DeregisterCallback does not
  // hold any mutexes that are required by `callback`.
  bool RegisterCallback(CancellationToken token, CancelCallback callback);

  // Deregister the callback that, when registered, was associated
  // with the given cancellation token. Returns true iff the callback
  // was deregistered and will not be invoked; otherwise returns false
  // after the callback has been invoked, blocking if necessary.
  //
  // NOTE(mrry): This method may block if cancellation is in progress.
  // The caller of this method must not hold any mutexes that are required
  // to invoke any cancellation callback that has been registered with this
  // cancellation manager.
  bool DeregisterCallback(CancellationToken token);

 private:
  bool is_cancelling_;
  std::atomic_bool is_cancelled_;

  mutex mu_;
  Notification cancelled_notification_;
  CancellationToken next_cancellation_token_ GUARDED_BY(mu_);
  std::unordered_map<CancellationToken, CancelCallback> callbacks_
      GUARDED_BY(mu_);
};

}  // namespace tensorflow

#endif  // TENSORFLOW_FRAMEWORK_CANCELLATION_H_