diff options
Diffstat (limited to 'tensorflow/core/common_runtime/eager/execute.cc')
-rw-r--r-- | tensorflow/core/common_runtime/eager/execute.cc | 137 |
1 files changed, 101 insertions, 36 deletions
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 39bda9119c..0c0fbc729c 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -88,6 +88,8 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i, TF_RETURN_IF_ERROR((*handle)->Device(&handle_device)); const Device* actual_device = handle_device == nullptr ? ctx->HostCPU() : handle_device; + const Device* op_device = + op->Device() == nullptr ? ctx->HostCPU() : op->Device(); if (expected_device != actual_device) { switch (ctx->GetDevicePlacementPolicy()) { @@ -106,8 +108,8 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i, " cannot compute ", op->Name(), " as input #", i, " was expected to be on ", expected_device->name(), " but is actually on ", - actual_device->name(), " (operation running on ", - op->Device()->name(), ")", + actual_device->name(), " (operation running on ", op_device->name(), + ")", " Tensors can be copied explicitly using .gpu() or .cpu() " "methods," " or transparently copied by using tf.enable_eager_execution(" @@ -118,7 +120,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i, LOG(WARNING) << "before computing " << op->Name() << " input #" << i << " was expected to be on " << expected_device->name() << " but is actually on " << actual_device->name() - << " (operation running on " << op->Device()->name() + << " (operation running on " << op_device->name() << "). This triggers a copy which can be a performance " "bottleneck."; break; @@ -128,7 +130,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i, // We are only here if the policy is warn or silent copies, so we should // trigger a copy. auto pre_time = Env::Default()->NowMicros(); - TensorHandle* result_handle; + TensorHandle* result_handle = nullptr; Status status = EagerCopyToDevice( *handle, ctx, expected_device->name().c_str(), &result_handle); if (run_metadata != nullptr) { @@ -173,7 +175,7 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device, tensorflow::TensorHandle* handle = op->Inputs()[i]; if (handle->dtype != kernel->input_type(i)) { return errors::InvalidArgument( - "cannot compute ", op->Name(), " as input #", i, + "cannot compute ", op->Name(), " as input #", i, "(zero-based)", " was expected to be a ", DataTypeString(kernel->input_type(i)), " tensor but is a ", DataTypeString(handle->dtype), " tensor"); } @@ -512,7 +514,8 @@ Status EagerLocalExecute(EagerOperation* op, // See WARNING comment in Execute (before kernel->Run) - would be nice to // rework to avoid this subtlety. tf_shared_lock l(*ctx->FunctionsMu()); - status = KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel); + status = KernelAndDevice::Init(ndef, ctx->func_lib(device), ctx->runner(), + kernel); if (!status.ok()) { delete kernel; return status; @@ -582,6 +585,87 @@ Status EagerLocalExecute(EagerOperation* op, return status; } +std::function<void()> GetRemoteTensorDestructor( + EagerContext* ctx, eager::EagerClient* eager_client, uint64 context_id, + uint64 op_id, int output_num) { + return [ctx, eager_client, context_id, op_id, output_num]() { + std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest); + request->set_context_id(context_id); + + auto* handle_to_decref = request->add_queue()->mutable_handle_to_decref(); + handle_to_decref->set_op_id(op_id); + handle_to_decref->set_output_num(output_num); + + if (ctx->Async()) { + tensorflow::uint64 id = ctx->NextId(); + auto* node = + new eager::RemoteExecuteNode(id, std::move(request), eager_client); + ctx->ExecutorAdd(node); + } else { + eager::EnqueueRequest* actual_request = request.release(); + eager::EnqueueResponse* response = new eager::EnqueueResponse; + eager_client->EnqueueAsync( + actual_request, response, + [actual_request, response](const tensorflow::Status& s) { + delete actual_request; + delete response; + }); + } + + return tensorflow::Status::OK(); + }; +} + +// When !ctx->UseSendTensorRPC(), then tensors are shipped between remote +// devices by the receiver invoking the WorkerService.RecvTensor RPC *on the +// sender* (Rendezvous::RecvAsync() invoked by the _Recv kernel). +// +// However, in some configurations the node that has the tensor to be copied +// isn't running a server (WorkerService RPC interface). For such cases, +// this function enables sending tensors using the EagerService.SendTensor RPC +// *on the receiver*. +Status EagerRemoteSendTensor(EagerContext* ctx, TensorHandle* h, + Device* recv_device, TensorHandle** result) { + eager::EagerClient* eager_client; + uint64 context_id; + TF_RETURN_IF_ERROR( + ctx->GetClientAndContextID(recv_device, &eager_client, &context_id)); + + eager::SendTensorRequest request; + eager::SendTensorResponse response; + + request.set_context_id(context_id); + request.set_op_id(ctx->NextId()); + request.set_device_name(recv_device->name()); + + const Tensor* tensor; + TF_RETURN_IF_ERROR(h->Tensor(&tensor)); + tensor->AsProtoTensorContent(request.add_tensors()); + + const tensorflow::uint64 id = request.op_id(); + + // TODO(nareshmodi): support making this call async. + Notification n; + Status status; + eager_client->SendTensorAsync(&request, &response, + [&n, &status](const Status& s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + if (!status.ok()) return status; + + std::function<void()> destructor = + GetRemoteTensorDestructor(ctx, eager_client, context_id, id, 0); + + *result = new TensorHandle(id, /*output_num=*/0, /*remote_shape_node_id=*/0, + tensor->dtype(), std::move(destructor), + recv_device, recv_device, ctx); + (*result)->SetRemoteShape(MakeUnique<TensorShape>(tensor->shape())); + + return Status::OK(); +} + Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, int* num_retvals) { #ifdef __ANDROID__ @@ -595,10 +679,12 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, TF_RETURN_IF_ERROR( ctx->GetClientAndContextID(op->Device(), &eager_client, &context_id)); - eager::EnqueueRequest request; + std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest); eager::EnqueueResponse response; - auto* remote_op = request.add_queue()->mutable_operation(); + request->set_context_id(context_id); + + auto* remote_op = request->add_queue()->mutable_operation(); for (int i = 0; i < op->Inputs().size(); i++) { tensorflow::Device* input_device; @@ -628,8 +714,6 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, op->Attrs().FillAttrValueMap(remote_op->mutable_attrs()); remote_op->set_device(op->Device()->name()); - request.set_context_id(context_id); - DataTypeVector output_dtypes; TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes)); @@ -651,32 +735,11 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, for (int i = 0; i < *num_retvals; i++) { // TODO(nareshmodi): Change the callback to instead add the decref to a list // of pending decrefs that we can send as a batch with the next execute. - std::function<void()> callback = [ctx, eager_client, context_id, id, i]() { - eager::EnqueueRequest request; - request.set_context_id(context_id); - - auto* handle_to_decref = request.add_queue()->mutable_handle_to_decref(); - handle_to_decref->set_op_id(id); - handle_to_decref->set_output_num(i); - - if (ctx->Async()) { - tensorflow::uint64 id = ctx->NextId(); - auto* node = new eager::RemoteExecuteNode(id, request, eager_client); - ctx->ExecutorAdd(node); - } else { - Notification n; - eager::EnqueueResponse response; - eager_client->EnqueueAsync( - &request, &response, - [&n](const tensorflow::Status& s) { n.Notify(); }); - n.WaitForNotification(); - } - - return tensorflow::Status::OK(); - }; + std::function<void()> destructor = + GetRemoteTensorDestructor(ctx, eager_client, context_id, id, i); retvals[i] = new TensorHandle(remote_op->id(), i, remote_node_id, - output_dtypes[i], std::move(callback), + output_dtypes[i], std::move(destructor), op_device, op_device, op->EagerContext()); } @@ -690,7 +753,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, } // Unable to capture via std::move, so bind instead. auto* node = new eager::RemoteExecuteNode( - remote_node_id, request, eager_client, op->Inputs(), + remote_node_id, std::move(request), eager_client, op->Inputs(), std::bind( [](const gtl::InlinedVector<TensorHandle*, 2>& retvals, const Status& status, const eager::EnqueueResponse& response) { @@ -707,7 +770,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, } else { Notification n; Status status; - eager_client->EnqueueAsync(&request, &response, + eager_client->EnqueueAsync(request.get(), &response, [&n, &status](const Status& s) { status = s; n.Notify(); @@ -936,6 +999,8 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, if (sender_is_local && recver_is_local) { return LocalEagerCopyToDevice(h, ctx, recv_device, result); + } else if (ctx->UseSendTensorRPC() && sender_is_local && !recver_is_local) { + return EagerRemoteSendTensor(ctx, h, recv_device, result); } else { string wire_id = GetUniqueWireID(); |