aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/rpc
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-04-24 13:13:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-24 13:16:00 -0700
commit893aa776009418c841d49c924207f3cdaf1d5174 (patch)
treea11ab86db407e72161d9b4734128604ae3492052 /tensorflow/core/util/rpc
parent33ffc8e7ff5090b92951c7faac150042dd814085 (diff)
Fixing concurrency issues in RPC factory.
PiperOrigin-RevId: 194133903
Diffstat (limited to 'tensorflow/core/util/rpc')
-rw-r--r--tensorflow/core/util/rpc/call_container.h165
-rw-r--r--tensorflow/core/util/rpc/rpc_factory.h5
2 files changed, 128 insertions, 42 deletions
diff --git a/tensorflow/core/util/rpc/call_container.h b/tensorflow/core/util/rpc/call_container.h
index 7f36056797..e1226a7f16 100644
--- a/tensorflow/core/util/rpc/call_container.h
+++ b/tensorflow/core/util/rpc/call_container.h
@@ -26,53 +26,60 @@ limitations under the License.
namespace tensorflow {
-template <typename Call>
+namespace internal {
+// The following class is used for coordination between a `CallContainer`
+// instance and a cancellation callback to make sure that the `CallContainer`
+// instance waits for the cancellation callback to be destroyed (either because
+// a cancellation occurred or because the callback was deregistered) before
+// deleting itself. Without this coordination the cancellation callback could
+// attempt to access a `CallContainer` instance that is no longer valid.
+class NotifyWhenDestroyed {
+ public:
+ explicit NotifyWhenDestroyed(std::shared_ptr<Notification> notification)
+ : notification_(std::move(notification)) {}
+
+ ~NotifyWhenDestroyed() { notification_->Notify(); }
+
+ private:
+ std::shared_ptr<Notification> notification_;
+};
+} // namespace internal
+
+// The following class is responsible for the life cycle management of a set of
+// RPC calls. The calls are started when an instance of the class is created and
+// the class contract guarantees to invoke a "done" callback provided by the
+// caller when all RPC calls have either completed or been cancelled.
+//
+// The caller should not make any assumptions about the validity of an instance
+// of this class after the provided callback has been invoked, which may be
+// immediately after the instance was created.
+template <class Call>
class CallContainer {
public:
+ typedef std::function<void(CallContainer<Call>*, int)> CreateCallFn;
+ typedef std::function<void(Call*)> StartCallFn;
+
+ // Uses the provided `create_call_fn` and `start_call_fn` functions to create
+ // and start a set of RPC calls. When all RPC calls have either completed or
+ // been cancelled, the `done` callback is invoked. The caller should not make
+ // any assumptions about the validity of the created instance as the instance
+ // will delete itself after invoking the `done` callback.
explicit CallContainer(OpKernelContext* ctx, int num_calls, bool fail_fast,
bool try_rpc, AsyncOpKernel::DoneCallback done,
- CancellationToken token)
- : ctx_(ctx),
- done_(std::move(done)),
- token_(token),
- fail_fast_(fail_fast),
- try_rpc_(try_rpc) {
- CHECK_GT(num_calls, 0);
-
- // This will run when all RPCs are finished.
- reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) {
- ctx_->cancellation_manager()->DeregisterCallback(token_);
- ctx_->SetStatus(s);
- done_();
- delete this;
- });
-
- // Subtract reference count from the initial creation.
- core::ScopedUnref unref(reffed_status_callback_);
-
- for (int i = 0; i < num_calls; ++i) {
- // Increase the reference on the callback for each new RPC.
- reffed_status_callback_->Ref();
- }
- }
+ CreateCallFn create_call_fn,
+ StartCallFn start_call_fn);
- std::list<Call>* calls() { return &calls_; }
+ // Registers a call with this container. This method expects its arguments to
+ // match those of a `Call` constructor as it forwards them to an underlying
+ // collection, which creates a `Call` instance in place.
+ template <class... Args>
+ void RegisterCall(Args&&... args);
- void StartCancel() {
- // Once this loop is done, can no longer assume anything is valid
- // because "delete this" may have been immediately called.
- // Nothing should run after this loop.
- for (auto& call : calls_) {
- call.StartCancel();
- }
- }
+ // Starts the cancellation of all RPC calls managed by this container.
+ void StartCancel();
- void Done(const Status& s, int index) {
- if (!try_rpc_) {
- reffed_status_callback_->UpdateStatus(s);
- }
- reffed_status_callback_->Unref();
- }
+ // Indicates that the `index`-th RPC call has finished.
+ void Done(const Status& s, int index);
private:
OpKernelContext* ctx_;
@@ -81,10 +88,88 @@ class CallContainer {
const CancellationToken token_;
const bool fail_fast_;
const bool try_rpc_;
+ std::shared_ptr<Notification> callback_destroyed_;
// Performs its own reference counting.
ReffedStatusCallback* reffed_status_callback_;
};
+template <class Call>
+CallContainer<Call>::CallContainer(
+ OpKernelContext* ctx, int num_calls, bool fail_fast, bool try_rpc,
+ AsyncOpKernel::DoneCallback done,
+ typename CallContainer<Call>::CreateCallFn create_call_fn,
+ typename CallContainer<Call>::StartCallFn start_call_fn)
+ : ctx_(ctx),
+ done_(std::move(done)),
+ token_(ctx->cancellation_manager()->get_cancellation_token()),
+ fail_fast_(fail_fast),
+ try_rpc_(try_rpc),
+ callback_destroyed_(new Notification) {
+ CHECK_GT(num_calls, 0);
+
+ // This will run when all RPCs are finished.
+ reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) {
+ ctx_->cancellation_manager()->DeregisterCallback(token_);
+ ctx_->SetStatus(s);
+ done_();
+ callback_destroyed_->WaitForNotification();
+ delete this;
+ });
+
+ // The cancellation callback needs to be registered before the RPC calls are
+ // started to make sure that the callback is properly cleaned up by the
+ // `reffed_status_callback` when all calls complete. At the same time, the
+ // cancellation callback should wait for the RPC calls to be started for the
+ // cancellation to take effect.
+ std::shared_ptr<internal::NotifyWhenDestroyed> notify_when_destroyed(
+ new internal::NotifyWhenDestroyed(callback_destroyed_));
+ std::shared_ptr<Notification> calls_started(new Notification);
+ bool is_cancelled = !ctx_->cancellation_manager()->RegisterCallback(
+ token_, [this, calls_started, notify_when_destroyed]() {
+ calls_started->WaitForNotification();
+ StartCancel();
+ });
+
+ for (int i = 0; i < num_calls; ++i) {
+ create_call_fn(this, i);
+ // Increase the reference on the callback for each new RPC.
+ reffed_status_callback_->Ref();
+ }
+ for (Call& call : calls_) {
+ start_call_fn(&call);
+ }
+ calls_started->Notify();
+
+ if (is_cancelled) {
+ ctx_->SetStatus(errors::Cancelled("Operation has been cancelled."));
+ StartCancel();
+ }
+
+ // Subtract reference count from the initial creation.
+ reffed_status_callback_->Unref();
+}
+
+template <class Call>
+template <class... Args>
+void CallContainer<Call>::RegisterCall(Args&&... args) {
+ calls_.emplace_back(std::forward<Args>(args)...);
+}
+
+template <class Call>
+void CallContainer<Call>::StartCancel() {
+ for (auto& call : calls_) {
+ call.StartCancel();
+ }
+}
+
+template <class Call>
+void CallContainer<Call>::Done(const Status& s, int index) {
+ if (!try_rpc_) {
+ reffed_status_callback_->UpdateStatus(s);
+ }
+ reffed_status_callback_->Unref();
+}
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
diff --git a/tensorflow/core/util/rpc/rpc_factory.h b/tensorflow/core/util/rpc/rpc_factory.h
index 9bf078c0f4..c4eaaf4457 100644
--- a/tensorflow/core/util/rpc/rpc_factory.h
+++ b/tensorflow/core/util/rpc/rpc_factory.h
@@ -32,10 +32,11 @@ class RPCFactory {
RPCFactory() {}
virtual ~RPCFactory() {}
- // Start a Call() to methods `method_t` at addresses `address_t` with
+ // Asynchronously invokes methods `method_t` at addresses `address_t` with
// request strings from `request_t`. Any of these may be scalar
// Tensors, in which case the operands are broadcasted.
- // Upon completion of all requests, `response_t` will be populated.
+ // Upon completion of all requests, `response_t` will be populated and the
+ // `done` callback will be invoked.
//
// If `try_rpc` is `true`, then `status_message_t` and
// `status_code_t` will be populated as well.