aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2018-07-03 12:21:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-03 12:24:24 -0700
commit962c639b27b40afafdc41d7bffca2ee1d2ccd1cf (patch)
treed8e69f2b7707b5b32ffcea11844c61cc2c499007 /tensorflow/core/distributed_runtime
parent69fe072cb467c0723e8c5266c8b32288fb3104a8 (diff)
Make functions defined with tfe.defun respect devices when executing.
Modifies GraphModeFunction to emit PartitionedCall ops instead of Call ops so that the created functions can execute across devices. This should strictly increase the set of functions that tfe.defun can faithfully execute. Previous to this change, functions executed through tfe.defun would ignore device annotations and only run on a single device. It is not yet possible to execute a function across multiple processes. Specifically, this CL: (1) Adds a stateful version of PartitionedCall, (2) Modifies `defun` to emit PartitionedCall or StatefulPartitionedCall by default, (3) Makes `tf.gradients` aware of the existence of `(Stateful)PartitionedCall`, (4) Fixes bugs in PartitionedCallOp related to the placement of resource-touching ops / which args and retvals are always on host memory, and also removes the requirement for args/retvals to be passed through the host. PiperOrigin-RevId: 203164388
Diffstat (limited to 'tensorflow/core/distributed_runtime')
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl.cc6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index 5a26d5bf48..8ecccd4d06 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -63,10 +63,10 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
}
*num_retvals += iter->second.i();
} else if (!output_arg.type_list_attr().empty()) {
- auto iter = attrs.find(output_arg.number_attr());
+ auto iter = attrs.find(output_arg.type_list_attr());
if (iter == attrs.end()) {
- return errors::InvalidArgument("Unable to find number_attr ",
- output_arg.number_attr(),
+ return errors::InvalidArgument("Unable to find type_list_attr ",
+ output_arg.type_list_attr(),
" for Op: ", op_name);
}
*num_retvals += iter->second.list().type_size();