aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc')
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc135
1 files changed, 69 insertions, 66 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
index d004abd1c1..cde6b785dc 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
@@ -30,7 +30,7 @@ limitations under the License.
namespace tensorflow {
-namespace {
+namespace internal {
class GrpcCall {
public:
explicit GrpcCall(CallContainer<GrpcCall>* container, int index, bool try_rpc,
@@ -57,9 +57,10 @@ class GrpcCall {
container_->Done(s, index_);
}
+ CallOptions* call_opts() { return &call_opts_; }
+ int index() { return index_; }
const string& request() const { return *request_msg_; }
string* response() const { return response_msg_; }
- CallOptions* call_opts() { return &call_opts_; }
private:
CallContainer<GrpcCall>* const container_;
@@ -72,7 +73,9 @@ class GrpcCall {
string* status_message_;
};
-} // namespace
+} // namespace internal
+
+using internal::GrpcCall;
GrpcRPCFactory::GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
int64 timeout_in_ms)
@@ -110,28 +113,6 @@ void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements,
Tensor* response_t, Tensor* status_code_t,
Tensor* status_message_t,
AsyncOpKernel::DoneCallback done) {
- auto address = address_t.flat<string>();
- auto method = method_t.flat<string>();
- auto request = request_t.flat<string>();
-
- // Stubs are maintained by the GrpcRPCFactory class and will be
- // deleted when the class is destroyed.
- ::grpc::GenericStub* singleton_stub = nullptr;
- if (address.size() == 1) {
- singleton_stub = GetOrCreateStubForAddress(address(0));
- }
- auto get_stub = [&address, this,
- singleton_stub](int64 ix) -> ::grpc::GenericStub* {
- return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
- : singleton_stub;
- };
- auto get_method_ptr = [&method](int64 ix) -> const string* {
- return (method.size() > 1) ? &(method(ix)) : &(method(0));
- };
- auto get_request_ptr = [&request](int64 ix) -> const string* {
- return (request.size() > 1) ? &(request(ix)) : &(request(0));
- };
-
if (try_rpc) {
// In this case status_code will never be set in the response,
// so we just set it to OK.
@@ -140,49 +121,22 @@ void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements,
static_cast<int>(errors::Code::OK));
}
- CancellationManager* cm = ctx->cancellation_manager();
- CancellationToken cancellation_token = cm->get_cancellation_token();
-
- // This object will delete itself when done.
- auto* container =
- new CallContainer<GrpcCall>(ctx, num_elements, fail_fast_, try_rpc,
- std::move(done), cancellation_token);
-
- auto response = response_t->flat<string>();
- int32* status_code_ptr = nullptr;
- string* status_message_ptr = nullptr;
- if (try_rpc) {
- status_code_ptr = status_code_t->flat<int32>().data();
- status_message_ptr = status_message_t->flat<string>().data();
- }
- for (int i = 0; i < num_elements; ++i) {
- container->calls()->emplace_back(
- container, i, try_rpc, get_request_ptr(i), &response(i),
- (try_rpc) ? &status_code_ptr[i] : nullptr,
- (try_rpc) ? &status_message_ptr[i] : nullptr);
- }
+ CallContainer<GrpcCall>::CreateCallFn create_call_fn =
+ [this, &request_t, &try_rpc, response_t, status_code_t, status_message_t](
+ CallContainer<GrpcCall>* container, int index) {
+ CreateCall(request_t, try_rpc, index, container, response_t,
+ status_code_t, status_message_t);
+ };
- int i = 0;
- for (GrpcCall& call : *(container->calls())) {
- // This object will delete itself when done.
- new RPCState<string>(get_stub(i), &completion_queue_, *get_method_ptr(i),
- call.request(), call.response(),
- /*done=*/[&call](const Status& s) { call.Done(s); },
- call.call_opts(), fail_fast_, timeout_in_ms_);
- ++i;
- }
+ CallContainer<GrpcCall>::StartCallFn start_call_fn =
+ [this, &address_t, &method_t](GrpcCall* call) {
+ StartCall(address_t, method_t, call);
+ };
- // Need to register this callback after all the RPCs are in
- // flight; otherwise we may try to cancel an RPC *before* it
- // launches, which is a no-op, and then fall into a deadlock.
- bool is_cancelled = !cm->RegisterCallback(
- cancellation_token, [container]() { container->StartCancel(); });
-
- if (is_cancelled) {
- ctx->SetStatus(errors::Cancelled("Operation has been cancelled."));
- // container's reference counter will take care of calling done().
- container->StartCancel();
- }
+ // This object will delete itself when done.
+ new CallContainer<GrpcCall>(ctx, num_elements, fail_fast_, try_rpc,
+ std::move(done), std::move(create_call_fn),
+ std::move(start_call_fn));
}
::grpc::GenericStub* GrpcRPCFactory::GetOrCreateStubForAddress(
@@ -210,4 +164,53 @@ GrpcRPCFactory::ChannelPtr GrpcRPCFactory::CreateChannelForAddress(
/*target=*/address, ::grpc::InsecureChannelCredentials(), args);
}
+void GrpcRPCFactory::CreateCall(const Tensor& request_t, const bool try_rpc,
+ int index, CallContainer<GrpcCall>* container,
+ Tensor* response_t, Tensor* status_code_t,
+ Tensor* status_message_t) {
+ auto request = request_t.flat<string>();
+ auto get_request_ptr = [&request](int64 ix) -> const string* {
+ return (request.size() > 1) ? &(request(ix)) : &(request(0));
+ };
+ auto response = response_t->flat<string>();
+ int32* status_code_ptr = nullptr;
+ string* status_message_ptr = nullptr;
+ if (try_rpc) {
+ status_code_ptr = status_code_t->flat<int32>().data();
+ status_message_ptr = status_message_t->flat<string>().data();
+ }
+ container->RegisterCall(container, index, try_rpc, get_request_ptr(index),
+ &response(index),
+ (try_rpc) ? &status_code_ptr[index] : nullptr,
+ (try_rpc) ? &status_message_ptr[index] : nullptr);
+}
+
+void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t,
+ GrpcCall* call) {
+ auto address = address_t.flat<string>();
+ auto method = method_t.flat<string>();
+ // Stubs are maintained by the GrpcRPCFactory class and will be
+ // deleted when the class is destroyed.
+ ::grpc::GenericStub* singleton_stub = nullptr;
+ if (address.size() == 1) {
+ singleton_stub = GetOrCreateStubForAddress(address(0));
+ }
+ auto get_stub = [&address, this,
+ singleton_stub](int64 ix) -> ::grpc::GenericStub* {
+ return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
+ : singleton_stub;
+ };
+ auto get_method_ptr = [&method](int64 ix) -> const string* {
+ return (method.size() > 1) ? &(method(ix)) : &(method(0));
+ };
+
+ int index = call->index();
+ // This object will delete itself when done.
+ new RPCState<string>(get_stub(index), &completion_queue_,
+ *get_method_ptr(index), call->request(),
+ call->response(),
+ /*done=*/[call](const Status& s) { call->Done(s); },
+ call->call_opts(), fail_fast_, timeout_in_ms_);
+}
+
} // namespace tensorflow