diff options
author | 2018-04-24 13:13:18 -0700 | |
---|---|---|
committer | 2018-04-24 13:16:00 -0700 | |
commit | 893aa776009418c841d49c924207f3cdaf1d5174 (patch) | |
tree | a11ab86db407e72161d9b4734128604ae3492052 /tensorflow/core/util/rpc | |
parent | 33ffc8e7ff5090b92951c7faac150042dd814085 (diff) |
Fixing concurrency issues in RPC factory.
PiperOrigin-RevId: 194133903
Diffstat (limited to 'tensorflow/core/util/rpc')
-rw-r--r-- | tensorflow/core/util/rpc/call_container.h | 165 | ||||
-rw-r--r-- | tensorflow/core/util/rpc/rpc_factory.h | 5 |
2 files changed, 128 insertions, 42 deletions
diff --git a/tensorflow/core/util/rpc/call_container.h b/tensorflow/core/util/rpc/call_container.h index 7f36056797..e1226a7f16 100644 --- a/tensorflow/core/util/rpc/call_container.h +++ b/tensorflow/core/util/rpc/call_container.h @@ -26,53 +26,60 @@ limitations under the License. namespace tensorflow { -template <typename Call> +namespace internal { +// The following class is used for coordination between a `CallContainer` +// instance and a cancellation callback to make sure that the `CallContainer` +// instance waits for the cancellation callback to be destroyed (either because +// a cancellation occurred or because the callback was deregistered) before +// deleting itself. Without this coordination the cancellation callback could +// attempt to access a `CallContainer` instance that is no longer valid. +class NotifyWhenDestroyed { + public: + explicit NotifyWhenDestroyed(std::shared_ptr<Notification> notification) + : notification_(std::move(notification)) {} + + ~NotifyWhenDestroyed() { notification_->Notify(); } + + private: + std::shared_ptr<Notification> notification_; +}; +} // namespace internal + +// The following class is responsible for the life cycle management of a set of +// RPC calls. The calls are started when an instance of the class is created and +// the class contract guarantees to invoke a "done" callback provided by the +// caller when all RPC calls have either completed or been cancelled. +// +// The caller should not make any assumptions about the validity of an instance +// of this class after the provided callback has been invoked, which may be +// immediately after the instance was created. +template <class Call> class CallContainer { public: + typedef std::function<void(CallContainer<Call>*, int)> CreateCallFn; + typedef std::function<void(Call*)> StartCallFn; + + // Uses the provided `create_call_fn` and `start_call_fn` functions to create + // and start a set of RPC calls. When all RPC calls have either completed or + // been cancelled, the `done` callback is invoked. The caller should not make + // any assumptions about the validity of the created instance as the instance + // will delete itself after invoking the `done` callback. explicit CallContainer(OpKernelContext* ctx, int num_calls, bool fail_fast, bool try_rpc, AsyncOpKernel::DoneCallback done, - CancellationToken token) - : ctx_(ctx), - done_(std::move(done)), - token_(token), - fail_fast_(fail_fast), - try_rpc_(try_rpc) { - CHECK_GT(num_calls, 0); - - // This will run when all RPCs are finished. - reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) { - ctx_->cancellation_manager()->DeregisterCallback(token_); - ctx_->SetStatus(s); - done_(); - delete this; - }); - - // Subtract reference count from the initial creation. - core::ScopedUnref unref(reffed_status_callback_); - - for (int i = 0; i < num_calls; ++i) { - // Increase the reference on the callback for each new RPC. - reffed_status_callback_->Ref(); - } - } + CreateCallFn create_call_fn, + StartCallFn start_call_fn); - std::list<Call>* calls() { return &calls_; } + // Registers a call with this container. This method expects its arguments to + // match those of a `Call` constructor as it forwards them to an underlying + // collection, which creates a `Call` instance in place. + template <class... Args> + void RegisterCall(Args&&... args); - void StartCancel() { - // Once this loop is done, can no longer assume anything is valid - // because "delete this" may have been immediately called. - // Nothing should run after this loop. - for (auto& call : calls_) { - call.StartCancel(); - } - } + // Starts the cancellation of all RPC calls managed by this container. + void StartCancel(); - void Done(const Status& s, int index) { - if (!try_rpc_) { - reffed_status_callback_->UpdateStatus(s); - } - reffed_status_callback_->Unref(); - } + // Indicates that the `index`-th RPC call has finished. + void Done(const Status& s, int index); private: OpKernelContext* ctx_; @@ -81,10 +88,88 @@ class CallContainer { const CancellationToken token_; const bool fail_fast_; const bool try_rpc_; + std::shared_ptr<Notification> callback_destroyed_; // Performs its own reference counting. ReffedStatusCallback* reffed_status_callback_; }; +template <class Call> +CallContainer<Call>::CallContainer( + OpKernelContext* ctx, int num_calls, bool fail_fast, bool try_rpc, + AsyncOpKernel::DoneCallback done, + typename CallContainer<Call>::CreateCallFn create_call_fn, + typename CallContainer<Call>::StartCallFn start_call_fn) + : ctx_(ctx), + done_(std::move(done)), + token_(ctx->cancellation_manager()->get_cancellation_token()), + fail_fast_(fail_fast), + try_rpc_(try_rpc), + callback_destroyed_(new Notification) { + CHECK_GT(num_calls, 0); + + // This will run when all RPCs are finished. + reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) { + ctx_->cancellation_manager()->DeregisterCallback(token_); + ctx_->SetStatus(s); + done_(); + callback_destroyed_->WaitForNotification(); + delete this; + }); + + // The cancellation callback needs to be registered before the RPC calls are + // started to make sure that the callback is properly cleaned up by the + // `reffed_status_callback` when all calls complete. At the same time, the + // cancellation callback should wait for the RPC calls to be started for the + // cancellation to take effect. + std::shared_ptr<internal::NotifyWhenDestroyed> notify_when_destroyed( + new internal::NotifyWhenDestroyed(callback_destroyed_)); + std::shared_ptr<Notification> calls_started(new Notification); + bool is_cancelled = !ctx_->cancellation_manager()->RegisterCallback( + token_, [this, calls_started, notify_when_destroyed]() { + calls_started->WaitForNotification(); + StartCancel(); + }); + + for (int i = 0; i < num_calls; ++i) { + create_call_fn(this, i); + // Increase the reference on the callback for each new RPC. + reffed_status_callback_->Ref(); + } + for (Call& call : calls_) { + start_call_fn(&call); + } + calls_started->Notify(); + + if (is_cancelled) { + ctx_->SetStatus(errors::Cancelled("Operation has been cancelled.")); + StartCancel(); + } + + // Subtract reference count from the initial creation. + reffed_status_callback_->Unref(); +} + +template <class Call> +template <class... Args> +void CallContainer<Call>::RegisterCall(Args&&... args) { + calls_.emplace_back(std::forward<Args>(args)...); +} + +template <class Call> +void CallContainer<Call>::StartCancel() { + for (auto& call : calls_) { + call.StartCancel(); + } +} + +template <class Call> +void CallContainer<Call>::Done(const Status& s, int index) { + if (!try_rpc_) { + reffed_status_callback_->UpdateStatus(s); + } + reffed_status_callback_->Unref(); +} + } // namespace tensorflow #endif // TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_ diff --git a/tensorflow/core/util/rpc/rpc_factory.h b/tensorflow/core/util/rpc/rpc_factory.h index 9bf078c0f4..c4eaaf4457 100644 --- a/tensorflow/core/util/rpc/rpc_factory.h +++ b/tensorflow/core/util/rpc/rpc_factory.h @@ -32,10 +32,11 @@ class RPCFactory { RPCFactory() {} virtual ~RPCFactory() {} - // Start a Call() to methods `method_t` at addresses `address_t` with + // Asynchronously invokes methods `method_t` at addresses `address_t` with // request strings from `request_t`. Any of these may be scalar // Tensors, in which case the operands are broadcasted. - // Upon completion of all requests, `response_t` will be populated. + // Upon completion of all requests, `response_t` will be populated and the + // `done` callback will be invoked. // // If `try_rpc` is `true`, then `status_message_t` and // `status_code_t` will be populated as well. |