aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/cancellation.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/cancellation.cc')
-rw-r--r--tensorflow/core/framework/cancellation.cc79
1 files changed, 79 insertions, 0 deletions
diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc
new file mode 100644
index 0000000000..51423792a8
--- /dev/null
+++ b/tensorflow/core/framework/cancellation.cc
@@ -0,0 +1,79 @@
+#include "tensorflow/core/framework/cancellation.h"
+
+#include <vector>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+const CancellationToken CancellationManager::kInvalidToken = -1;
+
+CancellationManager::CancellationManager()
+ : is_cancelling_(false), is_cancelled_(0), next_cancellation_token_(0) {}
+
+void CancellationManager::StartCancel() {
+ std::unordered_map<CancellationToken, CancelCallback> callbacks_to_run;
+ {
+ mutex_lock l(mu_);
+ if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) {
+ return;
+ }
+ is_cancelling_ = true;
+ std::swap(callbacks_, callbacks_to_run);
+ }
+ // We call these callbacks without holding mu_, so that concurrent
+ // calls to DeregisterCallback, which can happen asynchronously, do
+ // not block. The callbacks remain valid because any concurrent call
+ // to DeregisterCallback will block until the
+ // cancelled_notification_ is notified.
+ for (auto key_and_value : callbacks_to_run) {
+ key_and_value.second();
+ }
+ {
+ mutex_lock l(mu_);
+ is_cancelling_ = false;
+ is_cancelled_.store(true, std::memory_order_release);
+ }
+ cancelled_notification_.Notify();
+}
+
+CancellationToken CancellationManager::get_cancellation_token() {
+ mutex_lock l(mu_);
+ return next_cancellation_token_++;
+}
+
+bool CancellationManager::RegisterCallback(CancellationToken token,
+ CancelCallback callback) {
+ mutex_lock l(mu_);
+ CHECK_LT(token, next_cancellation_token_) << "Invalid cancellation token";
+ bool should_register = !is_cancelled_ && !is_cancelling_;
+ if (should_register) {
+ std::swap(callbacks_[token], callback);
+ }
+ return should_register;
+}
+
+bool CancellationManager::DeregisterCallback(CancellationToken token) {
+ mu_.lock();
+ if (is_cancelled_) {
+ mu_.unlock();
+ return false;
+ } else if (is_cancelling_) {
+ mu_.unlock();
+ // Wait for all of the cancellation callbacks to be called. This
+ // wait ensures that the caller of DeregisterCallback does not
+ // return immediately and free objects that may be used in the
+ // execution of any currently pending callbacks in StartCancel.
+ cancelled_notification_.WaitForNotification();
+ return false;
+ } else {
+ callbacks_.erase(token);
+ mu_.unlock();
+ return true;
+ }
+}
+
+CancellationManager::~CancellationManager() { StartCancel(); }
+
+} // end namespace tensorflow