diff options
author | 2018-04-09 10:56:29 -0700 | |
---|---|---|
committer | 2018-04-09 10:59:33 -0700 | |
commit | 6594b9f530ee0a82b61a4b0d2b80c3ced1464fb7 (patch) | |
tree | 70930b4c03aee76e57759148c9dcc362c687e55c /tensorflow/core/kernels/collective_ops.cc | |
parent | 7576a99c49679dc17ff806acb1a5150f5d16ee58 (diff) |
Collective Ops Part 2
Kernel/Op defs for reduction and broadcast.
Note that kernels just set up CollectiveParams and don't
define detailed algorithms.
This change is part of a series of changes introducing infrastructure
for collective ops and initial implementations of reduction and broadcast.
PiperOrigin-RevId: 192151715
Diffstat (limited to 'tensorflow/core/kernels/collective_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/collective_ops.cc | 266 |
1 files changed, 266 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc new file mode 100644 index 0000000000..5de41bac72 --- /dev/null +++ b/tensorflow/core/kernels/collective_ops.cc @@ -0,0 +1,266 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +namespace { +class CollectiveOpKernel : public AsyncOpKernel { + public: + explicit CollectiveOpKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {} + + // A string encoding instance, frame and iter to be handed off to + // the implementation for use in generating RecvBuf keys. + string GetCollectiveKey(OpKernelContext* c) { + return strings::StrCat(col_params_.instance.instance_key, ":", + c->frame_iter().frame_id, ":", + c->frame_iter().iter_id); + } + + // Returns false if calling invocation of ComputeAsync should return + // immediately. + bool CanProceedWithCompute(OpKernelContext* c, CollectiveExecutor* col_exec, + const DoneCallback& done) { + if (col_params_.group.group_size > + col_params_.instance.device_names.size()) { + // This is the first invocation: Finish initializing col_params_. + // Call in a blockable thread because it's not guaranteed that + // this call cannot block. + c->env()->SchedClosure([this, c, done, col_exec]() { + col_exec->CompleteParamsAsync(c->device()->name(), &col_params_, + c->cancellation_manager(), + [this, c, done](const Status& s) { + if (s.ok()) { + ComputeAsync(c, done); + } else { + c->SetStatus(s); + done(); + } + }); + }); + return false; + } + return true; + } + + CollectiveParams col_params_; +}; + +class CollectiveReduceOpKernel : public CollectiveOpKernel { + public: + explicit CollectiveReduceOpKernel(OpKernelConstruction* c) + : CollectiveOpKernel(c) { + col_params_.instance.type = REDUCTION_COLLECTIVE; + OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size)); + OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key)); + OP_REQUIRES_OK( + c, c->GetAttr("instance_key", &col_params_.instance.instance_key)); + OP_REQUIRES_OK( + c, c->GetAttr("subdiv_offsets", + &col_params_.instance.impl_details.subdiv_offsets)); + string merge_op_name; + OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name)); + OP_REQUIRES(c, merge_op_name == "Add" || merge_op_name == "Mul", + errors::InvalidArgument( + "merge_op must be one of {\"Add\", \"Mul\"} but got ", + merge_op_name)); + string final_op_name; + OP_REQUIRES_OK(c, c->GetAttr("final_op", &final_op_name)); + OP_REQUIRES(c, final_op_name == "Id" || final_op_name == "Div", + errors::InvalidArgument( + "final_op must be one of {\"Id\", \"Div\"} but got ", + final_op_name)); + OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type)); + + const NodeDef& real_node = c->def(); + col_params_.name = strings::StrCat(real_node.name(), ": Reduce(", + merge_op_name, ",", final_op_name, ")"); + col_params_.group.device_type = c->device_type(); + + // Find the OpKernels by name, type and device type. + NodeDef sub_node; + // The merge_op takes two inputs + sub_node.add_input(real_node.input(0)); + sub_node.add_input(real_node.input(0)); + sub_node.set_device(real_node.device()); + SetAttrValue(col_params_.instance.data_type, + &(*sub_node.mutable_attr())["T"]); + col_params_.merge_op = BuildOpKernel(c, merge_op_name, &sub_node); + col_params_.final_op = BuildOpKernel(c, final_op_name, &sub_node); + } + + std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c, + const string& name, + NodeDef* sub_node) { + std::unique_ptr<OpKernel> k; + if (name.empty() || name == "Id") return k; + sub_node->set_name(name); + sub_node->set_op(name); + Status status; + k = CreateOpKernel(c->device_type(), c->device(), + c->device()->GetAllocator(AllocatorAttributes()), + *sub_node, c->graph_def_version(), &status); + if (!status.ok()) { + c->CtxFailureWithWarning(errors::Internal("Failed to build OpKernel for ", + name, " : ", + status.error_message())); + } + return k; + } + + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + CollectiveExecutor* col_exec = c->collective_executor(); + OP_REQUIRES_ASYNC( + c, col_exec, + errors::Internal( + "Failed to get CollectiveExecutor from OpKernelContext for Op ", + col_params_.name), + done); + if (!CanProceedWithCompute(c, col_exec, done)) return; + // Allocate the output tensor, trying to reuse the input. + Tensor* output = nullptr; + OP_REQUIRES_OK_ASYNC(c, + c->forward_input_or_allocate_output( + {0}, 0, c->input(0).shape(), &output), + done); + + auto actual_done = [c, col_exec, done](const Status& s) { + OP_REQUIRES_OK_ASYNC(c, s, done); + done(); + }; + col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(CollectiveReduceOpKernel); +}; + +REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_CPU), + CollectiveReduceOpKernel); +REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_GPU), + CollectiveReduceOpKernel); + +class CollectiveBcastSendOpKernel : public CollectiveOpKernel { + public: + explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c) + : CollectiveOpKernel(c) { + col_params_.instance.type = BROADCAST_COLLECTIVE; + OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size)); + OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key)); + OP_REQUIRES_OK( + c, c->GetAttr("instance_key", &col_params_.instance.instance_key)); + OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type)); + OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_)); + col_params_.is_source = true; + col_params_.instance.impl_details.subdiv_offsets = {0}; + + col_params_.name = + strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")"); + col_params_.group.device_type = c->device_type(); + } + + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + CollectiveExecutor* col_exec = c->collective_executor(); + OP_REQUIRES_ASYNC( + c, col_exec, + errors::Internal( + "Failed to get CollectiveExecutor from OpKernelContext for Op ", + col_params_.name), + done); + if (!CanProceedWithCompute(c, col_exec, done)) return; + OP_REQUIRES_ASYNC( + c, shape_.IsSameSize(c->input(0).shape()), + errors::Internal("Declared shape of op ", col_params_.name, + " does not match shape of input"), + done); + // Allocate the output Tensor, trying to reuse the input. + Tensor* output = nullptr; + OP_REQUIRES_OK_ASYNC( + c, c->forward_input_or_allocate_output({0}, 0, shape_, &output), done); + + auto actual_done = [c, col_exec, done](const Status& s) { + OP_REQUIRES_OK_ASYNC(c, s, done); + done(); + }; + col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done); + } + + private: + TensorShape shape_; + + TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastSendOpKernel); +}; + +REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_CPU), + CollectiveBcastSendOpKernel); +REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_GPU), + CollectiveBcastSendOpKernel); + +class CollectiveBcastRecvOpKernel : public CollectiveOpKernel { + public: + explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c) + : CollectiveOpKernel(c) { + col_params_.instance.type = BROADCAST_COLLECTIVE; + OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size)); + OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key)); + OP_REQUIRES_OK( + c, c->GetAttr("instance_key", &col_params_.instance.instance_key)); + OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type)); + OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_)); + col_params_.is_source = false; + col_params_.instance.impl_details.subdiv_offsets = {0}; + + col_params_.name = + strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")"); + col_params_.group.device_type = c->device_type(); + } + + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + CollectiveExecutor* col_exec = c->collective_executor(); + OP_REQUIRES_ASYNC( + c, col_exec, + errors::Internal( + "Failed to get CollectiveExecutor from OpKernelContext for Op ", + col_params_.name), + done); + if (!CanProceedWithCompute(c, col_exec, done)) return; + // No input, so must allocate output. + Tensor* output = nullptr; + OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape_, &output), done); + + auto actual_done = [c, col_exec, done](const Status& s) { + OP_REQUIRES_OK_ASYNC(c, s, done); + done(); + }; + col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done); + } + + private: + TensorShape shape_; + + TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastRecvOpKernel); +}; + +REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_CPU), + CollectiveBcastRecvOpKernel); +REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_GPU), + CollectiveBcastRecvOpKernel); + +} // namespace +} // namespace tensorflow |