#include "tensorflow/core/framework/cancellation.h" #include #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { const CancellationToken CancellationManager::kInvalidToken = -1; CancellationManager::CancellationManager() : is_cancelling_(false), is_cancelled_(0), next_cancellation_token_(0) {} void CancellationManager::StartCancel() { std::unordered_map callbacks_to_run; { mutex_lock l(mu_); if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) { return; } is_cancelling_ = true; std::swap(callbacks_, callbacks_to_run); } // We call these callbacks without holding mu_, so that concurrent // calls to DeregisterCallback, which can happen asynchronously, do // not block. The callbacks remain valid because any concurrent call // to DeregisterCallback will block until the // cancelled_notification_ is notified. for (auto key_and_value : callbacks_to_run) { key_and_value.second(); } { mutex_lock l(mu_); is_cancelling_ = false; is_cancelled_.store(true, std::memory_order_release); } cancelled_notification_.Notify(); } CancellationToken CancellationManager::get_cancellation_token() { mutex_lock l(mu_); return next_cancellation_token_++; } bool CancellationManager::RegisterCallback(CancellationToken token, CancelCallback callback) { mutex_lock l(mu_); CHECK_LT(token, next_cancellation_token_) << "Invalid cancellation token"; bool should_register = !is_cancelled_ && !is_cancelling_; if (should_register) { std::swap(callbacks_[token], callback); } return should_register; } bool CancellationManager::DeregisterCallback(CancellationToken token) { mu_.lock(); if (is_cancelled_) { mu_.unlock(); return false; } else if (is_cancelling_) { mu_.unlock(); // Wait for all of the cancellation callbacks to be called. This // wait ensures that the caller of DeregisterCallback does not // return immediately and free objects that may be used in the // execution of any currently pending callbacks in StartCancel. cancelled_notification_.WaitForNotification(); return false; } else { callbacks_.erase(token); mu_.unlock(); return true; } } CancellationManager::~CancellationManager() { StartCancel(); } } // end namespace tensorflow