diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-09-21 09:53:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-21 09:57:04 -0700 |
commit | 4e252b2f997904769711b242bb37027706b08b7f (patch) | |
tree | ec947d8c655e4ca2603bc55ee5118b210c29c968 /tensorflow/core/common_runtime | |
parent | 233de7fe7efcf7c8fbcd4d3653a1f6d32feff5c8 (diff) |
Set device on resource touching ops before checking where to execute.
Thanks @alextp for finding the bug!
PiperOrigin-RevId: 213999971
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r-- | tensorflow/core/common_runtime/eager/execute.cc | 41 |
1 files changed, 21 insertions, 20 deletions
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 1da1326a9a..1bc63616d0 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -251,26 +251,6 @@ Status EagerLocalExecute(EagerOperation* op, EagerContext* ctx = op->EagerContext(); auto status = ctx->GetStatus(); if (!status.ok()) return status; - // Ensure all resource-touching ops run in the device the resource is, - // regardless of anything else that has been specified. This is identical to - // the graph mode behavior. - for (int i = 0; i < op->Inputs().size(); ++i) { - Device* input_op_device = nullptr; - status = op->Inputs()[i]->OpDevice(&input_op_device); - if (!status.ok()) return status; - VLOG(2) << "for op " << op->Name() << " input " << i << " " - << DataTypeString(op->Inputs()[i]->dtype) << " " - << (input_op_device == nullptr ? "cpu" : input_op_device->name()) - << " " << (op->Device() == nullptr ? "cpu" : op->Device()->name()); - if (op->Inputs()[i]->dtype == DT_RESOURCE && - (input_op_device != op->Device() || input_op_device == nullptr)) { - Device* d = input_op_device == nullptr ? ctx->HostCPU() : input_op_device; - VLOG(1) << "Changing device of operation " << op->Name() << " to " - << d->name() << " because input #" << i - << " is a resource in this device."; - op->SetDevice(d); - } - } Device* device = op->Device(); Fprint128 cache_key = op->MutableAttrs()->CacheKey( @@ -604,6 +584,27 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, Status EagerExecute(EagerOperation* op, gtl::InlinedVector<TensorHandle*, 2>* retvals, int* num_retvals) { + // Ensure all resource-touching ops run in the device the resource is, + // regardless of anything else that has been specified. This is identical to + // the graph mode behavior. + EagerContext* ctx = op->EagerContext(); + for (int i = 0; i < op->Inputs().size(); ++i) { + Device* input_op_device = nullptr; + auto status = op->Inputs()[i]->OpDevice(&input_op_device); + if (!status.ok()) return status; + VLOG(2) << "for op " << op->Name() << " input " << i << " " + << DataTypeString(op->Inputs()[i]->dtype) << " " + << (input_op_device == nullptr ? "cpu" : input_op_device->name()) + << " " << (op->Device() == nullptr ? "cpu" : op->Device()->name()); + if (op->Inputs()[i]->dtype == DT_RESOURCE && + (input_op_device != op->Device() || input_op_device == nullptr)) { + Device* d = input_op_device == nullptr ? ctx->HostCPU() : input_op_device; + VLOG(1) << "Changing device of operation " << op->Name() << " to " + << d->name() << " because input #" << i + << " is a resource in this device."; + op->SetDevice(d); + } + } bool op_is_local = IsLocal(op->EagerContext(), op->Device()); if (op_is_local) { |