aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/eager/execute.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/eager/execute.cc')
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc137
1 files changed, 101 insertions, 36 deletions
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 39bda9119c..0c0fbc729c 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -88,6 +88,8 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i,
TF_RETURN_IF_ERROR((*handle)->Device(&handle_device));
const Device* actual_device =
handle_device == nullptr ? ctx->HostCPU() : handle_device;
+ const Device* op_device =
+ op->Device() == nullptr ? ctx->HostCPU() : op->Device();
if (expected_device != actual_device) {
switch (ctx->GetDevicePlacementPolicy()) {
@@ -106,8 +108,8 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i,
" cannot compute ",
op->Name(), " as input #", i, " was expected to be on ",
expected_device->name(), " but is actually on ",
- actual_device->name(), " (operation running on ",
- op->Device()->name(), ")",
+ actual_device->name(), " (operation running on ", op_device->name(),
+ ")",
" Tensors can be copied explicitly using .gpu() or .cpu() "
"methods,"
" or transparently copied by using tf.enable_eager_execution("
@@ -118,7 +120,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i,
LOG(WARNING) << "before computing " << op->Name() << " input #" << i
<< " was expected to be on " << expected_device->name()
<< " but is actually on " << actual_device->name()
- << " (operation running on " << op->Device()->name()
+ << " (operation running on " << op_device->name()
<< "). This triggers a copy which can be a performance "
"bottleneck.";
break;
@@ -128,7 +130,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i,
// We are only here if the policy is warn or silent copies, so we should
// trigger a copy.
auto pre_time = Env::Default()->NowMicros();
- TensorHandle* result_handle;
+ TensorHandle* result_handle = nullptr;
Status status = EagerCopyToDevice(
*handle, ctx, expected_device->name().c_str(), &result_handle);
if (run_metadata != nullptr) {
@@ -173,7 +175,7 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
tensorflow::TensorHandle* handle = op->Inputs()[i];
if (handle->dtype != kernel->input_type(i)) {
return errors::InvalidArgument(
- "cannot compute ", op->Name(), " as input #", i,
+ "cannot compute ", op->Name(), " as input #", i, "(zero-based)",
" was expected to be a ", DataTypeString(kernel->input_type(i)),
" tensor but is a ", DataTypeString(handle->dtype), " tensor");
}
@@ -512,7 +514,8 @@ Status EagerLocalExecute(EagerOperation* op,
// See WARNING comment in Execute (before kernel->Run) - would be nice to
// rework to avoid this subtlety.
tf_shared_lock l(*ctx->FunctionsMu());
- status = KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel);
+ status = KernelAndDevice::Init(ndef, ctx->func_lib(device), ctx->runner(),
+ kernel);
if (!status.ok()) {
delete kernel;
return status;
@@ -582,6 +585,87 @@ Status EagerLocalExecute(EagerOperation* op,
return status;
}
+std::function<void()> GetRemoteTensorDestructor(
+ EagerContext* ctx, eager::EagerClient* eager_client, uint64 context_id,
+ uint64 op_id, int output_num) {
+ return [ctx, eager_client, context_id, op_id, output_num]() {
+ std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
+ request->set_context_id(context_id);
+
+ auto* handle_to_decref = request->add_queue()->mutable_handle_to_decref();
+ handle_to_decref->set_op_id(op_id);
+ handle_to_decref->set_output_num(output_num);
+
+ if (ctx->Async()) {
+ tensorflow::uint64 id = ctx->NextId();
+ auto* node =
+ new eager::RemoteExecuteNode(id, std::move(request), eager_client);
+ ctx->ExecutorAdd(node);
+ } else {
+ eager::EnqueueRequest* actual_request = request.release();
+ eager::EnqueueResponse* response = new eager::EnqueueResponse;
+ eager_client->EnqueueAsync(
+ actual_request, response,
+ [actual_request, response](const tensorflow::Status& s) {
+ delete actual_request;
+ delete response;
+ });
+ }
+
+ return tensorflow::Status::OK();
+ };
+}
+
+// When !ctx->UseSendTensorRPC(), then tensors are shipped between remote
+// devices by the receiver invoking the WorkerService.RecvTensor RPC *on the
+// sender* (Rendezvous::RecvAsync() invoked by the _Recv kernel).
+//
+// However, in some configurations the node that has the tensor to be copied
+// isn't running a server (WorkerService RPC interface). For such cases,
+// this function enables sending tensors using the EagerService.SendTensor RPC
+// *on the receiver*.
+Status EagerRemoteSendTensor(EagerContext* ctx, TensorHandle* h,
+ Device* recv_device, TensorHandle** result) {
+ eager::EagerClient* eager_client;
+ uint64 context_id;
+ TF_RETURN_IF_ERROR(
+ ctx->GetClientAndContextID(recv_device, &eager_client, &context_id));
+
+ eager::SendTensorRequest request;
+ eager::SendTensorResponse response;
+
+ request.set_context_id(context_id);
+ request.set_op_id(ctx->NextId());
+ request.set_device_name(recv_device->name());
+
+ const Tensor* tensor;
+ TF_RETURN_IF_ERROR(h->Tensor(&tensor));
+ tensor->AsProtoTensorContent(request.add_tensors());
+
+ const tensorflow::uint64 id = request.op_id();
+
+ // TODO(nareshmodi): support making this call async.
+ Notification n;
+ Status status;
+ eager_client->SendTensorAsync(&request, &response,
+ [&n, &status](const Status& s) {
+ status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ if (!status.ok()) return status;
+
+ std::function<void()> destructor =
+ GetRemoteTensorDestructor(ctx, eager_client, context_id, id, 0);
+
+ *result = new TensorHandle(id, /*output_num=*/0, /*remote_shape_node_id=*/0,
+ tensor->dtype(), std::move(destructor),
+ recv_device, recv_device, ctx);
+ (*result)->SetRemoteShape(MakeUnique<TensorShape>(tensor->shape()));
+
+ return Status::OK();
+}
+
Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
int* num_retvals) {
#ifdef __ANDROID__
@@ -595,10 +679,12 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
TF_RETURN_IF_ERROR(
ctx->GetClientAndContextID(op->Device(), &eager_client, &context_id));
- eager::EnqueueRequest request;
+ std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
eager::EnqueueResponse response;
- auto* remote_op = request.add_queue()->mutable_operation();
+ request->set_context_id(context_id);
+
+ auto* remote_op = request->add_queue()->mutable_operation();
for (int i = 0; i < op->Inputs().size(); i++) {
tensorflow::Device* input_device;
@@ -628,8 +714,6 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
remote_op->set_device(op->Device()->name());
- request.set_context_id(context_id);
-
DataTypeVector output_dtypes;
TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes));
@@ -651,32 +735,11 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
for (int i = 0; i < *num_retvals; i++) {
// TODO(nareshmodi): Change the callback to instead add the decref to a list
// of pending decrefs that we can send as a batch with the next execute.
- std::function<void()> callback = [ctx, eager_client, context_id, id, i]() {
- eager::EnqueueRequest request;
- request.set_context_id(context_id);
-
- auto* handle_to_decref = request.add_queue()->mutable_handle_to_decref();
- handle_to_decref->set_op_id(id);
- handle_to_decref->set_output_num(i);
-
- if (ctx->Async()) {
- tensorflow::uint64 id = ctx->NextId();
- auto* node = new eager::RemoteExecuteNode(id, request, eager_client);
- ctx->ExecutorAdd(node);
- } else {
- Notification n;
- eager::EnqueueResponse response;
- eager_client->EnqueueAsync(
- &request, &response,
- [&n](const tensorflow::Status& s) { n.Notify(); });
- n.WaitForNotification();
- }
-
- return tensorflow::Status::OK();
- };
+ std::function<void()> destructor =
+ GetRemoteTensorDestructor(ctx, eager_client, context_id, id, i);
retvals[i] = new TensorHandle(remote_op->id(), i, remote_node_id,
- output_dtypes[i], std::move(callback),
+ output_dtypes[i], std::move(destructor),
op_device, op_device, op->EagerContext());
}
@@ -690,7 +753,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
}
// Unable to capture via std::move, so bind instead.
auto* node = new eager::RemoteExecuteNode(
- remote_node_id, request, eager_client, op->Inputs(),
+ remote_node_id, std::move(request), eager_client, op->Inputs(),
std::bind(
[](const gtl::InlinedVector<TensorHandle*, 2>& retvals,
const Status& status, const eager::EnqueueResponse& response) {
@@ -707,7 +770,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
} else {
Notification n;
Status status;
- eager_client->EnqueueAsync(&request, &response,
+ eager_client->EnqueueAsync(request.get(), &response,
[&n, &status](const Status& s) {
status = s;
n.Notify();
@@ -936,6 +999,8 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
if (sender_is_local && recver_is_local) {
return LocalEagerCopyToDevice(h, ctx, recv_device, result);
+ } else if (ctx->UseSendTensorRPC() && sender_is_local && !recver_is_local) {
+ return EagerRemoteSendTensor(ctx, h, recv_device, result);
} else {
string wire_id = GetUniqueWireID();