aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/collective_ops.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-09 10:56:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-09 10:59:33 -0700
commit6594b9f530ee0a82b61a4b0d2b80c3ced1464fb7 (patch)
tree70930b4c03aee76e57759148c9dcc362c687e55c /tensorflow/core/kernels/collective_ops.cc
parent7576a99c49679dc17ff806acb1a5150f5d16ee58 (diff)
Collective Ops Part 2
Kernel/Op defs for reduction and broadcast. Note that kernels just set up CollectiveParams and don't define detailed algorithms. This change is part of a series of changes introducing infrastructure for collective ops and initial implementations of reduction and broadcast. PiperOrigin-RevId: 192151715
Diffstat (limited to 'tensorflow/core/kernels/collective_ops.cc')
-rw-r--r--tensorflow/core/kernels/collective_ops.cc266
1 files changed, 266 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
new file mode 100644
index 0000000000..5de41bac72
--- /dev/null
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -0,0 +1,266 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+namespace {
+class CollectiveOpKernel : public AsyncOpKernel {
+ public:
+ explicit CollectiveOpKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {}
+
+ // A string encoding instance, frame and iter to be handed off to
+ // the implementation for use in generating RecvBuf keys.
+ string GetCollectiveKey(OpKernelContext* c) {
+ return strings::StrCat(col_params_.instance.instance_key, ":",
+ c->frame_iter().frame_id, ":",
+ c->frame_iter().iter_id);
+ }
+
+ // Returns false if calling invocation of ComputeAsync should return
+ // immediately.
+ bool CanProceedWithCompute(OpKernelContext* c, CollectiveExecutor* col_exec,
+ const DoneCallback& done) {
+ if (col_params_.group.group_size >
+ col_params_.instance.device_names.size()) {
+ // This is the first invocation: Finish initializing col_params_.
+ // Call in a blockable thread because it's not guaranteed that
+ // this call cannot block.
+ c->env()->SchedClosure([this, c, done, col_exec]() {
+ col_exec->CompleteParamsAsync(c->device()->name(), &col_params_,
+ c->cancellation_manager(),
+ [this, c, done](const Status& s) {
+ if (s.ok()) {
+ ComputeAsync(c, done);
+ } else {
+ c->SetStatus(s);
+ done();
+ }
+ });
+ });
+ return false;
+ }
+ return true;
+ }
+
+ CollectiveParams col_params_;
+};
+
+class CollectiveReduceOpKernel : public CollectiveOpKernel {
+ public:
+ explicit CollectiveReduceOpKernel(OpKernelConstruction* c)
+ : CollectiveOpKernel(c) {
+ col_params_.instance.type = REDUCTION_COLLECTIVE;
+ OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
+ OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
+ OP_REQUIRES_OK(
+ c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
+ OP_REQUIRES_OK(
+ c, c->GetAttr("subdiv_offsets",
+ &col_params_.instance.impl_details.subdiv_offsets));
+ string merge_op_name;
+ OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
+ OP_REQUIRES(c, merge_op_name == "Add" || merge_op_name == "Mul",
+ errors::InvalidArgument(
+ "merge_op must be one of {\"Add\", \"Mul\"} but got ",
+ merge_op_name));
+ string final_op_name;
+ OP_REQUIRES_OK(c, c->GetAttr("final_op", &final_op_name));
+ OP_REQUIRES(c, final_op_name == "Id" || final_op_name == "Div",
+ errors::InvalidArgument(
+ "final_op must be one of {\"Id\", \"Div\"} but got ",
+ final_op_name));
+ OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
+
+ const NodeDef& real_node = c->def();
+ col_params_.name = strings::StrCat(real_node.name(), ": Reduce(",
+ merge_op_name, ",", final_op_name, ")");
+ col_params_.group.device_type = c->device_type();
+
+ // Find the OpKernels by name, type and device type.
+ NodeDef sub_node;
+ // The merge_op takes two inputs
+ sub_node.add_input(real_node.input(0));
+ sub_node.add_input(real_node.input(0));
+ sub_node.set_device(real_node.device());
+ SetAttrValue(col_params_.instance.data_type,
+ &(*sub_node.mutable_attr())["T"]);
+ col_params_.merge_op = BuildOpKernel(c, merge_op_name, &sub_node);
+ col_params_.final_op = BuildOpKernel(c, final_op_name, &sub_node);
+ }
+
+ std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c,
+ const string& name,
+ NodeDef* sub_node) {
+ std::unique_ptr<OpKernel> k;
+ if (name.empty() || name == "Id") return k;
+ sub_node->set_name(name);
+ sub_node->set_op(name);
+ Status status;
+ k = CreateOpKernel(c->device_type(), c->device(),
+ c->device()->GetAllocator(AllocatorAttributes()),
+ *sub_node, c->graph_def_version(), &status);
+ if (!status.ok()) {
+ c->CtxFailureWithWarning(errors::Internal("Failed to build OpKernel for ",
+ name, " : ",
+ status.error_message()));
+ }
+ return k;
+ }
+
+ void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+ CollectiveExecutor* col_exec = c->collective_executor();
+ OP_REQUIRES_ASYNC(
+ c, col_exec,
+ errors::Internal(
+ "Failed to get CollectiveExecutor from OpKernelContext for Op ",
+ col_params_.name),
+ done);
+ if (!CanProceedWithCompute(c, col_exec, done)) return;
+ // Allocate the output tensor, trying to reuse the input.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(c,
+ c->forward_input_or_allocate_output(
+ {0}, 0, c->input(0).shape(), &output),
+ done);
+
+ auto actual_done = [c, col_exec, done](const Status& s) {
+ OP_REQUIRES_OK_ASYNC(c, s, done);
+ done();
+ };
+ col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(CollectiveReduceOpKernel);
+};
+
+REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_CPU),
+ CollectiveReduceOpKernel);
+REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_GPU),
+ CollectiveReduceOpKernel);
+
+class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
+ public:
+ explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c)
+ : CollectiveOpKernel(c) {
+ col_params_.instance.type = BROADCAST_COLLECTIVE;
+ OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
+ OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
+ OP_REQUIRES_OK(
+ c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
+ OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
+ OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
+ col_params_.is_source = true;
+ col_params_.instance.impl_details.subdiv_offsets = {0};
+
+ col_params_.name =
+ strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")");
+ col_params_.group.device_type = c->device_type();
+ }
+
+ void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+ CollectiveExecutor* col_exec = c->collective_executor();
+ OP_REQUIRES_ASYNC(
+ c, col_exec,
+ errors::Internal(
+ "Failed to get CollectiveExecutor from OpKernelContext for Op ",
+ col_params_.name),
+ done);
+ if (!CanProceedWithCompute(c, col_exec, done)) return;
+ OP_REQUIRES_ASYNC(
+ c, shape_.IsSameSize(c->input(0).shape()),
+ errors::Internal("Declared shape of op ", col_params_.name,
+ " does not match shape of input"),
+ done);
+ // Allocate the output Tensor, trying to reuse the input.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(
+ c, c->forward_input_or_allocate_output({0}, 0, shape_, &output), done);
+
+ auto actual_done = [c, col_exec, done](const Status& s) {
+ OP_REQUIRES_OK_ASYNC(c, s, done);
+ done();
+ };
+ col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
+ }
+
+ private:
+ TensorShape shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastSendOpKernel);
+};
+
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_CPU),
+ CollectiveBcastSendOpKernel);
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_GPU),
+ CollectiveBcastSendOpKernel);
+
+class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
+ public:
+ explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c)
+ : CollectiveOpKernel(c) {
+ col_params_.instance.type = BROADCAST_COLLECTIVE;
+ OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
+ OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
+ OP_REQUIRES_OK(
+ c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
+ OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
+ OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
+ col_params_.is_source = false;
+ col_params_.instance.impl_details.subdiv_offsets = {0};
+
+ col_params_.name =
+ strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")");
+ col_params_.group.device_type = c->device_type();
+ }
+
+ void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+ CollectiveExecutor* col_exec = c->collective_executor();
+ OP_REQUIRES_ASYNC(
+ c, col_exec,
+ errors::Internal(
+ "Failed to get CollectiveExecutor from OpKernelContext for Op ",
+ col_params_.name),
+ done);
+ if (!CanProceedWithCompute(c, col_exec, done)) return;
+ // No input, so must allocate output.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape_, &output), done);
+
+ auto actual_done = [c, col_exec, done](const Status& s) {
+ OP_REQUIRES_OK_ASYNC(c, s, done);
+ done();
+ };
+ col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
+ }
+
+ private:
+ TensorShape shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastRecvOpKernel);
+};
+
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_CPU),
+ CollectiveBcastRecvOpKernel);
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_GPU),
+ CollectiveBcastRecvOpKernel);
+
+} // namespace
+} // namespace tensorflow