diff options
7 files changed, 251 insertions, 134 deletions
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/BUILD b/tensorflow/contrib/rpc/python/kernel_tests/BUILD index f3e6731213..2311c15a68 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/BUILD +++ b/tensorflow/contrib/rpc/python/kernel_tests/BUILD @@ -28,7 +28,6 @@ py_library( py_library( name = "rpc_op_test_base", srcs = ["rpc_op_test_base.py"], - tags = ["notsan"], deps = [ ":test_example_proto_py", "//tensorflow/contrib/proto", diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py index e2e0dbc7a2..3fc6bfbb4d 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py +++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py @@ -35,6 +35,7 @@ class RpcOpTest(test.TestCase, rpc_op_test_base.RpcOpTestBase): _protocol = 'grpc' invalid_method_string = 'Method not found' + connect_failed_string = 'Connect Failed' def __init__(self, methodName='runTest'): # pylint: disable=invalid-name super(RpcOpTest, self).__init__(methodName) diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py index 89f3ee1a1c..27273d16b1 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py +++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py @@ -93,40 +93,39 @@ class RpcOpTestBase(object): response_values = sess.run(response_tensors) self.assertAllEqual(response_values.shape, [0]) - def testInvalidAddresses(self): - with self.test_session() as sess: - with self.assertRaisesOpError(self.invalid_method_string): - sess.run( - self.rpc( - method='/InvalidService.IncrementTestShapes', - address=self._address, - request='')) + def testInvalidMethod(self): + for method in [ + '/InvalidService.IncrementTestShapes', + self.get_method_name('InvalidMethodName') + ]: + with self.test_session() as sess: + with self.assertRaisesOpError(self.invalid_method_string): + sess.run(self.rpc(method=method, address=self._address, request='')) - with self.assertRaisesOpError(self.invalid_method_string): - sess.run( - self.rpc( - method=self.get_method_name('InvalidMethodName'), - address=self._address, - request='')) + _, status_code_value, status_message_value = sess.run( + self.try_rpc(method=method, address=self._address, request='')) + self.assertEqual(errors.UNIMPLEMENTED, status_code_value) + self.assertTrue( + self.invalid_method_string in status_message_value.decode('ascii')) - # This also covers the case of address='' - # and address='localhost:293874293874' + def testInvalidAddress(self): + # This covers the case of address='' and address='localhost:293874293874' + address = 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@' + with self.test_session() as sess: with self.assertRaises(errors.UnavailableError): sess.run( self.rpc( method=self.get_method_name('IncrementTestShapes'), - address='unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@', + address=address, request='')) - - # Test invalid method with the TryRpc op _, status_code_value, status_message_value = sess.run( self.try_rpc( - method=self.get_method_name('InvalidMethodName'), - address=self._address, + method=self.get_method_name('IncrementTestShapes'), + address=address, request='')) - self.assertEqual(errors.UNIMPLEMENTED, status_code_value) + self.assertEqual(errors.UNAVAILABLE, status_code_value) self.assertTrue( - self.invalid_method_string in status_message_value.decode('ascii')) + self.connect_failed_string in status_message_value.decode('ascii')) def testAlwaysFailingMethod(self): with self.test_session() as sess: @@ -138,6 +137,18 @@ class RpcOpTestBase(object): with self.assertRaisesOpError(I_WARNED_YOU): sess.run(response_tensors) + response_tensors, status_code, status_message = self.try_rpc( + method=self.get_method_name('AlwaysFailWithInvalidArgument'), + address=self._address, + request='') + self.assertEqual(response_tensors.shape, ()) + self.assertEqual(status_code.shape, ()) + self.assertEqual(status_message.shape, ()) + status_code_value, status_message_value = sess.run((status_code, + status_message)) + self.assertEqual(errors.INVALID_ARGUMENT, status_code_value) + self.assertTrue(I_WARNED_YOU in status_message_value.decode('ascii')) + def testSometimesFailingMethodWithManyRequests(self): with self.test_session() as sess: # Fail hard by default. @@ -197,8 +208,7 @@ class RpcOpTestBase(object): address=self._address, request=request_tensors) for _ in range(10) ] - # Launch parallel 10 calls to the RpcOp, each containing - # 20 rpc requests. + # Launch parallel 10 calls to the RpcOp, each containing 20 rpc requests. many_response_values = sess.run(many_response_tensors) self.assertEqual(10, len(many_response_values)) for response_values in many_response_values: diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc index d004abd1c1..cde6b785dc 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc @@ -30,7 +30,7 @@ limitations under the License. namespace tensorflow { -namespace { +namespace internal { class GrpcCall { public: explicit GrpcCall(CallContainer<GrpcCall>* container, int index, bool try_rpc, @@ -57,9 +57,10 @@ class GrpcCall { container_->Done(s, index_); } + CallOptions* call_opts() { return &call_opts_; } + int index() { return index_; } const string& request() const { return *request_msg_; } string* response() const { return response_msg_; } - CallOptions* call_opts() { return &call_opts_; } private: CallContainer<GrpcCall>* const container_; @@ -72,7 +73,9 @@ class GrpcCall { string* status_message_; }; -} // namespace +} // namespace internal + +using internal::GrpcCall; GrpcRPCFactory::GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast, int64 timeout_in_ms) @@ -110,28 +113,6 @@ void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements, Tensor* response_t, Tensor* status_code_t, Tensor* status_message_t, AsyncOpKernel::DoneCallback done) { - auto address = address_t.flat<string>(); - auto method = method_t.flat<string>(); - auto request = request_t.flat<string>(); - - // Stubs are maintained by the GrpcRPCFactory class and will be - // deleted when the class is destroyed. - ::grpc::GenericStub* singleton_stub = nullptr; - if (address.size() == 1) { - singleton_stub = GetOrCreateStubForAddress(address(0)); - } - auto get_stub = [&address, this, - singleton_stub](int64 ix) -> ::grpc::GenericStub* { - return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix)) - : singleton_stub; - }; - auto get_method_ptr = [&method](int64 ix) -> const string* { - return (method.size() > 1) ? &(method(ix)) : &(method(0)); - }; - auto get_request_ptr = [&request](int64 ix) -> const string* { - return (request.size() > 1) ? &(request(ix)) : &(request(0)); - }; - if (try_rpc) { // In this case status_code will never be set in the response, // so we just set it to OK. @@ -140,49 +121,22 @@ void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements, static_cast<int>(errors::Code::OK)); } - CancellationManager* cm = ctx->cancellation_manager(); - CancellationToken cancellation_token = cm->get_cancellation_token(); - - // This object will delete itself when done. - auto* container = - new CallContainer<GrpcCall>(ctx, num_elements, fail_fast_, try_rpc, - std::move(done), cancellation_token); - - auto response = response_t->flat<string>(); - int32* status_code_ptr = nullptr; - string* status_message_ptr = nullptr; - if (try_rpc) { - status_code_ptr = status_code_t->flat<int32>().data(); - status_message_ptr = status_message_t->flat<string>().data(); - } - for (int i = 0; i < num_elements; ++i) { - container->calls()->emplace_back( - container, i, try_rpc, get_request_ptr(i), &response(i), - (try_rpc) ? &status_code_ptr[i] : nullptr, - (try_rpc) ? &status_message_ptr[i] : nullptr); - } + CallContainer<GrpcCall>::CreateCallFn create_call_fn = + [this, &request_t, &try_rpc, response_t, status_code_t, status_message_t]( + CallContainer<GrpcCall>* container, int index) { + CreateCall(request_t, try_rpc, index, container, response_t, + status_code_t, status_message_t); + }; - int i = 0; - for (GrpcCall& call : *(container->calls())) { - // This object will delete itself when done. - new RPCState<string>(get_stub(i), &completion_queue_, *get_method_ptr(i), - call.request(), call.response(), - /*done=*/[&call](const Status& s) { call.Done(s); }, - call.call_opts(), fail_fast_, timeout_in_ms_); - ++i; - } + CallContainer<GrpcCall>::StartCallFn start_call_fn = + [this, &address_t, &method_t](GrpcCall* call) { + StartCall(address_t, method_t, call); + }; - // Need to register this callback after all the RPCs are in - // flight; otherwise we may try to cancel an RPC *before* it - // launches, which is a no-op, and then fall into a deadlock. - bool is_cancelled = !cm->RegisterCallback( - cancellation_token, [container]() { container->StartCancel(); }); - - if (is_cancelled) { - ctx->SetStatus(errors::Cancelled("Operation has been cancelled.")); - // container's reference counter will take care of calling done(). - container->StartCancel(); - } + // This object will delete itself when done. + new CallContainer<GrpcCall>(ctx, num_elements, fail_fast_, try_rpc, + std::move(done), std::move(create_call_fn), + std::move(start_call_fn)); } ::grpc::GenericStub* GrpcRPCFactory::GetOrCreateStubForAddress( @@ -210,4 +164,53 @@ GrpcRPCFactory::ChannelPtr GrpcRPCFactory::CreateChannelForAddress( /*target=*/address, ::grpc::InsecureChannelCredentials(), args); } +void GrpcRPCFactory::CreateCall(const Tensor& request_t, const bool try_rpc, + int index, CallContainer<GrpcCall>* container, + Tensor* response_t, Tensor* status_code_t, + Tensor* status_message_t) { + auto request = request_t.flat<string>(); + auto get_request_ptr = [&request](int64 ix) -> const string* { + return (request.size() > 1) ? &(request(ix)) : &(request(0)); + }; + auto response = response_t->flat<string>(); + int32* status_code_ptr = nullptr; + string* status_message_ptr = nullptr; + if (try_rpc) { + status_code_ptr = status_code_t->flat<int32>().data(); + status_message_ptr = status_message_t->flat<string>().data(); + } + container->RegisterCall(container, index, try_rpc, get_request_ptr(index), + &response(index), + (try_rpc) ? &status_code_ptr[index] : nullptr, + (try_rpc) ? &status_message_ptr[index] : nullptr); +} + +void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t, + GrpcCall* call) { + auto address = address_t.flat<string>(); + auto method = method_t.flat<string>(); + // Stubs are maintained by the GrpcRPCFactory class and will be + // deleted when the class is destroyed. + ::grpc::GenericStub* singleton_stub = nullptr; + if (address.size() == 1) { + singleton_stub = GetOrCreateStubForAddress(address(0)); + } + auto get_stub = [&address, this, + singleton_stub](int64 ix) -> ::grpc::GenericStub* { + return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix)) + : singleton_stub; + }; + auto get_method_ptr = [&method](int64 ix) -> const string* { + return (method.size() > 1) ? &(method(ix)) : &(method(0)); + }; + + int index = call->index(); + // This object will delete itself when done. + new RPCState<string>(get_stub(index), &completion_queue_, + *get_method_ptr(index), call->request(), + call->response(), + /*done=*/[call](const Status& s) { call->Done(s); }, + call->call_opts(), fail_fast_, timeout_in_ms_); +} + } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h index 34ec235aaf..29394c84b5 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h @@ -20,10 +20,16 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/util/rpc/call_container.h" #include "tensorflow/core/util/rpc/rpc_factory.h" namespace tensorflow { +// Forward declaration of GrpcCall. +namespace internal { +class GrpcCall; +} // namespace internal + class GrpcRPCFactory : public RPCFactory { public: explicit GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast, @@ -42,6 +48,18 @@ class GrpcRPCFactory : public RPCFactory { virtual ChannelPtr CreateChannelForAddress(const string& address); private: + // Creates a call and registers it with given `container`. The `index` is used + // to index into the tensor arguments. + void CreateCall(const Tensor& request_t, const bool try_rpc, int index, + CallContainer<internal::GrpcCall>* container, + Tensor* response_t, Tensor* status_code_t, + Tensor* status_message_t); + + // Asynchronously invokes the given `call`. The call completion is handled + // by the call container the call was previously registered with. + void StartCall(const Tensor& address_t, const Tensor& method_t, + internal::GrpcCall* call); + ::grpc::GenericStub* GetOrCreateStubForAddress(const string& address); bool fail_fast_; diff --git a/tensorflow/core/util/rpc/call_container.h b/tensorflow/core/util/rpc/call_container.h index 7f36056797..e1226a7f16 100644 --- a/tensorflow/core/util/rpc/call_container.h +++ b/tensorflow/core/util/rpc/call_container.h @@ -26,53 +26,60 @@ limitations under the License. namespace tensorflow { -template <typename Call> +namespace internal { +// The following class is used for coordination between a `CallContainer` +// instance and a cancellation callback to make sure that the `CallContainer` +// instance waits for the cancellation callback to be destroyed (either because +// a cancellation occurred or because the callback was deregistered) before +// deleting itself. Without this coordination the cancellation callback could +// attempt to access a `CallContainer` instance that is no longer valid. +class NotifyWhenDestroyed { + public: + explicit NotifyWhenDestroyed(std::shared_ptr<Notification> notification) + : notification_(std::move(notification)) {} + + ~NotifyWhenDestroyed() { notification_->Notify(); } + + private: + std::shared_ptr<Notification> notification_; +}; +} // namespace internal + +// The following class is responsible for the life cycle management of a set of +// RPC calls. The calls are started when an instance of the class is created and +// the class contract guarantees to invoke a "done" callback provided by the +// caller when all RPC calls have either completed or been cancelled. +// +// The caller should not make any assumptions about the validity of an instance +// of this class after the provided callback has been invoked, which may be +// immediately after the instance was created. +template <class Call> class CallContainer { public: + typedef std::function<void(CallContainer<Call>*, int)> CreateCallFn; + typedef std::function<void(Call*)> StartCallFn; + + // Uses the provided `create_call_fn` and `start_call_fn` functions to create + // and start a set of RPC calls. When all RPC calls have either completed or + // been cancelled, the `done` callback is invoked. The caller should not make + // any assumptions about the validity of the created instance as the instance + // will delete itself after invoking the `done` callback. explicit CallContainer(OpKernelContext* ctx, int num_calls, bool fail_fast, bool try_rpc, AsyncOpKernel::DoneCallback done, - CancellationToken token) - : ctx_(ctx), - done_(std::move(done)), - token_(token), - fail_fast_(fail_fast), - try_rpc_(try_rpc) { - CHECK_GT(num_calls, 0); - - // This will run when all RPCs are finished. - reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) { - ctx_->cancellation_manager()->DeregisterCallback(token_); - ctx_->SetStatus(s); - done_(); - delete this; - }); - - // Subtract reference count from the initial creation. - core::ScopedUnref unref(reffed_status_callback_); - - for (int i = 0; i < num_calls; ++i) { - // Increase the reference on the callback for each new RPC. - reffed_status_callback_->Ref(); - } - } + CreateCallFn create_call_fn, + StartCallFn start_call_fn); - std::list<Call>* calls() { return &calls_; } + // Registers a call with this container. This method expects its arguments to + // match those of a `Call` constructor as it forwards them to an underlying + // collection, which creates a `Call` instance in place. + template <class... Args> + void RegisterCall(Args&&... args); - void StartCancel() { - // Once this loop is done, can no longer assume anything is valid - // because "delete this" may have been immediately called. - // Nothing should run after this loop. - for (auto& call : calls_) { - call.StartCancel(); - } - } + // Starts the cancellation of all RPC calls managed by this container. + void StartCancel(); - void Done(const Status& s, int index) { - if (!try_rpc_) { - reffed_status_callback_->UpdateStatus(s); - } - reffed_status_callback_->Unref(); - } + // Indicates that the `index`-th RPC call has finished. + void Done(const Status& s, int index); private: OpKernelContext* ctx_; @@ -81,10 +88,88 @@ class CallContainer { const CancellationToken token_; const bool fail_fast_; const bool try_rpc_; + std::shared_ptr<Notification> callback_destroyed_; // Performs its own reference counting. ReffedStatusCallback* reffed_status_callback_; }; +template <class Call> +CallContainer<Call>::CallContainer( + OpKernelContext* ctx, int num_calls, bool fail_fast, bool try_rpc, + AsyncOpKernel::DoneCallback done, + typename CallContainer<Call>::CreateCallFn create_call_fn, + typename CallContainer<Call>::StartCallFn start_call_fn) + : ctx_(ctx), + done_(std::move(done)), + token_(ctx->cancellation_manager()->get_cancellation_token()), + fail_fast_(fail_fast), + try_rpc_(try_rpc), + callback_destroyed_(new Notification) { + CHECK_GT(num_calls, 0); + + // This will run when all RPCs are finished. + reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) { + ctx_->cancellation_manager()->DeregisterCallback(token_); + ctx_->SetStatus(s); + done_(); + callback_destroyed_->WaitForNotification(); + delete this; + }); + + // The cancellation callback needs to be registered before the RPC calls are + // started to make sure that the callback is properly cleaned up by the + // `reffed_status_callback` when all calls complete. At the same time, the + // cancellation callback should wait for the RPC calls to be started for the + // cancellation to take effect. + std::shared_ptr<internal::NotifyWhenDestroyed> notify_when_destroyed( + new internal::NotifyWhenDestroyed(callback_destroyed_)); + std::shared_ptr<Notification> calls_started(new Notification); + bool is_cancelled = !ctx_->cancellation_manager()->RegisterCallback( + token_, [this, calls_started, notify_when_destroyed]() { + calls_started->WaitForNotification(); + StartCancel(); + }); + + for (int i = 0; i < num_calls; ++i) { + create_call_fn(this, i); + // Increase the reference on the callback for each new RPC. + reffed_status_callback_->Ref(); + } + for (Call& call : calls_) { + start_call_fn(&call); + } + calls_started->Notify(); + + if (is_cancelled) { + ctx_->SetStatus(errors::Cancelled("Operation has been cancelled.")); + StartCancel(); + } + + // Subtract reference count from the initial creation. + reffed_status_callback_->Unref(); +} + +template <class Call> +template <class... Args> +void CallContainer<Call>::RegisterCall(Args&&... args) { + calls_.emplace_back(std::forward<Args>(args)...); +} + +template <class Call> +void CallContainer<Call>::StartCancel() { + for (auto& call : calls_) { + call.StartCancel(); + } +} + +template <class Call> +void CallContainer<Call>::Done(const Status& s, int index) { + if (!try_rpc_) { + reffed_status_callback_->UpdateStatus(s); + } + reffed_status_callback_->Unref(); +} + } // namespace tensorflow #endif // TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_ diff --git a/tensorflow/core/util/rpc/rpc_factory.h b/tensorflow/core/util/rpc/rpc_factory.h index 9bf078c0f4..c4eaaf4457 100644 --- a/tensorflow/core/util/rpc/rpc_factory.h +++ b/tensorflow/core/util/rpc/rpc_factory.h @@ -32,10 +32,11 @@ class RPCFactory { RPCFactory() {} virtual ~RPCFactory() {} - // Start a Call() to methods `method_t` at addresses `address_t` with + // Asynchronously invokes methods `method_t` at addresses `address_t` with // request strings from `request_t`. Any of these may be scalar // Tensors, in which case the operands are broadcasted. - // Upon completion of all requests, `response_t` will be populated. + // Upon completion of all requests, `response_t` will be populated and the + // `done` callback will be invoked. // // If `try_rpc` is `true`, then `status_message_t` and // `status_code_t` will be populated as well. |