aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-07-26 15:56:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 16:01:20 -0700
commit4e0c5685af914684a77177002e4265a721c3e0ff (patch)
treebe224eb28a04d7f3ff4652884c3458a2e9fb29fc
parent62df725269a89a0a5d877eae18d0c83155f2ea9d (diff)
Don't make remote copy call when both send/recv devices are the same.
PiperOrigin-RevId: 206236233
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc14
1 files changed, 13 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 0c0fbc729c..f97fa4fadc 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -448,6 +448,14 @@ bool IsLocal(EagerContext* ctx, tensorflow::Device* d) {
return ctx->local_device_mgr()->LookupDevice(d->name(), &tmp).ok();
}
+bool OnSameTask(EagerContext* ctx, Device* first, Device* second) {
+ if (first == nullptr) first = ctx->HostCPU();
+ if (second == nullptr) second = ctx->HostCPU();
+ return first->parsed_name().job == second->parsed_name().job &&
+ first->parsed_name().replica == second->parsed_name().replica &&
+ first->parsed_name().task == second->parsed_name().task;
+}
+
Status EagerLocalExecute(EagerOperation* op,
gtl::InlinedVector<TensorHandle*, 2>* retvals,
int* num_retvals) {
@@ -689,7 +697,11 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
for (int i = 0; i < op->Inputs().size(); i++) {
tensorflow::Device* input_device;
TF_RETURN_IF_ERROR(op->Inputs()[i]->Device(&input_device));
- if (op->Device() != input_device) {
+ if (op->Device() != input_device &&
+ // If the expected and actual devices are on the same task, don't
+ // explicitly copy, and instead depend on the copy to happen locally
+ // when the op is executed on the device.
+ !OnSameTask(ctx, op->Device(), input_device)) {
// TODO(b/110044833): It's possible the same tensor gets copied to the
// remote device repeatedly.
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(