aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/reffed_status_callback_test.cc
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-09-27 10:27:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-27 10:33:58 -0700
commit2ce49b2f6ad56b06ddc156c3b998ede6f4d1958e (patch)
treefe371b40a1f184a68d823bf53d5fd40ff388d691 /tensorflow/core/util/reffed_status_callback_test.cc
parent5cac28c41af785532e90101787cf85545cdac410 (diff)
Add new ReffedStatusCallback util class. This class allows multiple threads to
update a status before the underlying callback is executed. The use pattern is: auto cb = new ReffesStatusCallback(std::move(done)); auto execution = [cb](...) { if (cb->ok()) { cb->Ref(); ... } }; auto post_execution = [cb](const Status& s) { cb->SetStatus(s); cb->Unref(); } Status r = CallAsyncOp( ..., std::move(execution), std::move(post_execution) /*done*/); cb->SetStatus(r); cb->Unref(); PiperOrigin-RevId: 170216176
Diffstat (limited to 'tensorflow/core/util/reffed_status_callback_test.cc')
-rw-r--r--tensorflow/core/util/reffed_status_callback_test.cc111
1 files changed, 111 insertions, 0 deletions
diff --git a/tensorflow/core/util/reffed_status_callback_test.cc b/tensorflow/core/util/reffed_status_callback_test.cc
new file mode 100644
index 0000000000..7e776beb23
--- /dev/null
+++ b/tensorflow/core/util/reffed_status_callback_test.cc
@@ -0,0 +1,111 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <atomic>
+
+#include "tensorflow/core/util/reffed_status_callback.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(TestReffedStatusCallback, CallsBackOK) {
+ bool called = false;
+ Status status = errors::InvalidArgument("");
+ auto done = [&called, &status](const Status& s) {
+ called = true;
+ status = s;
+ };
+ auto* cb = new ReffedStatusCallback(std::move(done));
+ EXPECT_FALSE(called);
+ cb->Unref();
+ EXPECT_TRUE(called);
+ EXPECT_TRUE(status.ok());
+}
+
+TEST(TestReffedStatusCallback, CallsBackFail) {
+ bool called = false;
+ Status status = Status::OK();
+ auto done = [&called, &status](const Status& s) {
+ called = true;
+ status = s;
+ };
+ auto* cb = new ReffedStatusCallback(std::move(done));
+ cb->UpdateStatus(errors::Internal("1"));
+ cb->UpdateStatus(errors::Internal("2")); // Will be ignored.
+ EXPECT_FALSE(called);
+ cb->Unref();
+ EXPECT_TRUE(called);
+ EXPECT_EQ(status.error_message(), "1");
+}
+
+TEST(TestReffedStatusCallback, RefMulti) {
+ int called = false;
+ Status status = Status::OK();
+ auto done = [&called, &status](const Status& s) {
+ called = true;
+ status = s;
+ };
+ auto* cb = new ReffedStatusCallback(std::move(done));
+ cb->Ref();
+ cb->UpdateStatus(errors::Internal("1"));
+ cb->Ref();
+ cb->UpdateStatus(errors::Internal("2")); // Will be ignored.
+ cb->Unref();
+ cb->Unref();
+ EXPECT_FALSE(called);
+ cb->Unref(); // Created by constructor.
+ EXPECT_TRUE(called);
+ EXPECT_EQ(status.error_message(), "1");
+}
+
+TEST(TestReffedStatusCallback, MultiThreaded) {
+ std::atomic<int> num_called(0);
+ Status status;
+ Notification n;
+
+ auto done = [&num_called, &status, &n](const Status& s) {
+ ++num_called;
+ status = s;
+ n.Notify();
+ };
+
+ auto* cb = new ReffedStatusCallback(std::move(done));
+
+ thread::ThreadPool threads(Env::Default(), "test", 3);
+ for (int i = 0; i < 5; ++i) {
+ cb->Ref();
+ threads.Schedule([cb]() {
+ cb->UpdateStatus(errors::InvalidArgument("err"));
+ cb->Unref();
+ });
+ }
+
+ // Subtract one for the initial (construction) reference.
+ cb->Unref();
+
+ n.WaitForNotification();
+
+ EXPECT_EQ(num_called.load(), 1);
+ EXPECT_EQ(status.error_message(), "err");
+}
+
+} // namespace
+} // namespace tensorflow