diff options
Diffstat (limited to 'tensorflow/core/framework/cancellation.h')
-rw-r--r-- | tensorflow/core/framework/cancellation.h | 121 |
1 files changed, 121 insertions, 0 deletions
diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h new file mode 100644 index 0000000000..feda548e97 --- /dev/null +++ b/tensorflow/core/framework/cancellation.h @@ -0,0 +1,121 @@ +#ifndef TENSORFLOW_FRAMEWORK_CANCELLATION_H_ +#define TENSORFLOW_FRAMEWORK_CANCELLATION_H_ + +#include <atomic> +#include <functional> +#include <unordered_map> + +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// A token that can be used to register and deregister a +// CancelCallback with a CancellationManager. +// +// CancellationToken values must be created by a call to +// CancellationManager::get_cancellation_token. +typedef int64 CancellationToken; + +// A callback that is invoked when a step is cancelled. +// +// NOTE(mrry): See caveats about CancelCallback implementations in the +// comment for CancellationManager::RegisterCallback. +typedef std::function<void()> CancelCallback; + +class CancellationManager { + public: + // A value that won't be returned by get_cancellation_token(). + static const CancellationToken kInvalidToken; + + CancellationManager(); + ~CancellationManager(); + + // Run all callbacks associated with this manager. + void StartCancel(); + + // Returns true iff StartCancel() has been called. + bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); } + + // Returns a token that must be used in calls to RegisterCallback + // and DeregisterCallback. + CancellationToken get_cancellation_token(); + + // Attempts to register the given callback to be invoked when this + // manager is cancelled. Returns true if the callback was + // registered; returns false if this manager was already cancelled, + // and the callback was not registered. + // + // If this method returns false, it is the caller's responsibility + // to perform any cancellation cleanup. + // + // This method is tricky to use correctly. The following usage pattern + // is recommended: + // + // class ObjectWithCancellableOperation { + // mutex mu_; + // void CancellableOperation(CancellationManager* cm, + // std::function<void(Status)> callback) { + // bool already_cancelled; + // CancellationToken token = cm->get_cancellation_token(); + // { + // mutex_lock(mu_); + // already_cancelled = cm->RegisterCallback( + // [this, token]() { Cancel(token); }); + // if (!already_cancelled) { + // // Issue asynchronous operation. Associate the pending operation + // // with `token` in some object state, or provide another way for + // // the Cancel method to look up the operation for cancellation. + // // Ensure that `cm->DeregisterCallback(token)` is called without + // // holding `mu_`, before `callback` is invoked. + // // ... + // } + // } + // if (already_cancelled) { + // callback(errors::Cancelled("Operation was cancelled")); + // } + // } + // + // void Cancel(CancellationToken token) { + // mutex_lock(mu_); + // // Take action to cancel the operation with the given cancellation + // // token. + // } + // + // NOTE(mrry): The caller should take care that (i) the calling code + // is robust to `callback` being invoked asynchronously (e.g. from + // another thread), (ii) `callback` is deregistered by a call to + // this->DeregisterCallback(token) when the operation completes + // successfully, and (iii) `callback` does not invoke any method + // on this cancellation manager. Furthermore, it is important that + // the eventual caller of the complementary DeregisterCallback does not + // hold any mutexes that are required by `callback`. + bool RegisterCallback(CancellationToken token, CancelCallback callback); + + // Deregister the callback that, when registered, was associated + // with the given cancellation token. Returns true iff the callback + // was deregistered and will not be invoked; otherwise returns false + // after the callback has been invoked, blocking if necessary. + // + // NOTE(mrry): This method may block if cancellation is in progress. + // The caller of this method must not hold any mutexes that are required + // to invoke any cancellation callback that has been registered with this + // cancellation manager. + bool DeregisterCallback(CancellationToken token); + + private: + bool is_cancelling_; + std::atomic_bool is_cancelled_; + + mutex mu_; + Notification cancelled_notification_; + CancellationToken next_cancellation_token_ GUARDED_BY(mu_); + std::unordered_map<CancellationToken, CancelCallback> callbacks_ + GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_CANCELLATION_H_ |