aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/BUILD2
-rw-r--r--tensorflow/core/util/reffed_status_callback.h56
-rw-r--r--tensorflow/core/util/reffed_status_callback_test.cc111
3 files changed, 169 insertions, 0 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index a757a31de9..5502eebd7f 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -445,6 +445,7 @@ tf_cuda_library(
"util/mirror_pad_mode.h",
"util/padding.h",
"util/port.h",
+ "util/reffed_status_callback.h",
"util/saved_tensor_slice_util.h",
"util/sparse/group_iterator.h",
"util/sparse/sparse_tensor.h",
@@ -2575,6 +2576,7 @@ tf_cc_tests(
"util/example_proto_helper_test.cc",
"util/memmapped_file_system_test.cc",
"util/presized_cuckoo_map_test.cc",
+ "util/reffed_status_callback_test.cc",
"util/reporter_test.cc",
"util/saved_tensor_slice_util_test.cc",
"util/semver_test.cc",
diff --git a/tensorflow/core/util/reffed_status_callback.h b/tensorflow/core/util/reffed_status_callback.h
new file mode 100644
index 0000000000..c31b42d1e6
--- /dev/null
+++ b/tensorflow/core/util/reffed_status_callback.h
@@ -0,0 +1,56 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_REFFED_STATUS_CALLBACK_H_
+#define TENSORFLOW_CORE_UTIL_REFFED_STATUS_CALLBACK_H_
+
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+// The ReffedStatusCallback is a refcounted object that accepts a
+// StatusCallback. When it is destroyed (its refcount goes to 0), the
+// StatusCallback is called with the first non-OK status passed to
+// UpdateStatus(), or Status::OK() if no non-OK status was set.
+class ReffedStatusCallback : public core::RefCounted {
+ public:
+ explicit ReffedStatusCallback(StatusCallback done)
+ : done_(std::move(done)), status_(Status::OK()) {}
+
+ void UpdateStatus(const Status& s) {
+ if (!s.ok()) {
+ mutex_lock lock(mu_);
+ if (status_.ok()) status_.Update(s);
+ }
+ }
+
+ bool ok() {
+ mutex_lock lock(mu_);
+ return status_.ok();
+ }
+
+ ~ReffedStatusCallback() { done_(status_); }
+
+ private:
+ StatusCallback done_;
+ mutex mu_;
+ Status status_ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_UTIL_REFFED_STATUS_CALLBACK_H_
diff --git a/tensorflow/core/util/reffed_status_callback_test.cc b/tensorflow/core/util/reffed_status_callback_test.cc
new file mode 100644
index 0000000000..7e776beb23
--- /dev/null
+++ b/tensorflow/core/util/reffed_status_callback_test.cc
@@ -0,0 +1,111 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <atomic>
+
+#include "tensorflow/core/util/reffed_status_callback.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(TestReffedStatusCallback, CallsBackOK) {
+ bool called = false;
+ Status status = errors::InvalidArgument("");
+ auto done = [&called, &status](const Status& s) {
+ called = true;
+ status = s;
+ };
+ auto* cb = new ReffedStatusCallback(std::move(done));
+ EXPECT_FALSE(called);
+ cb->Unref();
+ EXPECT_TRUE(called);
+ EXPECT_TRUE(status.ok());
+}
+
+TEST(TestReffedStatusCallback, CallsBackFail) {
+ bool called = false;
+ Status status = Status::OK();
+ auto done = [&called, &status](const Status& s) {
+ called = true;
+ status = s;
+ };
+ auto* cb = new ReffedStatusCallback(std::move(done));
+ cb->UpdateStatus(errors::Internal("1"));
+ cb->UpdateStatus(errors::Internal("2")); // Will be ignored.
+ EXPECT_FALSE(called);
+ cb->Unref();
+ EXPECT_TRUE(called);
+ EXPECT_EQ(status.error_message(), "1");
+}
+
+TEST(TestReffedStatusCallback, RefMulti) {
+ int called = false;
+ Status status = Status::OK();
+ auto done = [&called, &status](const Status& s) {
+ called = true;
+ status = s;
+ };
+ auto* cb = new ReffedStatusCallback(std::move(done));
+ cb->Ref();
+ cb->UpdateStatus(errors::Internal("1"));
+ cb->Ref();
+ cb->UpdateStatus(errors::Internal("2")); // Will be ignored.
+ cb->Unref();
+ cb->Unref();
+ EXPECT_FALSE(called);
+ cb->Unref(); // Created by constructor.
+ EXPECT_TRUE(called);
+ EXPECT_EQ(status.error_message(), "1");
+}
+
+TEST(TestReffedStatusCallback, MultiThreaded) {
+ std::atomic<int> num_called(0);
+ Status status;
+ Notification n;
+
+ auto done = [&num_called, &status, &n](const Status& s) {
+ ++num_called;
+ status = s;
+ n.Notify();
+ };
+
+ auto* cb = new ReffedStatusCallback(std::move(done));
+
+ thread::ThreadPool threads(Env::Default(), "test", 3);
+ for (int i = 0; i < 5; ++i) {
+ cb->Ref();
+ threads.Schedule([cb]() {
+ cb->UpdateStatus(errors::InvalidArgument("err"));
+ cb->Unref();
+ });
+ }
+
+ // Subtract one for the initial (construction) reference.
+ cb->Unref();
+
+ n.WaitForNotification();
+
+ EXPECT_EQ(num_called.load(), 1);
+ EXPECT_EQ(status.error_message(), "err");
+}
+
+} // namespace
+} // namespace tensorflow