aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py1
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py60
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc135
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h18
-rw-r--r--tensorflow/core/util/rpc/call_container.h165
-rw-r--r--tensorflow/core/util/rpc/rpc_factory.h5
7 files changed, 251 insertions, 134 deletions
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/BUILD b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
index f3e6731213..2311c15a68 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
@@ -28,7 +28,6 @@ py_library(
py_library(
name = "rpc_op_test_base",
srcs = ["rpc_op_test_base.py"],
- tags = ["notsan"],
deps = [
":test_example_proto_py",
"//tensorflow/contrib/proto",
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py
index e2e0dbc7a2..3fc6bfbb4d 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py
@@ -35,6 +35,7 @@ class RpcOpTest(test.TestCase, rpc_op_test_base.RpcOpTestBase):
_protocol = 'grpc'
invalid_method_string = 'Method not found'
+ connect_failed_string = 'Connect Failed'
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
super(RpcOpTest, self).__init__(methodName)
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
index 89f3ee1a1c..27273d16b1 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
@@ -93,40 +93,39 @@ class RpcOpTestBase(object):
response_values = sess.run(response_tensors)
self.assertAllEqual(response_values.shape, [0])
- def testInvalidAddresses(self):
- with self.test_session() as sess:
- with self.assertRaisesOpError(self.invalid_method_string):
- sess.run(
- self.rpc(
- method='/InvalidService.IncrementTestShapes',
- address=self._address,
- request=''))
+ def testInvalidMethod(self):
+ for method in [
+ '/InvalidService.IncrementTestShapes',
+ self.get_method_name('InvalidMethodName')
+ ]:
+ with self.test_session() as sess:
+ with self.assertRaisesOpError(self.invalid_method_string):
+ sess.run(self.rpc(method=method, address=self._address, request=''))
- with self.assertRaisesOpError(self.invalid_method_string):
- sess.run(
- self.rpc(
- method=self.get_method_name('InvalidMethodName'),
- address=self._address,
- request=''))
+ _, status_code_value, status_message_value = sess.run(
+ self.try_rpc(method=method, address=self._address, request=''))
+ self.assertEqual(errors.UNIMPLEMENTED, status_code_value)
+ self.assertTrue(
+ self.invalid_method_string in status_message_value.decode('ascii'))
- # This also covers the case of address=''
- # and address='localhost:293874293874'
+ def testInvalidAddress(self):
+ # This covers the case of address='' and address='localhost:293874293874'
+ address = 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
+ with self.test_session() as sess:
with self.assertRaises(errors.UnavailableError):
sess.run(
self.rpc(
method=self.get_method_name('IncrementTestShapes'),
- address='unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@',
+ address=address,
request=''))
-
- # Test invalid method with the TryRpc op
_, status_code_value, status_message_value = sess.run(
self.try_rpc(
- method=self.get_method_name('InvalidMethodName'),
- address=self._address,
+ method=self.get_method_name('IncrementTestShapes'),
+ address=address,
request=''))
- self.assertEqual(errors.UNIMPLEMENTED, status_code_value)
+ self.assertEqual(errors.UNAVAILABLE, status_code_value)
self.assertTrue(
- self.invalid_method_string in status_message_value.decode('ascii'))
+ self.connect_failed_string in status_message_value.decode('ascii'))
def testAlwaysFailingMethod(self):
with self.test_session() as sess:
@@ -138,6 +137,18 @@ class RpcOpTestBase(object):
with self.assertRaisesOpError(I_WARNED_YOU):
sess.run(response_tensors)
+ response_tensors, status_code, status_message = self.try_rpc(
+ method=self.get_method_name('AlwaysFailWithInvalidArgument'),
+ address=self._address,
+ request='')
+ self.assertEqual(response_tensors.shape, ())
+ self.assertEqual(status_code.shape, ())
+ self.assertEqual(status_message.shape, ())
+ status_code_value, status_message_value = sess.run((status_code,
+ status_message))
+ self.assertEqual(errors.INVALID_ARGUMENT, status_code_value)
+ self.assertTrue(I_WARNED_YOU in status_message_value.decode('ascii'))
+
def testSometimesFailingMethodWithManyRequests(self):
with self.test_session() as sess:
# Fail hard by default.
@@ -197,8 +208,7 @@ class RpcOpTestBase(object):
address=self._address,
request=request_tensors) for _ in range(10)
]
- # Launch parallel 10 calls to the RpcOp, each containing
- # 20 rpc requests.
+ # Launch parallel 10 calls to the RpcOp, each containing 20 rpc requests.
many_response_values = sess.run(many_response_tensors)
self.assertEqual(10, len(many_response_values))
for response_values in many_response_values:
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
index d004abd1c1..cde6b785dc 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
@@ -30,7 +30,7 @@ limitations under the License.
namespace tensorflow {
-namespace {
+namespace internal {
class GrpcCall {
public:
explicit GrpcCall(CallContainer<GrpcCall>* container, int index, bool try_rpc,
@@ -57,9 +57,10 @@ class GrpcCall {
container_->Done(s, index_);
}
+ CallOptions* call_opts() { return &call_opts_; }
+ int index() { return index_; }
const string& request() const { return *request_msg_; }
string* response() const { return response_msg_; }
- CallOptions* call_opts() { return &call_opts_; }
private:
CallContainer<GrpcCall>* const container_;
@@ -72,7 +73,9 @@ class GrpcCall {
string* status_message_;
};
-} // namespace
+} // namespace internal
+
+using internal::GrpcCall;
GrpcRPCFactory::GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
int64 timeout_in_ms)
@@ -110,28 +113,6 @@ void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements,
Tensor* response_t, Tensor* status_code_t,
Tensor* status_message_t,
AsyncOpKernel::DoneCallback done) {
- auto address = address_t.flat<string>();
- auto method = method_t.flat<string>();
- auto request = request_t.flat<string>();
-
- // Stubs are maintained by the GrpcRPCFactory class and will be
- // deleted when the class is destroyed.
- ::grpc::GenericStub* singleton_stub = nullptr;
- if (address.size() == 1) {
- singleton_stub = GetOrCreateStubForAddress(address(0));
- }
- auto get_stub = [&address, this,
- singleton_stub](int64 ix) -> ::grpc::GenericStub* {
- return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
- : singleton_stub;
- };
- auto get_method_ptr = [&method](int64 ix) -> const string* {
- return (method.size() > 1) ? &(method(ix)) : &(method(0));
- };
- auto get_request_ptr = [&request](int64 ix) -> const string* {
- return (request.size() > 1) ? &(request(ix)) : &(request(0));
- };
-
if (try_rpc) {
// In this case status_code will never be set in the response,
// so we just set it to OK.
@@ -140,49 +121,22 @@ void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements,
static_cast<int>(errors::Code::OK));
}
- CancellationManager* cm = ctx->cancellation_manager();
- CancellationToken cancellation_token = cm->get_cancellation_token();
-
- // This object will delete itself when done.
- auto* container =
- new CallContainer<GrpcCall>(ctx, num_elements, fail_fast_, try_rpc,
- std::move(done), cancellation_token);
-
- auto response = response_t->flat<string>();
- int32* status_code_ptr = nullptr;
- string* status_message_ptr = nullptr;
- if (try_rpc) {
- status_code_ptr = status_code_t->flat<int32>().data();
- status_message_ptr = status_message_t->flat<string>().data();
- }
- for (int i = 0; i < num_elements; ++i) {
- container->calls()->emplace_back(
- container, i, try_rpc, get_request_ptr(i), &response(i),
- (try_rpc) ? &status_code_ptr[i] : nullptr,
- (try_rpc) ? &status_message_ptr[i] : nullptr);
- }
+ CallContainer<GrpcCall>::CreateCallFn create_call_fn =
+ [this, &request_t, &try_rpc, response_t, status_code_t, status_message_t](
+ CallContainer<GrpcCall>* container, int index) {
+ CreateCall(request_t, try_rpc, index, container, response_t,
+ status_code_t, status_message_t);
+ };
- int i = 0;
- for (GrpcCall& call : *(container->calls())) {
- // This object will delete itself when done.
- new RPCState<string>(get_stub(i), &completion_queue_, *get_method_ptr(i),
- call.request(), call.response(),
- /*done=*/[&call](const Status& s) { call.Done(s); },
- call.call_opts(), fail_fast_, timeout_in_ms_);
- ++i;
- }
+ CallContainer<GrpcCall>::StartCallFn start_call_fn =
+ [this, &address_t, &method_t](GrpcCall* call) {
+ StartCall(address_t, method_t, call);
+ };
- // Need to register this callback after all the RPCs are in
- // flight; otherwise we may try to cancel an RPC *before* it
- // launches, which is a no-op, and then fall into a deadlock.
- bool is_cancelled = !cm->RegisterCallback(
- cancellation_token, [container]() { container->StartCancel(); });
-
- if (is_cancelled) {
- ctx->SetStatus(errors::Cancelled("Operation has been cancelled."));
- // container's reference counter will take care of calling done().
- container->StartCancel();
- }
+ // This object will delete itself when done.
+ new CallContainer<GrpcCall>(ctx, num_elements, fail_fast_, try_rpc,
+ std::move(done), std::move(create_call_fn),
+ std::move(start_call_fn));
}
::grpc::GenericStub* GrpcRPCFactory::GetOrCreateStubForAddress(
@@ -210,4 +164,53 @@ GrpcRPCFactory::ChannelPtr GrpcRPCFactory::CreateChannelForAddress(
/*target=*/address, ::grpc::InsecureChannelCredentials(), args);
}
+void GrpcRPCFactory::CreateCall(const Tensor& request_t, const bool try_rpc,
+ int index, CallContainer<GrpcCall>* container,
+ Tensor* response_t, Tensor* status_code_t,
+ Tensor* status_message_t) {
+ auto request = request_t.flat<string>();
+ auto get_request_ptr = [&request](int64 ix) -> const string* {
+ return (request.size() > 1) ? &(request(ix)) : &(request(0));
+ };
+ auto response = response_t->flat<string>();
+ int32* status_code_ptr = nullptr;
+ string* status_message_ptr = nullptr;
+ if (try_rpc) {
+ status_code_ptr = status_code_t->flat<int32>().data();
+ status_message_ptr = status_message_t->flat<string>().data();
+ }
+ container->RegisterCall(container, index, try_rpc, get_request_ptr(index),
+ &response(index),
+ (try_rpc) ? &status_code_ptr[index] : nullptr,
+ (try_rpc) ? &status_message_ptr[index] : nullptr);
+}
+
+void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t,
+ GrpcCall* call) {
+ auto address = address_t.flat<string>();
+ auto method = method_t.flat<string>();
+ // Stubs are maintained by the GrpcRPCFactory class and will be
+ // deleted when the class is destroyed.
+ ::grpc::GenericStub* singleton_stub = nullptr;
+ if (address.size() == 1) {
+ singleton_stub = GetOrCreateStubForAddress(address(0));
+ }
+ auto get_stub = [&address, this,
+ singleton_stub](int64 ix) -> ::grpc::GenericStub* {
+ return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
+ : singleton_stub;
+ };
+ auto get_method_ptr = [&method](int64 ix) -> const string* {
+ return (method.size() > 1) ? &(method(ix)) : &(method(0));
+ };
+
+ int index = call->index();
+ // This object will delete itself when done.
+ new RPCState<string>(get_stub(index), &completion_queue_,
+ *get_method_ptr(index), call->request(),
+ call->response(),
+ /*done=*/[call](const Status& s) { call->Done(s); },
+ call->call_opts(), fail_fast_, timeout_in_ms_);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h
index 34ec235aaf..29394c84b5 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h
@@ -20,10 +20,16 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/util/rpc/call_container.h"
#include "tensorflow/core/util/rpc/rpc_factory.h"
namespace tensorflow {
+// Forward declaration of GrpcCall.
+namespace internal {
+class GrpcCall;
+} // namespace internal
+
class GrpcRPCFactory : public RPCFactory {
public:
explicit GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
@@ -42,6 +48,18 @@ class GrpcRPCFactory : public RPCFactory {
virtual ChannelPtr CreateChannelForAddress(const string& address);
private:
+ // Creates a call and registers it with given `container`. The `index` is used
+ // to index into the tensor arguments.
+ void CreateCall(const Tensor& request_t, const bool try_rpc, int index,
+ CallContainer<internal::GrpcCall>* container,
+ Tensor* response_t, Tensor* status_code_t,
+ Tensor* status_message_t);
+
+ // Asynchronously invokes the given `call`. The call completion is handled
+ // by the call container the call was previously registered with.
+ void StartCall(const Tensor& address_t, const Tensor& method_t,
+ internal::GrpcCall* call);
+
::grpc::GenericStub* GetOrCreateStubForAddress(const string& address);
bool fail_fast_;
diff --git a/tensorflow/core/util/rpc/call_container.h b/tensorflow/core/util/rpc/call_container.h
index 7f36056797..e1226a7f16 100644
--- a/tensorflow/core/util/rpc/call_container.h
+++ b/tensorflow/core/util/rpc/call_container.h
@@ -26,53 +26,60 @@ limitations under the License.
namespace tensorflow {
-template <typename Call>
+namespace internal {
+// The following class is used for coordination between a `CallContainer`
+// instance and a cancellation callback to make sure that the `CallContainer`
+// instance waits for the cancellation callback to be destroyed (either because
+// a cancellation occurred or because the callback was deregistered) before
+// deleting itself. Without this coordination the cancellation callback could
+// attempt to access a `CallContainer` instance that is no longer valid.
+class NotifyWhenDestroyed {
+ public:
+ explicit NotifyWhenDestroyed(std::shared_ptr<Notification> notification)
+ : notification_(std::move(notification)) {}
+
+ ~NotifyWhenDestroyed() { notification_->Notify(); }
+
+ private:
+ std::shared_ptr<Notification> notification_;
+};
+} // namespace internal
+
+// The following class is responsible for the life cycle management of a set of
+// RPC calls. The calls are started when an instance of the class is created and
+// the class contract guarantees to invoke a "done" callback provided by the
+// caller when all RPC calls have either completed or been cancelled.
+//
+// The caller should not make any assumptions about the validity of an instance
+// of this class after the provided callback has been invoked, which may be
+// immediately after the instance was created.
+template <class Call>
class CallContainer {
public:
+ typedef std::function<void(CallContainer<Call>*, int)> CreateCallFn;
+ typedef std::function<void(Call*)> StartCallFn;
+
+ // Uses the provided `create_call_fn` and `start_call_fn` functions to create
+ // and start a set of RPC calls. When all RPC calls have either completed or
+ // been cancelled, the `done` callback is invoked. The caller should not make
+ // any assumptions about the validity of the created instance as the instance
+ // will delete itself after invoking the `done` callback.
explicit CallContainer(OpKernelContext* ctx, int num_calls, bool fail_fast,
bool try_rpc, AsyncOpKernel::DoneCallback done,
- CancellationToken token)
- : ctx_(ctx),
- done_(std::move(done)),
- token_(token),
- fail_fast_(fail_fast),
- try_rpc_(try_rpc) {
- CHECK_GT(num_calls, 0);
-
- // This will run when all RPCs are finished.
- reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) {
- ctx_->cancellation_manager()->DeregisterCallback(token_);
- ctx_->SetStatus(s);
- done_();
- delete this;
- });
-
- // Subtract reference count from the initial creation.
- core::ScopedUnref unref(reffed_status_callback_);
-
- for (int i = 0; i < num_calls; ++i) {
- // Increase the reference on the callback for each new RPC.
- reffed_status_callback_->Ref();
- }
- }
+ CreateCallFn create_call_fn,
+ StartCallFn start_call_fn);
- std::list<Call>* calls() { return &calls_; }
+ // Registers a call with this container. This method expects its arguments to
+ // match those of a `Call` constructor as it forwards them to an underlying
+ // collection, which creates a `Call` instance in place.
+ template <class... Args>
+ void RegisterCall(Args&&... args);
- void StartCancel() {
- // Once this loop is done, can no longer assume anything is valid
- // because "delete this" may have been immediately called.
- // Nothing should run after this loop.
- for (auto& call : calls_) {
- call.StartCancel();
- }
- }
+ // Starts the cancellation of all RPC calls managed by this container.
+ void StartCancel();
- void Done(const Status& s, int index) {
- if (!try_rpc_) {
- reffed_status_callback_->UpdateStatus(s);
- }
- reffed_status_callback_->Unref();
- }
+ // Indicates that the `index`-th RPC call has finished.
+ void Done(const Status& s, int index);
private:
OpKernelContext* ctx_;
@@ -81,10 +88,88 @@ class CallContainer {
const CancellationToken token_;
const bool fail_fast_;
const bool try_rpc_;
+ std::shared_ptr<Notification> callback_destroyed_;
// Performs its own reference counting.
ReffedStatusCallback* reffed_status_callback_;
};
+template <class Call>
+CallContainer<Call>::CallContainer(
+ OpKernelContext* ctx, int num_calls, bool fail_fast, bool try_rpc,
+ AsyncOpKernel::DoneCallback done,
+ typename CallContainer<Call>::CreateCallFn create_call_fn,
+ typename CallContainer<Call>::StartCallFn start_call_fn)
+ : ctx_(ctx),
+ done_(std::move(done)),
+ token_(ctx->cancellation_manager()->get_cancellation_token()),
+ fail_fast_(fail_fast),
+ try_rpc_(try_rpc),
+ callback_destroyed_(new Notification) {
+ CHECK_GT(num_calls, 0);
+
+ // This will run when all RPCs are finished.
+ reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) {
+ ctx_->cancellation_manager()->DeregisterCallback(token_);
+ ctx_->SetStatus(s);
+ done_();
+ callback_destroyed_->WaitForNotification();
+ delete this;
+ });
+
+ // The cancellation callback needs to be registered before the RPC calls are
+ // started to make sure that the callback is properly cleaned up by the
+ // `reffed_status_callback` when all calls complete. At the same time, the
+ // cancellation callback should wait for the RPC calls to be started for the
+ // cancellation to take effect.
+ std::shared_ptr<internal::NotifyWhenDestroyed> notify_when_destroyed(
+ new internal::NotifyWhenDestroyed(callback_destroyed_));
+ std::shared_ptr<Notification> calls_started(new Notification);
+ bool is_cancelled = !ctx_->cancellation_manager()->RegisterCallback(
+ token_, [this, calls_started, notify_when_destroyed]() {
+ calls_started->WaitForNotification();
+ StartCancel();
+ });
+
+ for (int i = 0; i < num_calls; ++i) {
+ create_call_fn(this, i);
+ // Increase the reference on the callback for each new RPC.
+ reffed_status_callback_->Ref();
+ }
+ for (Call& call : calls_) {
+ start_call_fn(&call);
+ }
+ calls_started->Notify();
+
+ if (is_cancelled) {
+ ctx_->SetStatus(errors::Cancelled("Operation has been cancelled."));
+ StartCancel();
+ }
+
+ // Subtract reference count from the initial creation.
+ reffed_status_callback_->Unref();
+}
+
+template <class Call>
+template <class... Args>
+void CallContainer<Call>::RegisterCall(Args&&... args) {
+ calls_.emplace_back(std::forward<Args>(args)...);
+}
+
+template <class Call>
+void CallContainer<Call>::StartCancel() {
+ for (auto& call : calls_) {
+ call.StartCancel();
+ }
+}
+
+template <class Call>
+void CallContainer<Call>::Done(const Status& s, int index) {
+ if (!try_rpc_) {
+ reffed_status_callback_->UpdateStatus(s);
+ }
+ reffed_status_callback_->Unref();
+}
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
diff --git a/tensorflow/core/util/rpc/rpc_factory.h b/tensorflow/core/util/rpc/rpc_factory.h
index 9bf078c0f4..c4eaaf4457 100644
--- a/tensorflow/core/util/rpc/rpc_factory.h
+++ b/tensorflow/core/util/rpc/rpc_factory.h
@@ -32,10 +32,11 @@ class RPCFactory {
RPCFactory() {}
virtual ~RPCFactory() {}
- // Start a Call() to methods `method_t` at addresses `address_t` with
+ // Asynchronously invokes methods `method_t` at addresses `address_t` with
// request strings from `request_t`. Any of these may be scalar
// Tensors, in which case the operands are broadcasted.
- // Upon completion of all requests, `response_t` will be populated.
+ // Upon completion of all requests, `response_t` will be populated and the
+ // `done` callback will be invoked.
//
// If `try_rpc` is `true`, then `status_message_t` and
// `status_code_t` will be populated as well.