diff options
author | Akshay Agrawal <akshayka@google.com> | 2018-07-03 12:21:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-03 12:24:24 -0700 |
commit | 962c639b27b40afafdc41d7bffca2ee1d2ccd1cf (patch) | |
tree | d8e69f2b7707b5b32ffcea11844c61cc2c499007 /tensorflow/core/distributed_runtime | |
parent | 69fe072cb467c0723e8c5266c8b32288fb3104a8 (diff) |
Make functions defined with tfe.defun respect devices when executing.
Modifies GraphModeFunction to emit PartitionedCall ops instead of Call ops
so that the created functions can execute across devices. This should strictly
increase the set of functions that tfe.defun can faithfully execute.
Previous to this change, functions executed through tfe.defun would ignore
device annotations and only run on a single device. It is not yet possible to execute
a function across multiple processes.
Specifically, this CL:
(1) Adds a stateful version of PartitionedCall,
(2) Modifies `defun` to emit PartitionedCall or StatefulPartitionedCall by default,
(3) Makes `tf.gradients` aware of the existence of `(Stateful)PartitionedCall`,
(4) Fixes bugs in PartitionedCallOp related to the placement of
resource-touching ops / which args and retvals are always on host memory, and
also removes the requirement for args/retvals to be passed through the host.
PiperOrigin-RevId: 203164388
Diffstat (limited to 'tensorflow/core/distributed_runtime')
-rw-r--r-- | tensorflow/core/distributed_runtime/eager/eager_service_impl.cc | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 5a26d5bf48..8ecccd4d06 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -63,10 +63,10 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name, } *num_retvals += iter->second.i(); } else if (!output_arg.type_list_attr().empty()) { - auto iter = attrs.find(output_arg.number_attr()); + auto iter = attrs.find(output_arg.type_list_attr()); if (iter == attrs.end()) { - return errors::InvalidArgument("Unable to find number_attr ", - output_arg.number_attr(), + return errors::InvalidArgument("Unable to find type_list_attr ", + output_arg.type_list_attr(), " for Op: ", op_name); } *num_retvals += iter->second.list().type_size(); |