diff options
Diffstat (limited to 'tensorflow/core/kernels/partitioned_function_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/partitioned_function_ops.cc | 296 |
1 files changed, 220 insertions, 76 deletions
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index b6ee808091..a7a9609c21 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_partition.h" @@ -42,8 +43,7 @@ namespace { // TODO(akshayka): Support distributed execution. class PartitionedCallOp : public AsyncOpKernel { public: - explicit PartitionedCallOp(OpKernelConstruction* ctx) - : AsyncOpKernel(ctx), local_device_name_(ctx->device()->name()) { + explicit PartitionedCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); } @@ -55,6 +55,9 @@ class PartitionedCallOp : public AsyncOpKernel { errors::Internal("No function library is provided."), done); + OpInputList args; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &args), done); + // The function body's graph is placed and partitioned the first time // `ComputeAsync` is invoked; every subsequent invocation calls each // of the function shards yielded by partitioning. @@ -67,16 +70,35 @@ class PartitionedCallOp : public AsyncOpKernel { // via, e.g., virtual device annotations and a list of device names supplied // through an attribute. // - // TODO(akshayka): Lift the constraint pinning inputs and outputs to the - // local device. - // // TODO(akshayka): Add a fastpath for functions that execute on a single // device. { mutex_lock l(mu_); - if (!partitioned_) { - auto graph = tensorflow::MakeUnique<Graph>(OpRegistry::Global()); - OP_REQUIRES_OK_ASYNC(ctx, GetGraphFromFunction(lib, graph.get()), done); + if (function_handles_.find(lib) == function_handles_.end()) { + if (local_device_name_.empty()) { + // The full local device name isn't known at kernel construction + // time, hence the need to set it here. + local_device_name_ = lib->device()->name(); + } + + // TODO(b/37549631): Because this kernel may correspond to a stateful + // op, it may be shared by multiple subgraphs, which in turn may have + // different `FunctionLibraryRuntime` objects and therefore different + // `FHandle` namespaces. As such, we partition on a per-FLR basis. + FunctionLibraryRuntime::InstantiateOptions opts; + FHandle handle; + OP_REQUIRES_OK_ASYNC( + ctx, + lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts, + &handle), + done); + const FunctionBody* fbody = lib->GetFunctionBody(handle); + OP_REQUIRES_ASYNC(ctx, fbody != nullptr, + errors::Internal("Could not find handle ", handle), + done); + auto graph = tensorflow::MakeUnique<Graph>(fbody->graph->flib_def()); + CopyGraph(*fbody->graph, graph.get()); + OP_REQUIRES_OK_ASYNC(ctx, PinResourceArgs(graph.get(), args), done); DeviceSet device_set; for (auto d : lib->device_mgr()->ListDevices()) { @@ -94,9 +116,14 @@ class PartitionedCallOp : public AsyncOpKernel { // an OpKernel, so functions are instantiated in an overlay library. overlay_lib_.reset(new FunctionLibraryDefinition( *lib->GetFunctionLibraryDefinition())); + auto handles = tensorflow::MakeUnique<gtl::FlatMap<string, FHandle>>(); for (const auto& pair : subgraphs) { + // TODO(akshayka): Fail gracefully if the set of devices corresponds + // to more than one address space. const string& target = pair.first; const auto& subgraph = pair.second; + OP_REQUIRES_OK_ASYNC( + ctx, UpdateArgAndRetMetadata(target, subgraph.get()), done); FunctionDef shard; string unique_name = UniquifyFunctionName(func_.name()); OP_REQUIRES_OK_ASYNC( @@ -111,40 +138,38 @@ class PartitionedCallOp : public AsyncOpKernel { lib->Instantiate(unique_name, AttrSlice(&shard.attr()), opts, &handle), done); - function_handles_.emplace(target, handle); + handles->emplace(target, handle); } - partitioned_ = true; + + function_handles_.emplace(lib, std::move(handles)); } } - ExecuteFunctions(lib, ctx, std::move(done)); + ExecuteFunctions(lib, ctx, args, std::move(done)); } private: typedef std::pair<string, FHandle> DeviceAndFHandle; + typedef std::pair<std::vector<int>, std::vector<int>> ArgAndRetIndices; + typedef std::pair<std::vector<AllocatorAttributes>, + std::vector<AllocatorAttributes>> + ArgAndRetAllocAttrs; - // `func_` encapsulates the original, unsharded function. - // Copies the graph backing `func_` into `*graph`, pinning the input and - // output nodes to the local device. - // - // `*graph` must be a freshly allocated graph. - Status GetGraphFromFunction(FunctionLibraryRuntime* lib, Graph* graph) { - FunctionLibraryRuntime::InstantiateOptions opts; - FHandle handle; - TF_RETURN_IF_ERROR(lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), - opts, &handle)); - const FunctionBody* fbody = lib->GetFunctionBody(handle); - if (fbody == nullptr) { - return errors::Internal("Could not find handle ", handle); - } - CopyGraph(*fbody->graph, graph); - - // Pin the inputs and outputs to the local device to simplify the - // function-dispatching logic. + // Pins each arg that emits a `DT_RESOURCE` tensor to the device on which the + // corresponding resource lives. This ensures that the Placer assigns ops that + // access these resources to the appropriate devices. + Status PinResourceArgs(Graph* graph, const OpInputList& args) { for (Node* node : graph->op_nodes()) { string node_type = node->type_string(); - if (node_type == FunctionLibraryDefinition::kArgOp || - node_type == FunctionLibraryDefinition::kRetOp) { - node->set_assigned_device_name(local_device_name_); + if (node_type == FunctionLibraryDefinition::kArgOp) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value)); + int index = attr_value->i(); + TF_RETURN_IF_ERROR(node->attrs().Find("T", &attr_value)); + DataType dtype = attr_value->type(); + if (dtype == DT_RESOURCE) { + ResourceHandle handle = args[index].flat<ResourceHandle>()(0); + node->set_assigned_device_name(handle.device()); + } } } return Status::OK(); @@ -198,9 +223,104 @@ class PartitionedCallOp : public AsyncOpKernel { return Status::OK(); } - // Executes the partitioned functions. + // Each subgraph produced by partitioning the function body contains a subset + // of the original `Arg` and `Retval` nodes. This function performs + // bookkeeping to track which `Arg` and `Retval` nodes were placed on a + // particular device / subgraph. + // + // More specifically, this function + // (1) rewrites the indices of the `Arg` and `Retval` nodes placed on a + // particular device, + // (2) records the subsets of `Arg` and `Retval` nodes assigned to the + // device, and + // (3) records which `Arg` and `Retval` nodes live in host memory. + Status UpdateArgAndRetMetadata(const string& device, Graph* subgraph) { + if (arg_and_ret_indices_.find(device) != arg_and_ret_indices_.end()) { + // This function has already been partitioned, albeit for a different + // function library. + return Status::OK(); + } + + ArgAndRetIndices indices; + std::vector<int>* arg_indices = &indices.first; + std::vector<int>* ret_indices = &indices.second; + std::vector<std::pair<Node*, int>> arg_nodes; + std::vector<std::pair<Node*, int>> ret_nodes; + const AttrValue* attr_value; + + for (Node* node : subgraph->op_nodes()) { + string node_type = node->type_string(); + if (node_type == FunctionLibraryDefinition::kArgOp) { + TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value)); + int index = attr_value->i(); + arg_indices->push_back(index); + arg_nodes.push_back(std::make_pair(node, index)); + } else if (node_type == FunctionLibraryDefinition::kRetOp) { + TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value)); + int index = attr_value->i(); + ret_indices->push_back(index); + ret_nodes.push_back(std::make_pair(node, index)); + } + } + + auto sort_by_index = [](std::pair<Node*, int> one, + std::pair<Node*, int> two) -> bool { + return one.second < two.second; + }; + std::sort(arg_nodes.begin(), arg_nodes.end(), sort_by_index); + std::sort(ret_nodes.begin(), ret_nodes.end(), sort_by_index); + for (int i = 0; i < arg_nodes.size(); ++i) { + Node* arg = arg_nodes[i].first; + arg->AddAttr("index", i); + TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value)); + AllocatorAttributes alloc_attr; + DataType type = attr_value->type(); + if (MTypeFromDType(type) == HOST_MEMORY) { + alloc_attr.set_on_host(true); + } + arg_and_ret_alloc_attrs_[device].first.push_back(alloc_attr); + } + for (int i = 0; i < ret_nodes.size(); ++i) { + Node* ret = ret_nodes[i].first; + ret->AddAttr("index", i); + TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value)); + AllocatorAttributes alloc_attr; + DataType type = attr_value->type(); + if (MTypeFromDType(type) == HOST_MEMORY) { + alloc_attr.set_on_host(true); + } + arg_and_ret_alloc_attrs_[device].second.push_back(alloc_attr); + } + + arg_and_ret_indices_.emplace(device, indices); + return Status::OK(); + } + + std::vector<Tensor> GetArgsForIndices(const std::vector<int>& indices, + const OpInputList& arguments) { + std::vector<Tensor> args; + args.reserve(indices.size()); + for (int i : indices) { + args.push_back(arguments[i]); + } + return args; + } + void ExecuteFunctions(FunctionLibraryRuntime* lib, OpKernelContext* ctx, - DoneCallback done) LOCKS_EXCLUDED(mu_) { + const OpInputList& op_args, DoneCallback done) + LOCKS_EXCLUDED(mu_) { + const gtl::FlatMap<string, FHandle>* handles; + { + mutex_lock l(mu_); + handles = function_handles_[lib].get(); + } + if (handles->empty()) { + // Trivial case where the function body is empty. + ctx->SetStatus(Status::OK()); + done(); + return; + } + FunctionLibraryRuntime::Options opts; opts.step_id = ctx->step_id(); opts.step_container = ctx->step_container(); @@ -210,16 +330,12 @@ class PartitionedCallOp : public AsyncOpKernel { // using device-specific threadpools when available. opts.runner = ctx->runner(); opts.source_device = local_device_name_; + opts.allow_dead_tensors = true; // TODO(akshayka): Accommodate the multiple-worker scenario by adding the // constructed rendezvous to a rendezvous manager. Rendezvous* rendez = new IntraProcessRendezvous(lib->device_mgr()); opts.rendezvous = rendez; - OpInputList arguments; - OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done); - // Dummy args vector for the remote shards, which do not have inputs. - std::vector<Tensor> dummy_args; - StatusCallback callback = std::bind( [](Rendezvous* rendez, DoneCallback& done, const Status& status) { rendez->Unref(); @@ -227,48 +343,62 @@ class PartitionedCallOp : public AsyncOpKernel { }, rendez, std::move(done), std::placeholders::_1); auto* refcounted_done = new ReffedStatusCallback(std::move(callback)); - for (int i = 1; i < function_handles_.size(); ++i) { + for (int i = 1; i < handles->size(); ++i) { refcounted_done->Ref(); } - for (const auto& pair : function_handles_) { - const string& target_device = pair.first; + for (const auto& pair : *handles) { + const string& target = pair.first; FHandle handle = pair.second; - VLOG(3) << "Running function shard on device " << target_device; - if (target_device == local_device_name_) { + VLOG(3) << "Running function shard on device " << target; + ArgAndRetIndices indices = arg_and_ret_indices_[target]; + ArgAndRetAllocAttrs alloc_attrs = arg_and_ret_alloc_attrs_[target]; + const std::vector<int>& arg_indices = indices.first; + const std::vector<int>& ret_indices = indices.second; + opts.args_alloc_attrs = alloc_attrs.first; + opts.rets_alloc_attrs = alloc_attrs.second; + if (target == local_device_name_) { opts.remote_execution = false; - std::vector<Tensor> args; - args.reserve(arguments.size()); - for (const Tensor& argument : arguments) { - args.push_back(argument); - } - auto* rets = new std::vector<Tensor>; - lib->Run(opts, handle, args, rets, - [rets, refcounted_done, ctx](const Status& status) { - if (!status.ok()) { - ctx->SetStatus(status); - } else { - for (int i = 0; i < rets->size(); ++i) { - ctx->set_output(i, (*rets)[i]); - } - } - delete rets; - refcounted_done->Unref(); - }); + std::vector<Tensor> args = GetArgsForIndices(arg_indices, op_args); + std::vector<Tensor>* rets = new std::vector<Tensor>; + lib->Run( + opts, handle, args, rets, + [rets, ret_indices, refcounted_done, ctx](const Status& status) { + if (!status.ok()) { + VLOG(3) << "Local execution failed: " << status; + ctx->SetStatus(status); + } else { + for (int i = 0; i < rets->size(); ++i) { + ctx->set_output(ret_indices[i], (*rets)[i]); + } + } + delete rets; + VLOG(3) << "Finished local execution."; + refcounted_done->Unref(); + }); } else { opts.remote_execution = true; - std::vector<Tensor>* dummy_rets = new std::vector<Tensor>; - lib->Run(opts, handle, dummy_args, dummy_rets, - [dummy_rets, refcounted_done, ctx](const Status& status) { - if (!status.ok()) { - ctx->SetStatus(status); - } - delete dummy_rets; - refcounted_done->Unref(); - }); + std::vector<Tensor> args = GetArgsForIndices(arg_indices, op_args); + std::vector<Tensor>* rets = new std::vector<Tensor>; + lib->Run( + opts, handle, args, rets, + [rets, ret_indices, refcounted_done, ctx](const Status& status) { + if (!status.ok()) { + VLOG(3) << "Remote execution failed: " << status; + ctx->SetStatus(status); + } else { + for (int i = 0; i < rets->size(); ++i) { + ctx->set_output(ret_indices[i], (*rets)[i]); + } + } + delete rets; + VLOG(3) << "Finished remote execution."; + refcounted_done->Unref(); + }); } } } + string UniquifyFunctionName(const string& name) { for (;; ++suffix_) { const string candidate = strings::StrCat(name, "_", suffix_); @@ -279,26 +409,40 @@ class PartitionedCallOp : public AsyncOpKernel { } NameAttrList func_; - const string local_device_name_; + string local_device_name_; // Function shards are added to `overlay_lib_`. std::unique_ptr<FunctionLibraryDefinition> overlay_lib_; - // A map from device names to handles of function shards; this map is - // read-only after the first execution of the OpKernel. - gtl::FlatMap<string, FHandle> function_handles_; + // Contains maps from device names to handles of function shards, keyed by + // FunctionLibraryRuntime pointers. (Because this kernel may be instantiated + // for a stateful op, different invocations of it may use different FLRs.) + gtl::FlatMap<FunctionLibraryRuntime*, + std::unique_ptr<gtl::FlatMap<string, FHandle>>> + function_handles_ GUARDED_BY(mu_); + // Map from device name to the indices of the arguments and return values + // placed on that device. Read-only after the first invocation. + gtl::FlatMap<string, ArgAndRetIndices> arg_and_ret_indices_; + // Map from device name to alloc attrs for arguments and return values of the + // function placed on that device. Read-only after the first invocation. + gtl::FlatMap<string, ArgAndRetAllocAttrs> arg_and_ret_alloc_attrs_; mutex mu_; - bool partitioned_ GUARDED_BY(mu_) = false; // Used to uniquify function names in `overlay_lib_`. uint32 suffix_ = 0; }; REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_CPU), PartitionedCallOp); +REGISTER_KERNEL_BUILDER(Name("StatefulPartitionedCall").Device(DEVICE_CPU), + PartitionedCallOp); REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_GPU), PartitionedCallOp); +REGISTER_KERNEL_BUILDER(Name("StatefulPartitionedCall").Device(DEVICE_GPU), + PartitionedCallOp); #if TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_SYCL), PartitionedCallOp); +REGISTER_KERNEL_BUILDER(Name("StatefulPartitionedCall").Device(DEVICE_SYCL), + PartitionedCallOp); #endif // TENSORFLOW_USE_SYCL } // namespace |