diff options
-rw-r--r-- | tensorflow/core/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/core/util/reffed_status_callback.h | 56 | ||||
-rw-r--r-- | tensorflow/core/util/reffed_status_callback_test.cc | 111 |
3 files changed, 169 insertions, 0 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index a757a31de9..5502eebd7f 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -445,6 +445,7 @@ tf_cuda_library( "util/mirror_pad_mode.h", "util/padding.h", "util/port.h", + "util/reffed_status_callback.h", "util/saved_tensor_slice_util.h", "util/sparse/group_iterator.h", "util/sparse/sparse_tensor.h", @@ -2575,6 +2576,7 @@ tf_cc_tests( "util/example_proto_helper_test.cc", "util/memmapped_file_system_test.cc", "util/presized_cuckoo_map_test.cc", + "util/reffed_status_callback_test.cc", "util/reporter_test.cc", "util/saved_tensor_slice_util_test.cc", "util/semver_test.cc", diff --git a/tensorflow/core/util/reffed_status_callback.h b/tensorflow/core/util/reffed_status_callback.h new file mode 100644 index 0000000000..c31b42d1e6 --- /dev/null +++ b/tensorflow/core/util/reffed_status_callback.h @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_REFFED_STATUS_CALLBACK_H_ +#define TENSORFLOW_CORE_UTIL_REFFED_STATUS_CALLBACK_H_ + +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// The ReffedStatusCallback is a refcounted object that accepts a +// StatusCallback. When it is destroyed (its refcount goes to 0), the +// StatusCallback is called with the first non-OK status passed to +// UpdateStatus(), or Status::OK() if no non-OK status was set. +class ReffedStatusCallback : public core::RefCounted { + public: + explicit ReffedStatusCallback(StatusCallback done) + : done_(std::move(done)), status_(Status::OK()) {} + + void UpdateStatus(const Status& s) { + if (!s.ok()) { + mutex_lock lock(mu_); + if (status_.ok()) status_.Update(s); + } + } + + bool ok() { + mutex_lock lock(mu_); + return status_.ok(); + } + + ~ReffedStatusCallback() { done_(status_); } + + private: + StatusCallback done_; + mutex mu_; + Status status_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_REFFED_STATUS_CALLBACK_H_ diff --git a/tensorflow/core/util/reffed_status_callback_test.cc b/tensorflow/core/util/reffed_status_callback_test.cc new file mode 100644 index 0000000000..7e776beb23 --- /dev/null +++ b/tensorflow/core/util/reffed_status_callback_test.cc @@ -0,0 +1,111 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <atomic> + +#include "tensorflow/core/util/reffed_status_callback.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(TestReffedStatusCallback, CallsBackOK) { + bool called = false; + Status status = errors::InvalidArgument(""); + auto done = [&called, &status](const Status& s) { + called = true; + status = s; + }; + auto* cb = new ReffedStatusCallback(std::move(done)); + EXPECT_FALSE(called); + cb->Unref(); + EXPECT_TRUE(called); + EXPECT_TRUE(status.ok()); +} + +TEST(TestReffedStatusCallback, CallsBackFail) { + bool called = false; + Status status = Status::OK(); + auto done = [&called, &status](const Status& s) { + called = true; + status = s; + }; + auto* cb = new ReffedStatusCallback(std::move(done)); + cb->UpdateStatus(errors::Internal("1")); + cb->UpdateStatus(errors::Internal("2")); // Will be ignored. + EXPECT_FALSE(called); + cb->Unref(); + EXPECT_TRUE(called); + EXPECT_EQ(status.error_message(), "1"); +} + +TEST(TestReffedStatusCallback, RefMulti) { + int called = false; + Status status = Status::OK(); + auto done = [&called, &status](const Status& s) { + called = true; + status = s; + }; + auto* cb = new ReffedStatusCallback(std::move(done)); + cb->Ref(); + cb->UpdateStatus(errors::Internal("1")); + cb->Ref(); + cb->UpdateStatus(errors::Internal("2")); // Will be ignored. + cb->Unref(); + cb->Unref(); + EXPECT_FALSE(called); + cb->Unref(); // Created by constructor. + EXPECT_TRUE(called); + EXPECT_EQ(status.error_message(), "1"); +} + +TEST(TestReffedStatusCallback, MultiThreaded) { + std::atomic<int> num_called(0); + Status status; + Notification n; + + auto done = [&num_called, &status, &n](const Status& s) { + ++num_called; + status = s; + n.Notify(); + }; + + auto* cb = new ReffedStatusCallback(std::move(done)); + + thread::ThreadPool threads(Env::Default(), "test", 3); + for (int i = 0; i < 5; ++i) { + cb->Ref(); + threads.Schedule([cb]() { + cb->UpdateStatus(errors::InvalidArgument("err")); + cb->Unref(); + }); + } + + // Subtract one for the initial (construction) reference. + cb->Unref(); + + n.WaitForNotification(); + + EXPECT_EQ(num_called.load(), 1); + EXPECT_EQ(status.error_message(), "err"); +} + +} // namespace +} // namespace tensorflow |