aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/cancellation.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/cancellation.h')
-rw-r--r--tensorflow/core/framework/cancellation.h121
1 files changed, 121 insertions, 0 deletions
diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h
new file mode 100644
index 0000000000..feda548e97
--- /dev/null
+++ b/tensorflow/core/framework/cancellation.h
@@ -0,0 +1,121 @@
+#ifndef TENSORFLOW_FRAMEWORK_CANCELLATION_H_
+#define TENSORFLOW_FRAMEWORK_CANCELLATION_H_
+
+#include <atomic>
+#include <functional>
+#include <unordered_map>
+
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// A token that can be used to register and deregister a
+// CancelCallback with a CancellationManager.
+//
+// CancellationToken values must be created by a call to
+// CancellationManager::get_cancellation_token.
+typedef int64 CancellationToken;
+
+// A callback that is invoked when a step is cancelled.
+//
+// NOTE(mrry): See caveats about CancelCallback implementations in the
+// comment for CancellationManager::RegisterCallback.
+typedef std::function<void()> CancelCallback;
+
+class CancellationManager {
+ public:
+ // A value that won't be returned by get_cancellation_token().
+ static const CancellationToken kInvalidToken;
+
+ CancellationManager();
+ ~CancellationManager();
+
+ // Run all callbacks associated with this manager.
+ void StartCancel();
+
+ // Returns true iff StartCancel() has been called.
+ bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); }
+
+ // Returns a token that must be used in calls to RegisterCallback
+ // and DeregisterCallback.
+ CancellationToken get_cancellation_token();
+
+ // Attempts to register the given callback to be invoked when this
+ // manager is cancelled. Returns true if the callback was
+ // registered; returns false if this manager was already cancelled,
+ // and the callback was not registered.
+ //
+ // If this method returns false, it is the caller's responsibility
+ // to perform any cancellation cleanup.
+ //
+ // This method is tricky to use correctly. The following usage pattern
+ // is recommended:
+ //
+ // class ObjectWithCancellableOperation {
+ // mutex mu_;
+ // void CancellableOperation(CancellationManager* cm,
+ // std::function<void(Status)> callback) {
+ // bool already_cancelled;
+ // CancellationToken token = cm->get_cancellation_token();
+ // {
+ // mutex_lock(mu_);
+ // already_cancelled = cm->RegisterCallback(
+ // [this, token]() { Cancel(token); });
+ // if (!already_cancelled) {
+ // // Issue asynchronous operation. Associate the pending operation
+ // // with `token` in some object state, or provide another way for
+ // // the Cancel method to look up the operation for cancellation.
+ // // Ensure that `cm->DeregisterCallback(token)` is called without
+ // // holding `mu_`, before `callback` is invoked.
+ // // ...
+ // }
+ // }
+ // if (already_cancelled) {
+ // callback(errors::Cancelled("Operation was cancelled"));
+ // }
+ // }
+ //
+ // void Cancel(CancellationToken token) {
+ // mutex_lock(mu_);
+ // // Take action to cancel the operation with the given cancellation
+ // // token.
+ // }
+ //
+ // NOTE(mrry): The caller should take care that (i) the calling code
+ // is robust to `callback` being invoked asynchronously (e.g. from
+ // another thread), (ii) `callback` is deregistered by a call to
+ // this->DeregisterCallback(token) when the operation completes
+ // successfully, and (iii) `callback` does not invoke any method
+ // on this cancellation manager. Furthermore, it is important that
+ // the eventual caller of the complementary DeregisterCallback does not
+ // hold any mutexes that are required by `callback`.
+ bool RegisterCallback(CancellationToken token, CancelCallback callback);
+
+ // Deregister the callback that, when registered, was associated
+ // with the given cancellation token. Returns true iff the callback
+ // was deregistered and will not be invoked; otherwise returns false
+ // after the callback has been invoked, blocking if necessary.
+ //
+ // NOTE(mrry): This method may block if cancellation is in progress.
+ // The caller of this method must not hold any mutexes that are required
+ // to invoke any cancellation callback that has been registered with this
+ // cancellation manager.
+ bool DeregisterCallback(CancellationToken token);
+
+ private:
+ bool is_cancelling_;
+ std::atomic_bool is_cancelled_;
+
+ mutex mu_;
+ Notification cancelled_notification_;
+ CancellationToken next_cancellation_token_ GUARDED_BY(mu_);
+ std::unordered_map<CancellationToken, CancelCallback> callbacks_
+ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_CANCELLATION_H_