aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-20 01:43:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 01:46:55 -0700
commita54310b1faa39df94dcef9ad1b5aaa0acc691e35 (patch)
treeb8f1a60490cd697e008b89569f775dd5aede5799 /tensorflow/core/framework
parentda3357ecbdd6772413e8bbceeab8238971be11ce (diff)
Internal change.
PiperOrigin-RevId: 213770000
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/cancellation.cc10
-rw-r--r--tensorflow/core/framework/cancellation.h9
-rw-r--r--tensorflow/core/framework/cancellation_test.cc52
3 files changed, 71 insertions, 0 deletions
diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc
index 1258e40c93..af59500aee 100644
--- a/tensorflow/core/framework/cancellation.cc
+++ b/tensorflow/core/framework/cancellation.cc
@@ -89,6 +89,16 @@ bool CancellationManager::DeregisterCallback(CancellationToken token) {
}
}
+bool CancellationManager::TryDeregisterCallback(CancellationToken token) {
+ mutex_lock lock(mu_);
+ if (is_cancelled_ || is_cancelling_) {
+ return false;
+ } else {
+ callbacks_.erase(token);
+ return true;
+ }
+}
+
CancellationManager::~CancellationManager() {
if (!callbacks_.empty()) {
StartCancel();
diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h
index acdaaf6a90..7a5d942486 100644
--- a/tensorflow/core/framework/cancellation.h
+++ b/tensorflow/core/framework/cancellation.h
@@ -122,6 +122,15 @@ class CancellationManager {
// cancellation manager.
bool DeregisterCallback(CancellationToken token);
+ // Deregister the callback that, when registered, was associated
+ // with the given cancellation token. Returns true iff the callback
+ // was deregistered and will not be invoked; otherwise returns false
+ // immediately, with no guarantee that the callback has completed.
+ //
+ // This method is guaranteed to return true if StartCancel has not been
+ // called.
+ bool TryDeregisterCallback(CancellationToken token);
+
private:
bool is_cancelling_;
std::atomic_bool is_cancelled_;
diff --git a/tensorflow/core/framework/cancellation_test.cc b/tensorflow/core/framework/cancellation_test.cc
index e3f18240b5..bf7593bc5f 100644
--- a/tensorflow/core/framework/cancellation_test.cc
+++ b/tensorflow/core/framework/cancellation_test.cc
@@ -115,4 +115,56 @@ TEST(Cancellation, IsCancelled) {
delete cm;
}
+TEST(Cancellation, TryDeregisterWithoutCancel) {
+ bool is_cancelled = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(
+ token, [&is_cancelled]() { is_cancelled = true; });
+ EXPECT_TRUE(registered);
+ bool deregistered = manager->TryDeregisterCallback(token);
+ EXPECT_TRUE(deregistered);
+ delete manager;
+ EXPECT_FALSE(is_cancelled);
+}
+
+TEST(Cancellation, TryDeregisterAfterCancel) {
+ bool is_cancelled = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(
+ token, [&is_cancelled]() { is_cancelled = true; });
+ EXPECT_TRUE(registered);
+ manager->StartCancel();
+ EXPECT_TRUE(is_cancelled);
+ bool deregistered = manager->TryDeregisterCallback(token);
+ EXPECT_FALSE(deregistered);
+ delete manager;
+}
+
+TEST(Cancellation, TryDeregisterDuringCancel) {
+ Notification cancel_started, finish_callback, cancel_complete;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(token, [&]() {
+ cancel_started.Notify();
+ finish_callback.WaitForNotification();
+ });
+ EXPECT_TRUE(registered);
+
+ thread::ThreadPool w(Env::Default(), "test", 1);
+ w.Schedule([&]() {
+ manager->StartCancel();
+ cancel_complete.Notify();
+ });
+ cancel_started.WaitForNotification();
+
+ bool deregistered = manager->TryDeregisterCallback(token);
+ EXPECT_FALSE(deregistered);
+
+ finish_callback.Notify();
+ cancel_complete.WaitForNotification();
+ delete manager;
+}
+
} // namespace tensorflow