diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-20 01:43:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-20 01:46:55 -0700 |
commit | a54310b1faa39df94dcef9ad1b5aaa0acc691e35 (patch) | |
tree | b8f1a60490cd697e008b89569f775dd5aede5799 /tensorflow/core/framework | |
parent | da3357ecbdd6772413e8bbceeab8238971be11ce (diff) |
Internal change.
PiperOrigin-RevId: 213770000
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/cancellation.cc | 10 | ||||
-rw-r--r-- | tensorflow/core/framework/cancellation.h | 9 | ||||
-rw-r--r-- | tensorflow/core/framework/cancellation_test.cc | 52 |
3 files changed, 71 insertions, 0 deletions
diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc index 1258e40c93..af59500aee 100644 --- a/tensorflow/core/framework/cancellation.cc +++ b/tensorflow/core/framework/cancellation.cc @@ -89,6 +89,16 @@ bool CancellationManager::DeregisterCallback(CancellationToken token) { } } +bool CancellationManager::TryDeregisterCallback(CancellationToken token) { + mutex_lock lock(mu_); + if (is_cancelled_ || is_cancelling_) { + return false; + } else { + callbacks_.erase(token); + return true; + } +} + CancellationManager::~CancellationManager() { if (!callbacks_.empty()) { StartCancel(); diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h index acdaaf6a90..7a5d942486 100644 --- a/tensorflow/core/framework/cancellation.h +++ b/tensorflow/core/framework/cancellation.h @@ -122,6 +122,15 @@ class CancellationManager { // cancellation manager. bool DeregisterCallback(CancellationToken token); + // Deregister the callback that, when registered, was associated + // with the given cancellation token. Returns true iff the callback + // was deregistered and will not be invoked; otherwise returns false + // immediately, with no guarantee that the callback has completed. + // + // This method is guaranteed to return true if StartCancel has not been + // called. + bool TryDeregisterCallback(CancellationToken token); + private: bool is_cancelling_; std::atomic_bool is_cancelled_; diff --git a/tensorflow/core/framework/cancellation_test.cc b/tensorflow/core/framework/cancellation_test.cc index e3f18240b5..bf7593bc5f 100644 --- a/tensorflow/core/framework/cancellation_test.cc +++ b/tensorflow/core/framework/cancellation_test.cc @@ -115,4 +115,56 @@ TEST(Cancellation, IsCancelled) { delete cm; } +TEST(Cancellation, TryDeregisterWithoutCancel) { + bool is_cancelled = false; + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + bool registered = manager->RegisterCallback( + token, [&is_cancelled]() { is_cancelled = true; }); + EXPECT_TRUE(registered); + bool deregistered = manager->TryDeregisterCallback(token); + EXPECT_TRUE(deregistered); + delete manager; + EXPECT_FALSE(is_cancelled); +} + +TEST(Cancellation, TryDeregisterAfterCancel) { + bool is_cancelled = false; + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + bool registered = manager->RegisterCallback( + token, [&is_cancelled]() { is_cancelled = true; }); + EXPECT_TRUE(registered); + manager->StartCancel(); + EXPECT_TRUE(is_cancelled); + bool deregistered = manager->TryDeregisterCallback(token); + EXPECT_FALSE(deregistered); + delete manager; +} + +TEST(Cancellation, TryDeregisterDuringCancel) { + Notification cancel_started, finish_callback, cancel_complete; + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + bool registered = manager->RegisterCallback(token, [&]() { + cancel_started.Notify(); + finish_callback.WaitForNotification(); + }); + EXPECT_TRUE(registered); + + thread::ThreadPool w(Env::Default(), "test", 1); + w.Schedule([&]() { + manager->StartCancel(); + cancel_complete.Notify(); + }); + cancel_started.WaitForNotification(); + + bool deregistered = manager->TryDeregisterCallback(token); + EXPECT_FALSE(deregistered); + + finish_callback.Notify(); + cancel_complete.WaitForNotification(); + delete manager; +} + } // namespace tensorflow |