diff options
Diffstat (limited to 'tensorflow/core/framework/cancellation.cc')
-rw-r--r-- | tensorflow/core/framework/cancellation.cc | 79 |
1 files changed, 79 insertions, 0 deletions
diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc new file mode 100644 index 0000000000..51423792a8 --- /dev/null +++ b/tensorflow/core/framework/cancellation.cc @@ -0,0 +1,79 @@ +#include "tensorflow/core/framework/cancellation.h" + +#include <vector> + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +const CancellationToken CancellationManager::kInvalidToken = -1; + +CancellationManager::CancellationManager() + : is_cancelling_(false), is_cancelled_(0), next_cancellation_token_(0) {} + +void CancellationManager::StartCancel() { + std::unordered_map<CancellationToken, CancelCallback> callbacks_to_run; + { + mutex_lock l(mu_); + if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) { + return; + } + is_cancelling_ = true; + std::swap(callbacks_, callbacks_to_run); + } + // We call these callbacks without holding mu_, so that concurrent + // calls to DeregisterCallback, which can happen asynchronously, do + // not block. The callbacks remain valid because any concurrent call + // to DeregisterCallback will block until the + // cancelled_notification_ is notified. + for (auto key_and_value : callbacks_to_run) { + key_and_value.second(); + } + { + mutex_lock l(mu_); + is_cancelling_ = false; + is_cancelled_.store(true, std::memory_order_release); + } + cancelled_notification_.Notify(); +} + +CancellationToken CancellationManager::get_cancellation_token() { + mutex_lock l(mu_); + return next_cancellation_token_++; +} + +bool CancellationManager::RegisterCallback(CancellationToken token, + CancelCallback callback) { + mutex_lock l(mu_); + CHECK_LT(token, next_cancellation_token_) << "Invalid cancellation token"; + bool should_register = !is_cancelled_ && !is_cancelling_; + if (should_register) { + std::swap(callbacks_[token], callback); + } + return should_register; +} + +bool CancellationManager::DeregisterCallback(CancellationToken token) { + mu_.lock(); + if (is_cancelled_) { + mu_.unlock(); + return false; + } else if (is_cancelling_) { + mu_.unlock(); + // Wait for all of the cancellation callbacks to be called. This + // wait ensures that the caller of DeregisterCallback does not + // return immediately and free objects that may be used in the + // execution of any currently pending callbacks in StartCancel. + cancelled_notification_.WaitForNotification(); + return false; + } else { + callbacks_.erase(token); + mu_.unlock(); + return true; + } +} + +CancellationManager::~CancellationManager() { StartCancel(); } + +} // end namespace tensorflow |