1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
|
#ifndef TENSORFLOW_FRAMEWORK_CANCELLATION_H_
#define TENSORFLOW_FRAMEWORK_CANCELLATION_H_
#include <atomic>
#include <functional>
#include <unordered_map>
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/status.h"
namespace tensorflow {
// A token that can be used to register and deregister a
// CancelCallback with a CancellationManager.
//
// CancellationToken values must be created by a call to
// CancellationManager::get_cancellation_token.
typedef int64 CancellationToken;
// A callback that is invoked when a step is cancelled.
//
// NOTE(mrry): See caveats about CancelCallback implementations in the
// comment for CancellationManager::RegisterCallback.
typedef std::function<void()> CancelCallback;
class CancellationManager {
public:
// A value that won't be returned by get_cancellation_token().
static const CancellationToken kInvalidToken;
CancellationManager();
~CancellationManager();
// Run all callbacks associated with this manager.
void StartCancel();
// Returns true iff StartCancel() has been called.
bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); }
// Returns a token that must be used in calls to RegisterCallback
// and DeregisterCallback.
CancellationToken get_cancellation_token();
// Attempts to register the given callback to be invoked when this
// manager is cancelled. Returns true if the callback was
// registered; returns false if this manager was already cancelled,
// and the callback was not registered.
//
// If this method returns false, it is the caller's responsibility
// to perform any cancellation cleanup.
//
// This method is tricky to use correctly. The following usage pattern
// is recommended:
//
// class ObjectWithCancellableOperation {
// mutex mu_;
// void CancellableOperation(CancellationManager* cm,
// std::function<void(Status)> callback) {
// bool already_cancelled;
// CancellationToken token = cm->get_cancellation_token();
// {
// mutex_lock(mu_);
// already_cancelled = cm->RegisterCallback(
// [this, token]() { Cancel(token); });
// if (!already_cancelled) {
// // Issue asynchronous operation. Associate the pending operation
// // with `token` in some object state, or provide another way for
// // the Cancel method to look up the operation for cancellation.
// // Ensure that `cm->DeregisterCallback(token)` is called without
// // holding `mu_`, before `callback` is invoked.
// // ...
// }
// }
// if (already_cancelled) {
// callback(errors::Cancelled("Operation was cancelled"));
// }
// }
//
// void Cancel(CancellationToken token) {
// mutex_lock(mu_);
// // Take action to cancel the operation with the given cancellation
// // token.
// }
//
// NOTE(mrry): The caller should take care that (i) the calling code
// is robust to `callback` being invoked asynchronously (e.g. from
// another thread), (ii) `callback` is deregistered by a call to
// this->DeregisterCallback(token) when the operation completes
// successfully, and (iii) `callback` does not invoke any method
// on this cancellation manager. Furthermore, it is important that
// the eventual caller of the complementary DeregisterCallback does not
// hold any mutexes that are required by `callback`.
bool RegisterCallback(CancellationToken token, CancelCallback callback);
// Deregister the callback that, when registered, was associated
// with the given cancellation token. Returns true iff the callback
// was deregistered and will not be invoked; otherwise returns false
// after the callback has been invoked, blocking if necessary.
//
// NOTE(mrry): This method may block if cancellation is in progress.
// The caller of this method must not hold any mutexes that are required
// to invoke any cancellation callback that has been registered with this
// cancellation manager.
bool DeregisterCallback(CancellationToken token);
private:
bool is_cancelling_;
std::atomic_bool is_cancelled_;
mutex mu_;
Notification cancelled_notification_;
CancellationToken next_cancellation_token_ GUARDED_BY(mu_);
std::unordered_map<CancellationToken, CancelCallback> callbacks_
GUARDED_BY(mu_);
};
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_CANCELLATION_H_
|