diff options
author | 2018-07-26 15:56:19 -0700 | |
---|---|---|
committer | 2018-07-26 16:01:20 -0700 | |
commit | 4e0c5685af914684a77177002e4265a721c3e0ff (patch) | |
tree | be224eb28a04d7f3ff4652884c3458a2e9fb29fc | |
parent | 62df725269a89a0a5d877eae18d0c83155f2ea9d (diff) |
Don't make remote copy call when both send/recv devices are the same.
PiperOrigin-RevId: 206236233
-rw-r--r-- | tensorflow/core/common_runtime/eager/execute.cc | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 0c0fbc729c..f97fa4fadc 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -448,6 +448,14 @@ bool IsLocal(EagerContext* ctx, tensorflow::Device* d) { return ctx->local_device_mgr()->LookupDevice(d->name(), &tmp).ok(); } +bool OnSameTask(EagerContext* ctx, Device* first, Device* second) { + if (first == nullptr) first = ctx->HostCPU(); + if (second == nullptr) second = ctx->HostCPU(); + return first->parsed_name().job == second->parsed_name().job && + first->parsed_name().replica == second->parsed_name().replica && + first->parsed_name().task == second->parsed_name().task; +} + Status EagerLocalExecute(EagerOperation* op, gtl::InlinedVector<TensorHandle*, 2>* retvals, int* num_retvals) { @@ -689,7 +697,11 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, for (int i = 0; i < op->Inputs().size(); i++) { tensorflow::Device* input_device; TF_RETURN_IF_ERROR(op->Inputs()[i]->Device(&input_device)); - if (op->Device() != input_device) { + if (op->Device() != input_device && + // If the expected and actual devices are on the same task, don't + // explicitly copy, and instead depend on the copy to happen locally + // when the op is executed on the device. + !OnSameTask(ctx, op->Device(), input_device)) { // TODO(b/110044833): It's possible the same tensor gets copied to the // remote device repeatedly. TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice( |