aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-09-21 09:53:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 09:57:04 -0700
commit4e252b2f997904769711b242bb37027706b08b7f (patch)
treeec947d8c655e4ca2603bc55ee5118b210c29c968 /tensorflow/core/common_runtime
parent233de7fe7efcf7c8fbcd4d3653a1f6d32feff5c8 (diff)
Set device on resource touching ops before checking where to execute.
Thanks @alextp for finding the bug! PiperOrigin-RevId: 213999971
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc41
1 files changed, 21 insertions, 20 deletions
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 1da1326a9a..1bc63616d0 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -251,26 +251,6 @@ Status EagerLocalExecute(EagerOperation* op,
EagerContext* ctx = op->EagerContext();
auto status = ctx->GetStatus();
if (!status.ok()) return status;
- // Ensure all resource-touching ops run in the device the resource is,
- // regardless of anything else that has been specified. This is identical to
- // the graph mode behavior.
- for (int i = 0; i < op->Inputs().size(); ++i) {
- Device* input_op_device = nullptr;
- status = op->Inputs()[i]->OpDevice(&input_op_device);
- if (!status.ok()) return status;
- VLOG(2) << "for op " << op->Name() << " input " << i << " "
- << DataTypeString(op->Inputs()[i]->dtype) << " "
- << (input_op_device == nullptr ? "cpu" : input_op_device->name())
- << " " << (op->Device() == nullptr ? "cpu" : op->Device()->name());
- if (op->Inputs()[i]->dtype == DT_RESOURCE &&
- (input_op_device != op->Device() || input_op_device == nullptr)) {
- Device* d = input_op_device == nullptr ? ctx->HostCPU() : input_op_device;
- VLOG(1) << "Changing device of operation " << op->Name() << " to "
- << d->name() << " because input #" << i
- << " is a resource in this device.";
- op->SetDevice(d);
- }
- }
Device* device = op->Device();
Fprint128 cache_key = op->MutableAttrs()->CacheKey(
@@ -604,6 +584,27 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
Status EagerExecute(EagerOperation* op,
gtl::InlinedVector<TensorHandle*, 2>* retvals,
int* num_retvals) {
+ // Ensure all resource-touching ops run in the device the resource is,
+ // regardless of anything else that has been specified. This is identical to
+ // the graph mode behavior.
+ EagerContext* ctx = op->EagerContext();
+ for (int i = 0; i < op->Inputs().size(); ++i) {
+ Device* input_op_device = nullptr;
+ auto status = op->Inputs()[i]->OpDevice(&input_op_device);
+ if (!status.ok()) return status;
+ VLOG(2) << "for op " << op->Name() << " input " << i << " "
+ << DataTypeString(op->Inputs()[i]->dtype) << " "
+ << (input_op_device == nullptr ? "cpu" : input_op_device->name())
+ << " " << (op->Device() == nullptr ? "cpu" : op->Device()->name());
+ if (op->Inputs()[i]->dtype == DT_RESOURCE &&
+ (input_op_device != op->Device() || input_op_device == nullptr)) {
+ Device* d = input_op_device == nullptr ? ctx->HostCPU() : input_op_device;
+ VLOG(1) << "Changing device of operation " << op->Name() << " to "
+ << d->name() << " because input #" << i
+ << " is a resource in this device.";
+ op->SetDevice(d);
+ }
+ }
bool op_is_local = IsLocal(op->EagerContext(), op->Device());
if (op_is_local) {