aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/partitioned_function_ops.cc
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2018-07-03 12:21:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-03 12:24:24 -0700
commit962c639b27b40afafdc41d7bffca2ee1d2ccd1cf (patch)
treed8e69f2b7707b5b32ffcea11844c61cc2c499007 /tensorflow/core/kernels/partitioned_function_ops.cc
parent69fe072cb467c0723e8c5266c8b32288fb3104a8 (diff)
Make functions defined with tfe.defun respect devices when executing.
Modifies GraphModeFunction to emit PartitionedCall ops instead of Call ops so that the created functions can execute across devices. This should strictly increase the set of functions that tfe.defun can faithfully execute. Previous to this change, functions executed through tfe.defun would ignore device annotations and only run on a single device. It is not yet possible to execute a function across multiple processes. Specifically, this CL: (1) Adds a stateful version of PartitionedCall, (2) Modifies `defun` to emit PartitionedCall or StatefulPartitionedCall by default, (3) Makes `tf.gradients` aware of the existence of `(Stateful)PartitionedCall`, (4) Fixes bugs in PartitionedCallOp related to the placement of resource-touching ops / which args and retvals are always on host memory, and also removes the requirement for args/retvals to be passed through the host. PiperOrigin-RevId: 203164388
Diffstat (limited to 'tensorflow/core/kernels/partitioned_function_ops.cc')
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc307
1 files changed, 232 insertions, 75 deletions
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index b6ee808091..71c1b56fbd 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_partition.h"
@@ -42,8 +43,7 @@ namespace {
// TODO(akshayka): Support distributed execution.
class PartitionedCallOp : public AsyncOpKernel {
public:
- explicit PartitionedCallOp(OpKernelConstruction* ctx)
- : AsyncOpKernel(ctx), local_device_name_(ctx->device()->name()) {
+ explicit PartitionedCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
}
@@ -55,6 +55,9 @@ class PartitionedCallOp : public AsyncOpKernel {
errors::Internal("No function library is provided."),
done);
+ OpInputList args;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &args), done);
+
// The function body's graph is placed and partitioned the first time
// `ComputeAsync` is invoked; every subsequent invocation calls each
// of the function shards yielded by partitioning.
@@ -67,16 +70,35 @@ class PartitionedCallOp : public AsyncOpKernel {
// via, e.g., virtual device annotations and a list of device names supplied
// through an attribute.
//
- // TODO(akshayka): Lift the constraint pinning inputs and outputs to the
- // local device.
- //
// TODO(akshayka): Add a fastpath for functions that execute on a single
// device.
{
mutex_lock l(mu_);
- if (!partitioned_) {
- auto graph = tensorflow::MakeUnique<Graph>(OpRegistry::Global());
- OP_REQUIRES_OK_ASYNC(ctx, GetGraphFromFunction(lib, graph.get()), done);
+ if (function_handles_.find(lib) == function_handles_.end()) {
+ if (local_device_name_.empty()) {
+ // The full local device name isn't known at kernel construction
+ // time, hence the need to set it here.
+ local_device_name_ = lib->device()->name();
+ }
+
+ // TODO(b/37549631): Because this kernel may correspond to a stateful
+ // op, it may be shared by multiple subgraphs, which in turn may have
+ // different `FunctionLibraryRuntime` objects and therefore different
+ // `FHandle` namespaces. As such, we partition on a per-FLR basis.
+ FunctionLibraryRuntime::InstantiateOptions opts;
+ FHandle handle;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts,
+ &handle),
+ done);
+ const FunctionBody* fbody = lib->GetFunctionBody(handle);
+ OP_REQUIRES_ASYNC(ctx, fbody != nullptr,
+ errors::Internal("Could not find handle ", handle),
+ done);
+ auto graph = tensorflow::MakeUnique<Graph>(fbody->graph->flib_def());
+ CopyGraph(*fbody->graph, graph.get());
+ OP_REQUIRES_OK_ASYNC(ctx, PinResourceArgs(graph.get(), args), done);
DeviceSet device_set;
for (auto d : lib->device_mgr()->ListDevices()) {
@@ -94,9 +116,14 @@ class PartitionedCallOp : public AsyncOpKernel {
// an OpKernel, so functions are instantiated in an overlay library.
overlay_lib_.reset(new FunctionLibraryDefinition(
*lib->GetFunctionLibraryDefinition()));
+ auto handles = tensorflow::MakeUnique<gtl::FlatMap<string, FHandle>>();
for (const auto& pair : subgraphs) {
+ // TODO(akshayka): Fail gracefully if the set of devices corresponds
+ // to more than one address space.
const string& target = pair.first;
const auto& subgraph = pair.second;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, UpdateArgAndRetMetadata(target, subgraph.get()), done);
FunctionDef shard;
string unique_name = UniquifyFunctionName(func_.name());
OP_REQUIRES_OK_ASYNC(
@@ -111,40 +138,52 @@ class PartitionedCallOp : public AsyncOpKernel {
lib->Instantiate(unique_name, AttrSlice(&shard.attr()), opts,
&handle),
done);
- function_handles_.emplace(target, handle);
+ handles->emplace(target, handle);
}
- partitioned_ = true;
+
+ function_handles_.emplace(lib, std::move(handles));
}
}
- ExecuteFunctions(lib, ctx, std::move(done));
+ ExecuteFunctions(lib, ctx, args, std::move(done));
}
private:
typedef std::pair<string, FHandle> DeviceAndFHandle;
+ typedef std::pair<std::vector<int>, std::vector<int>> ArgAndRetIndices;
+ typedef std::pair<std::vector<AllocatorAttributes>,
+ std::vector<AllocatorAttributes>>
+ ArgAndRetAllocAttrs;
- // `func_` encapsulates the original, unsharded function.
- // Copies the graph backing `func_` into `*graph`, pinning the input and
- // output nodes to the local device.
+ // Pins each arg that emits a `DT_RESOURCE` tensor to the device on which the
+ // corresponding resource lives. This ensures that the Placer assigns ops that
+ // access these resources to the appropriate devices. This method throws an
+ // error if any two resource inputs live on different devices.
//
- // `*graph` must be a freshly allocated graph.
- Status GetGraphFromFunction(FunctionLibraryRuntime* lib, Graph* graph) {
- FunctionLibraryRuntime::InstantiateOptions opts;
- FHandle handle;
- TF_RETURN_IF_ERROR(lib->Instantiate(func_.name(), AttrSlice(&func_.attr()),
- opts, &handle));
- const FunctionBody* fbody = lib->GetFunctionBody(handle);
- if (fbody == nullptr) {
- return errors::Internal("Could not find handle ", handle);
- }
- CopyGraph(*fbody->graph, graph);
-
- // Pin the inputs and outputs to the local device to simplify the
- // function-dispatching logic.
+ // TODO(akshayka): Remove the single-device constraint once we have a
+ // mechanism for telling the Placer that an op supports heterogeneous
+ // devices among its input resources.
+ Status PinResourceArgs(Graph* graph, const OpInputList& args) {
+ string device;
for (Node* node : graph->op_nodes()) {
string node_type = node->type_string();
- if (node_type == FunctionLibraryDefinition::kArgOp ||
- node_type == FunctionLibraryDefinition::kRetOp) {
- node->set_assigned_device_name(local_device_name_);
+ if (node_type == FunctionLibraryDefinition::kArgOp) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
+ int index = attr_value->i();
+ TF_RETURN_IF_ERROR(node->attrs().Find("T", &attr_value));
+ DataType dtype = attr_value->type();
+ if (dtype == DT_RESOURCE) {
+ ResourceHandle handle = args[index].flat<ResourceHandle>()(0);
+ const string& handle_device = handle.device();
+ if (device.empty()) {
+ device = handle_device;
+ } else if (device != handle_device) {
+ return errors::Internal(
+ "Resources must reside on a single device; observed devices ",
+ device, " and ", handle_device);
+ }
+ node->set_assigned_device_name(handle_device);
+ }
}
}
return Status::OK();
@@ -198,9 +237,104 @@ class PartitionedCallOp : public AsyncOpKernel {
return Status::OK();
}
- // Executes the partitioned functions.
+ // Each subgraph produced by partitioning the function body contains a subset
+ // of the original `Arg` and `Retval` nodes. This function performs
+ // bookkeeping to track which `Arg` and `Retval` nodes were placed on a
+ // particular device / subgraph.
+ //
+ // More specifically, this function
+ // (1) rewrites the indices of the `Arg` and `Retval` nodes placed on a
+ // particular device,
+ // (2) records the subsets of `Arg` and `Retval` nodes assigned to the
+ // device, and
+ // (3) records which `Arg` and `Retval` nodes live in host memory.
+ Status UpdateArgAndRetMetadata(const string& device, Graph* subgraph) {
+ if (arg_and_ret_indices_.find(device) != arg_and_ret_indices_.end()) {
+ // This function has already been partitioned, albeit for a different
+ // function library.
+ return Status::OK();
+ }
+
+ ArgAndRetIndices indices;
+ std::vector<int>* arg_indices = &indices.first;
+ std::vector<int>* ret_indices = &indices.second;
+ std::vector<std::pair<Node*, int>> arg_nodes;
+ std::vector<std::pair<Node*, int>> ret_nodes;
+ const AttrValue* attr_value;
+
+ for (Node* node : subgraph->op_nodes()) {
+ string node_type = node->type_string();
+ if (node_type == FunctionLibraryDefinition::kArgOp) {
+ TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
+ int index = attr_value->i();
+ arg_indices->push_back(index);
+ arg_nodes.push_back(std::make_pair(node, index));
+ } else if (node_type == FunctionLibraryDefinition::kRetOp) {
+ TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
+ int index = attr_value->i();
+ ret_indices->push_back(index);
+ ret_nodes.push_back(std::make_pair(node, index));
+ }
+ }
+
+ auto sort_by_index = [](std::pair<Node*, int> one,
+ std::pair<Node*, int> two) -> bool {
+ return one.second < two.second;
+ };
+ std::sort(arg_nodes.begin(), arg_nodes.end(), sort_by_index);
+ std::sort(ret_nodes.begin(), ret_nodes.end(), sort_by_index);
+ for (int i = 0; i < arg_nodes.size(); ++i) {
+ Node* arg = arg_nodes[i].first;
+ arg->AddAttr("index", i);
+ TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
+ AllocatorAttributes alloc_attr;
+ DataType type = attr_value->type();
+ if (MTypeFromDType(type) == HOST_MEMORY) {
+ alloc_attr.set_on_host(true);
+ }
+ arg_and_ret_alloc_attrs_[device].first.push_back(alloc_attr);
+ }
+ for (int i = 0; i < ret_nodes.size(); ++i) {
+ Node* ret = ret_nodes[i].first;
+ ret->AddAttr("index", i);
+ TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
+ AllocatorAttributes alloc_attr;
+ DataType type = attr_value->type();
+ if (MTypeFromDType(type) == HOST_MEMORY) {
+ alloc_attr.set_on_host(true);
+ }
+ arg_and_ret_alloc_attrs_[device].second.push_back(alloc_attr);
+ }
+
+ arg_and_ret_indices_.emplace(device, indices);
+ return Status::OK();
+ }
+
+ std::vector<Tensor> GetArgsForIndices(const std::vector<int>& indices,
+ const OpInputList& arguments) {
+ std::vector<Tensor> args;
+ args.reserve(indices.size());
+ for (int i : indices) {
+ args.push_back(arguments[i]);
+ }
+ return args;
+ }
+
void ExecuteFunctions(FunctionLibraryRuntime* lib, OpKernelContext* ctx,
- DoneCallback done) LOCKS_EXCLUDED(mu_) {
+ const OpInputList& op_args, DoneCallback done)
+ LOCKS_EXCLUDED(mu_) {
+ const gtl::FlatMap<string, FHandle>* handles;
+ {
+ mutex_lock l(mu_);
+ handles = function_handles_[lib].get();
+ }
+ if (handles->empty()) {
+ // Trivial case where the function body is empty.
+ ctx->SetStatus(Status::OK());
+ done();
+ return;
+ }
+
FunctionLibraryRuntime::Options opts;
opts.step_id = ctx->step_id();
opts.step_container = ctx->step_container();
@@ -215,11 +349,6 @@ class PartitionedCallOp : public AsyncOpKernel {
Rendezvous* rendez = new IntraProcessRendezvous(lib->device_mgr());
opts.rendezvous = rendez;
- OpInputList arguments;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
- // Dummy args vector for the remote shards, which do not have inputs.
- std::vector<Tensor> dummy_args;
-
StatusCallback callback = std::bind(
[](Rendezvous* rendez, DoneCallback& done, const Status& status) {
rendez->Unref();
@@ -227,48 +356,62 @@ class PartitionedCallOp : public AsyncOpKernel {
},
rendez, std::move(done), std::placeholders::_1);
auto* refcounted_done = new ReffedStatusCallback(std::move(callback));
- for (int i = 1; i < function_handles_.size(); ++i) {
+ for (int i = 1; i < handles->size(); ++i) {
refcounted_done->Ref();
}
- for (const auto& pair : function_handles_) {
- const string& target_device = pair.first;
+ for (const auto& pair : *handles) {
+ const string& target = pair.first;
FHandle handle = pair.second;
- VLOG(3) << "Running function shard on device " << target_device;
- if (target_device == local_device_name_) {
+ VLOG(3) << "Running function shard on device " << target;
+ ArgAndRetIndices indices = arg_and_ret_indices_[target];
+ ArgAndRetAllocAttrs alloc_attrs = arg_and_ret_alloc_attrs_[target];
+ const std::vector<int>& arg_indices = indices.first;
+ const std::vector<int>& ret_indices = indices.second;
+ opts.args_alloc_attrs = alloc_attrs.first;
+ opts.rets_alloc_attrs = alloc_attrs.second;
+ if (target == local_device_name_) {
opts.remote_execution = false;
- std::vector<Tensor> args;
- args.reserve(arguments.size());
- for (const Tensor& argument : arguments) {
- args.push_back(argument);
- }
- auto* rets = new std::vector<Tensor>;
- lib->Run(opts, handle, args, rets,
- [rets, refcounted_done, ctx](const Status& status) {
- if (!status.ok()) {
- ctx->SetStatus(status);
- } else {
- for (int i = 0; i < rets->size(); ++i) {
- ctx->set_output(i, (*rets)[i]);
- }
- }
- delete rets;
- refcounted_done->Unref();
- });
+ std::vector<Tensor> args = GetArgsForIndices(arg_indices, op_args);
+ std::vector<Tensor>* rets = new std::vector<Tensor>;
+ lib->Run(
+ opts, handle, args, rets,
+ [rets, ret_indices, refcounted_done, ctx](const Status& status) {
+ if (!status.ok()) {
+ VLOG(3) << "Local execution failed: " << status;
+ ctx->SetStatus(status);
+ } else {
+ for (int i = 0; i < rets->size(); ++i) {
+ ctx->set_output(ret_indices[i], (*rets)[i]);
+ }
+ }
+ delete rets;
+ VLOG(3) << "Finished local execution.";
+ refcounted_done->Unref();
+ });
} else {
opts.remote_execution = true;
- std::vector<Tensor>* dummy_rets = new std::vector<Tensor>;
- lib->Run(opts, handle, dummy_args, dummy_rets,
- [dummy_rets, refcounted_done, ctx](const Status& status) {
- if (!status.ok()) {
- ctx->SetStatus(status);
- }
- delete dummy_rets;
- refcounted_done->Unref();
- });
+ std::vector<Tensor> args = GetArgsForIndices(arg_indices, op_args);
+ std::vector<Tensor>* rets = new std::vector<Tensor>;
+ lib->Run(
+ opts, handle, args, rets,
+ [rets, ret_indices, refcounted_done, ctx](const Status& status) {
+ if (!status.ok()) {
+ VLOG(3) << "Remote execution failed: " << status;
+ ctx->SetStatus(status);
+ } else {
+ for (int i = 0; i < rets->size(); ++i) {
+ ctx->set_output(ret_indices[i], (*rets)[i]);
+ }
+ }
+ delete rets;
+ VLOG(3) << "Finished remote execution.";
+ refcounted_done->Unref();
+ });
}
}
}
+
string UniquifyFunctionName(const string& name) {
for (;; ++suffix_) {
const string candidate = strings::StrCat(name, "_", suffix_);
@@ -279,26 +422,40 @@ class PartitionedCallOp : public AsyncOpKernel {
}
NameAttrList func_;
- const string local_device_name_;
+ string local_device_name_;
// Function shards are added to `overlay_lib_`.
std::unique_ptr<FunctionLibraryDefinition> overlay_lib_;
- // A map from device names to handles of function shards; this map is
- // read-only after the first execution of the OpKernel.
- gtl::FlatMap<string, FHandle> function_handles_;
+ // Contains maps from device names to handles of function shards, keyed by
+ // FunctionLibraryRuntime pointers. (Because this kernel may be instantiated
+ // for a stateful op, different invocations of it may use different FLRs.)
+ gtl::FlatMap<FunctionLibraryRuntime*,
+ std::unique_ptr<gtl::FlatMap<string, FHandle>>>
+ function_handles_ GUARDED_BY(mu_);
+ // Map from device name to the indices of the arguments and return values
+ // placed on that device. Read-only after the first invocation.
+ gtl::FlatMap<string, ArgAndRetIndices> arg_and_ret_indices_;
+ // Map from device name to alloc attrs for arguments and return values of the
+ // function placed on that device. Read-only after the first invocation.
+ gtl::FlatMap<string, ArgAndRetAllocAttrs> arg_and_ret_alloc_attrs_;
mutex mu_;
- bool partitioned_ GUARDED_BY(mu_) = false;
// Used to uniquify function names in `overlay_lib_`.
uint32 suffix_ = 0;
};
REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_CPU),
PartitionedCallOp);
+REGISTER_KERNEL_BUILDER(Name("StatefulPartitionedCall").Device(DEVICE_CPU),
+ PartitionedCallOp);
REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_GPU),
PartitionedCallOp);
+REGISTER_KERNEL_BUILDER(Name("StatefulPartitionedCall").Device(DEVICE_GPU),
+ PartitionedCallOp);
#if TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_SYCL),
PartitionedCallOp);
+REGISTER_KERNEL_BUILDER(Name("StatefulPartitionedCall").Device(DEVICE_SYCL),
+ PartitionedCallOp);
#endif // TENSORFLOW_USE_SYCL
} // namespace