aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/nccl
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-15 21:36:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-15 21:41:29 -0700
commitca8af1d0dbb605087a4f8ae076188f2b9a26b1ba (patch)
treefc2b16452cfaba938365c21ba9221ca03953c28a /tensorflow/contrib/nccl
parente30246c49b353b9136f69caef23e7ba0e9df0f0e (diff)
Replace NcclReduce/Broadcast ops during graph optimization so that we can generate gradients for Reduce/Broadcast.
Changing _NcclBroadcastRecv shape input to int32 so that the corresponding Const op is outputting to HostMem. PiperOrigin-RevId: 172279684
Diffstat (limited to 'tensorflow/contrib/nccl')
-rw-r--r--tensorflow/contrib/nccl/BUILD2
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_ops.cc28
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_rewrite.cc271
-rw-r--r--tensorflow/contrib/nccl/ops/nccl_ops.cc84
-rw-r--r--tensorflow/contrib/nccl/python/ops/nccl_ops.py138
-rw-r--r--tensorflow/contrib/nccl/python/ops/nccl_ops_test.py87
6 files changed, 483 insertions, 127 deletions
diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD
index d6508362b8..5e7263ff62 100644
--- a/tensorflow/contrib/nccl/BUILD
+++ b/tensorflow/contrib/nccl/BUILD
@@ -71,10 +71,12 @@ tf_kernel_library(
"kernels/nccl_manager.cc",
"kernels/nccl_manager.h",
"kernels/nccl_ops.cc",
+ "kernels/nccl_rewrite.cc",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:gpu_headers_lib",
+ "//tensorflow/core:proto_text",
"@nccl_archive//:nccl",
],
alwayslink = 1,
diff --git a/tensorflow/contrib/nccl/kernels/nccl_ops.cc b/tensorflow/contrib/nccl/kernels/nccl_ops.cc
index 4eb52492db..266d4f6f0d 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_ops.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_ops.cc
@@ -15,8 +15,6 @@ limitations under the License.
#if GOOGLE_CUDA
-#include <memory>
-#include <unordered_map>
#include <vector>
#include "src/nccl.h"
@@ -24,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
+namespace {
// Base class for all communicator ops that use nccl.
//
@@ -134,7 +133,7 @@ class NcclReduceSendKernel : public NcclReduceOpBase {
compute_stream, &c->input(0), std::move(actual_done));
}
};
-REGISTER_KERNEL_BUILDER(Name("NcclReduceSend").Device(DEVICE_GPU),
+REGISTER_KERNEL_BUILDER(Name("_NcclReduceSend").Device(DEVICE_GPU),
NcclReduceSendKernel);
// To execute a single reduce, this kernel is called once for one devices, and
@@ -166,7 +165,7 @@ class NcclReduceRecvKernel : public NcclReduceOpBase {
private:
ncclRedOp_t reduction_op_;
};
-REGISTER_KERNEL_BUILDER(Name("NcclReduceRecv").Device(DEVICE_GPU),
+REGISTER_KERNEL_BUILDER(Name("_NcclReduceRecv").Device(DEVICE_GPU),
NcclReduceRecvKernel);
// To execute a single broadcast, this kernel is called once for one device, and
@@ -191,7 +190,7 @@ class NcclBroadcastSendKernel : public NcclAsyncOpBase {
std::move(actual_done));
}
};
-REGISTER_KERNEL_BUILDER(Name("NcclBroadcastSend").Device(DEVICE_GPU),
+REGISTER_KERNEL_BUILDER(Name("_NcclBroadcastSend").Device(DEVICE_GPU),
NcclBroadcastSendKernel);
// To execute a single broadcast, this kernel is called once for all but one of
@@ -206,7 +205,7 @@ class NcclBroadcastRecvKernel : public NcclAsyncOpBase {
const Tensor& shape_t = c->input(0);
TensorShape shape;
OP_REQUIRES_OK_ASYNC(
- c, TensorShapeUtils::MakeShape(shape_t.vec<int64>(), &shape), done);
+ c, TensorShapeUtils::MakeShape(shape_t.vec<int32>(), &shape), done);
Tensor* out_t;
OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape, &out_t), done);
@@ -224,9 +223,24 @@ class NcclBroadcastRecvKernel : public NcclAsyncOpBase {
}
};
REGISTER_KERNEL_BUILDER(
- Name("NcclBroadcastRecv").Device(DEVICE_GPU).HostMemory("shape"),
+ Name("_NcclBroadcastRecv").Device(DEVICE_GPU).HostMemory("shape"),
NcclBroadcastRecvKernel);
+// Define stub kernels for the ops that get replaced post placement.
+class NcclStubKernel : public AsyncOpKernel {
+ public:
+ explicit NcclStubKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {}
+ void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+ c->SetStatus(errors::Unimplemented(
+ "This op should be replaced during graph optimization."));
+ done();
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("NcclBroadcast").Device(DEVICE_GPU),
+ NcclStubKernel);
+REGISTER_KERNEL_BUILDER(Name("NcclReduce").Device(DEVICE_GPU), NcclStubKernel);
+
+} // namespace
} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc
new file mode 100644
index 0000000000..94a77c59da
--- /dev/null
+++ b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc
@@ -0,0 +1,271 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include <forward_list>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/graph/node_builder.h"
+
+namespace tensorflow {
+namespace {
+
+// Replaces NcclReduce node with _NcclReduceRecv reusing one input of same
+// device, adds one _NcclReduceSend for each other input.
+Status ReplaceReduce(Graph* graph, Node* node) {
+ string reduction;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "reduction", &reduction));
+ DataType dtype;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype));
+ int num_devices = node->num_inputs();
+ string shared_name = node->name();
+ auto make_builder = [&](StringPiece op_name, StringPiece suffix) {
+ return NodeBuilder(strings::StrCat(shared_name, suffix), op_name)
+ .Attr("reduction", reduction)
+ .Attr("num_devices", num_devices)
+ .Attr("shared_name", shared_name)
+ .Attr("T", dtype);
+ };
+ std::vector<Node*> control_inputs;
+ for (const auto& edge : node->in_edges()) {
+ if (edge->IsControlEdge()) {
+ control_inputs.push_back(edge->src());
+ }
+ }
+ std::vector<NodeBuilder::NodeOut> out_nodes;
+ for (const auto& edge : node->out_edges()) {
+ out_nodes.emplace_back(edge->dst(), edge->dst_input());
+ }
+ int recv_dev = node->assigned_device_name_index();
+ NodeBuilder recv_builder =
+ make_builder("_NcclReduceRecv", "Recv").ControlInputs(control_inputs);
+ bool recv_input_set = false;
+ int send_counter = 0;
+ for (const auto& edge : node->in_edges()) {
+ Node* src_node = edge->src();
+ if (edge->IsControlEdge()) {
+ continue;
+ }
+ int send_dev = src_node->assigned_device_name_index();
+ if (!recv_input_set && send_dev == recv_dev) {
+ recv_builder.Input(src_node);
+ recv_input_set = true;
+ continue;
+ }
+ auto send_builder = make_builder("_NcclReduceSend",
+ strings::StrCat("Send_", ++send_counter))
+ .Input(src_node)
+ .ControlInputs(control_inputs);
+ Node* send_node = nullptr;
+ TF_RETURN_IF_ERROR(send_builder.Finalize(graph, &send_node));
+ send_node->set_assigned_device_name_index(send_dev);
+ // Send nodes don't have any outputs and therefore have no data dependencies
+ // to the outputs of the graph. We add a control dependency to the receive
+ // node so that those 'dangling' nodes are run.
+ // TODO(b/67027412): Avoid these cross-device control edges.
+ for (const auto& out_node : out_nodes) {
+ graph->AddControlEdge(send_node, out_node.node);
+ }
+ }
+ if (!recv_input_set) {
+ return errors::InvalidArgument(
+ "No input tensor uses the same device as the NcclReduce op");
+ }
+ Node* recv_node = nullptr;
+ TF_RETURN_IF_ERROR(recv_builder.Finalize(graph, &recv_node));
+ recv_node->set_assigned_device_name_index(recv_dev);
+ graph->RemoveNode(node);
+ for (const auto& out_node : out_nodes) {
+ if (out_node.index == Graph::kControlSlot) {
+ graph->AddControlEdge(recv_node, out_node.node);
+ } else {
+ graph->AddEdge(recv_node, 0, out_node.node, out_node.index);
+ }
+ }
+ return Status::OK();
+}
+
+TensorProto TensorFromShape(const TensorShapeProto& shape) {
+ TensorProto result;
+ result.set_dtype(DT_INT32);
+ for (const auto& dim : shape.dim()) {
+ result.add_int_val(dim.size());
+ }
+ result.mutable_tensor_shape()->add_dim()->set_size(shape.dim_size());
+ return result;
+}
+
+// Replaces NcclBroadcast node with _NcclBroadcastSend, connects the input to
+// all outputs of same device, adds one _NcclBroadcastRecv for each other output
+// device.
+Status ReplaceBroadcast(Graph* graph, Node* node) {
+ DataType dtype;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype));
+ int send_dev = node->assigned_device_name_index();
+ int num_devices = 0; // Number of distinct devices, incremented below.
+
+ // Map device name index to nodes that take the broadcast as input.
+ std::vector<std::forward_list<NodeBuilder::NodeOut>> out_nodes_map;
+ for (const auto& edge : node->out_edges()) {
+ int dst_dev = edge->IsControlEdge()
+ ? send_dev
+ : edge->dst()->assigned_device_name_index();
+ if (out_nodes_map.size() <= dst_dev) {
+ out_nodes_map.resize(dst_dev + 1);
+ }
+ auto it = out_nodes_map.begin() + dst_dev;
+ if (it->empty()) {
+ ++num_devices;
+ }
+ it->emplace_front(NodeBuilder::NodeOut(edge->dst(), edge->dst_input()));
+ }
+
+ if (num_devices <= 1) {
+ // Only one participating device, skip NCCL op.
+ const Edge* in_edge = nullptr;
+ TF_RETURN_IF_ERROR(node->input_edge(0, &in_edge));
+ Node* in_node = in_edge->src();
+ int in_index = in_edge->src_output();
+ graph->RemoveNode(node);
+ for (const auto& out_nodes : out_nodes_map) {
+ for (const auto& out_node : out_nodes) {
+ if (out_node.index == Graph::kControlSlot) {
+ graph->AddControlEdge(in_node, out_node.node);
+ } else {
+ graph->AddEdge(in_node, in_index, out_node.node, out_node.index);
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ string shared_name = node->name();
+ auto make_builder = [&](StringPiece op_name, StringPiece suffix) {
+ return NodeBuilder(strings::StrCat(shared_name, suffix), op_name)
+ .Attr("num_devices", num_devices)
+ .Attr("shared_name", shared_name)
+ .Attr("T", dtype);
+ };
+
+ // Create broadcast send node and replace the original broadcast node.
+ NodeBuilder::NodeOut in_node;
+ NodeBuilder send_builder = make_builder("_NcclBroadcastSend", "Send");
+ for (const auto& edge : node->in_edges()) {
+ if (edge->IsControlEdge()) {
+ send_builder.ControlInput(edge->src());
+ } else {
+ in_node = NodeBuilder::NodeOut(edge->src(), edge->src_output());
+ send_builder.Input(in_node);
+ }
+ }
+ Node* send_node = nullptr;
+ TF_RETURN_IF_ERROR(send_builder.Finalize(graph, &send_node));
+ send_node->set_assigned_device_name_index(send_dev);
+
+ TensorShapeProto shape_proto;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "shape", &shape_proto));
+
+ // Delete the original node before reconnecting to outputs.
+ graph->RemoveNode(node);
+
+ // Connect all outputs on the device of broadcast send.
+ for (const auto& out_node : out_nodes_map[send_dev]) {
+ if (out_node.index == Graph::kControlSlot) {
+ graph->AddControlEdge(send_node, out_node.node);
+ } else {
+ graph->AddEdge(in_node.node, in_node.index, out_node.node,
+ out_node.index);
+ // Add control edge so send node is run.
+ graph->AddControlEdge(send_node, out_node.node);
+ }
+ }
+ out_nodes_map[send_dev].clear();
+
+ TensorProto tensor_proto = TensorFromShape(shape_proto);
+ bool is_fully_defined = TensorShape(shape_proto).IsFullyDefined();
+ string shape_name = strings::StrCat(in_node.node->name(), "/Shape");
+ Node* shape_node = nullptr;
+ if (!is_fully_defined) {
+ NodeBuilder shape_builder(shape_name, "Shape");
+ shape_builder.Input(in_node).Attr("out_type", DT_INT32).Attr("T", dtype);
+ TF_RETURN_IF_ERROR(shape_builder.Finalize(graph, &shape_node));
+ shape_node->set_assigned_device_name_index(send_dev);
+ }
+
+ // For all other devices, create a broadcast receive and connect outputs.
+ for (int recv_dev = 0; recv_dev < out_nodes_map.size(); ++recv_dev) {
+ if (out_nodes_map[recv_dev].empty()) {
+ continue;
+ }
+ if (is_fully_defined) {
+ // If the shape is fully defined, define one const node per device.
+ NodeBuilder shape_builder(strings::StrCat(shape_name, recv_dev), "Const");
+ shape_builder.Attr("value", tensor_proto).Attr("dtype", DT_INT32);
+ TF_RETURN_IF_ERROR(shape_builder.Finalize(graph, &shape_node));
+ shape_node->set_assigned_device_name_index(recv_dev);
+ }
+ Node* recv_node;
+ TF_RETURN_IF_ERROR(
+ make_builder("_NcclBroadcastRecv", strings::StrCat("Recv_", recv_dev))
+ .Input(shape_node)
+ .Finalize(graph, &recv_node));
+ recv_node->set_assigned_device_name_index(recv_dev);
+ for (const auto& out_node : out_nodes_map[recv_dev]) {
+ graph->AddEdge(recv_node, 0, out_node.node, out_node.index);
+ }
+ }
+
+ return Status::OK();
+}
+
+// Replaces occurrences of Nccl{Reduce, Broadcast}Input/Output with their
+// _Nccl...Send/Recv counterparts and removes data dependencies between them.
+class NcclReplacePass : public GraphOptimizationPass {
+ public:
+ Status Run(const GraphOptimizationPassOptions& options) override {
+ if (options.graph == nullptr) {
+ return Status::OK();
+ }
+ Graph* graph = options.graph->get();
+ if (graph == nullptr) {
+ return errors::Internal(
+ "NCCL replacement should happen before partitioning and a "
+ "graph should be available.");
+ }
+ // Find reduction and broadcast ops and replace them with Send/Recv ops.
+ for (Node* node : graph->op_nodes()) {
+ StringPiece type = node->type_string();
+ if (!type.starts_with("Nccl")) {
+ continue;
+ }
+ if (type == "NcclReduce") {
+ TF_RETURN_IF_ERROR(ReplaceReduce(graph, node));
+ }
+ if (type == "NcclBroadcast") {
+ TF_RETURN_IF_ERROR(ReplaceBroadcast(graph, node));
+ }
+ }
+ return Status::OK();
+ }
+};
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_PLACEMENT, 0,
+ NcclReplacePass);
+
+} // namespace
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/nccl/ops/nccl_ops.cc b/tensorflow/contrib/nccl/ops/nccl_ops.cc
index 532c79c24c..8eb804c2e9 100644
--- a/tensorflow/contrib/nccl/ops/nccl_ops.cc
+++ b/tensorflow/contrib/nccl/ops/nccl_ops.cc
@@ -45,7 +45,28 @@ num_devices: The number of devices participating in this reduction.
shared_name: Identifier that shared between ops of the same reduction.
)doc");
-REGISTER_OP("NcclReduceSend")
+// Note: This op has no kernel implementation, but is replaced by
+// _NcclReduceSend and _NcclReduceRecv during graph optimization stage.
+REGISTER_OP("NcclReduce")
+ .Input("input: num_devices * T")
+ .Output("data: T")
+ .Attr("reduction: {'min', 'max', 'prod', 'sum'}")
+ .Attr("T: {float, float64, int32, int64}")
+ .Attr("num_devices: int")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Reduces `input` from `num_devices` using `reduction` to a single device.
+
+The graph should be constructed so that all inputs have a valid device
+assignment, and the op itself is assigned one of these devices.
+
+input: The input to the reduction.
+data: the value of the reduction across all `num_devices` devices.
+reduction: the reduction operation to perform.
+ )doc");
+
+REGISTER_OP("_NcclReduceSend")
.Input("input: T")
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
.Attr("T: {float, float64, int32, int64}")
@@ -54,19 +75,20 @@ REGISTER_OP("NcclReduceSend")
.SetIsStateful()
.SetShapeFn(shape_inference::NoOutputs)
.Doc(R"doc(
-Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`.
+Replacement node for NcclReduce.
+Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`.
The graph should be constructed so that 'num_devices-1' devices run
-`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value
+`_NcclReduceSend` and one device runs _NcclReduceRecv op with shared_name value
`c`. Failure to do so will cause the graph execution to fail to complete.
-input: The input to the reduction
+input: The input to the reduction.
reduction: the reduction operation to perform.
num_devices: The number of devices participating in this reduction.
shared_name: Identifier that is shared between ops of the same reduce.
)doc");
-REGISTER_OP("NcclReduceRecv")
+REGISTER_OP("_NcclReduceRecv")
.Input("input: T")
.Output("data: T")
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
@@ -76,21 +98,42 @@ REGISTER_OP("NcclReduceRecv")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
+Replacement node for NcclReduce.
+
Reduces 'input' from this op and the NcclReduceSend ops registered in the same
`shared_name`.
-
The graph should be constructed so that 'num_devices-1' devices run
-`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value
+`_NcclReduceSend` and one device runs _NcclReduceRecv op with shared_name value
`c`. Failure to do so will cause the graph execution to fail to complete.
-input: The input to the reduction
+input: The input to the reduction.
data: The reduced data received from this op and the NcclReduceSend op.
reduction: the reduction operation to perform.
num_devices: The number of devices participating in this reduction.
shared_name: Identifier that is shared between ops of the same reduce.
)doc");
-REGISTER_OP("NcclBroadcastSend")
+// Note: This op has no kernel implementation, but is replaced by
+// _NcclBroadcastSend and _NcclBroadcastRecv during graph optimization stage.
+REGISTER_OP("NcclBroadcast")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: {float, float64, int32, int64}")
+ .Attr("shape: shape")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Sends `input` to all devices that are connected to the output.
+
+The graph should be constructed so that all ops connected to the output have a
+valid device assignment, and the op itself is assigned one of these devices.
+
+input: The input to the broadcast.
+output: The same as input.
+shape: The shape of the input tensor.
+ )doc");
+
+REGISTER_OP("_NcclBroadcastSend")
.Input("input: T")
.Attr("T: {float, float64, int32, int64}")
.Attr("num_devices: int")
@@ -98,19 +141,21 @@ REGISTER_OP("NcclBroadcastSend")
.SetIsStateful()
.SetShapeFn(shape_inference::NoOutputs)
.Doc(R"doc(
-Sends `input` to the NcclBroadcastRecv ops registered in the same `shared_name`.
+Replacement node for NcclBroadcast.
-The graph should be constructed so that one device runs `NcclBroadcastSend` and
-`num_devices-1` devices run NcclBroadcastRecv ops with shared_name value `c`.
+Sends `input` to the _NcclBroadcastRecv ops registered in the same
+`shared_name`.
+The graph should be constructed so that one device runs `_NcclBroadcastSend` and
+`num_devices-1` devices run _NcclBroadcastRecv ops with shared_name value `c`.
Failure to do so will cause the graph execution to fail to complete.
-input: The input to the broadcast
+input: The input to the broadcast.
num_devices: The number of devices participating in this reduction.
shared_name: Identifier that is shared between ops of the same broadcast.
)doc");
-REGISTER_OP("NcclBroadcastRecv")
- .Input("shape: int64")
+REGISTER_OP("_NcclBroadcastRecv")
+ .Input("shape: int32")
.Output("output: T")
.Attr("T: {float, float64, int32, int64}")
.Attr("num_devices: int")
@@ -123,11 +168,12 @@ REGISTER_OP("NcclBroadcastRecv")
return Status::OK();
})
.Doc(R"doc(
-Sends data of shape `shape` from the NcclBroadcastSend op registered in the
-same `shared_name`.
+Replacement node for NcclBroadcast.
-The graph should be constructed so that one device runs `NcclBroadcastSend` and
-`num_devices-1` devices run NcclBroadcastRecv ops with shared_name value `c`.
+Sends data of shape `shape` from the _NcclBroadcastSend op registered in the
+same `shared_name`.
+The graph should be constructed so that one device runs `_NcclBroadcastSend` and
+`num_devices-1` devices run _NcclBroadcastRecv ops with shared_name value `c`.
Failure to do so will cause the graph execution to fail to complete.
shape: The shape of the output.
diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops.py b/tensorflow/contrib/nccl/python/ops/nccl_ops.py
index 906d9f948a..8dc038b9ac 100644
--- a/tensorflow/contrib/nccl/python/ops/nccl_ops.py
+++ b/tensorflow/contrib/nccl/python/ops/nccl_ops.py
@@ -23,9 +23,7 @@ from tensorflow.contrib.nccl.ops import gen_nccl_ops
from tensorflow.contrib.util import loader
from tensorflow.python.eager import context
from tensorflow.python.framework import device
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
from tensorflow.python.platform import resource_loader
_nccl_ops_so = loader.load_op_library(
@@ -64,13 +62,13 @@ def _all_sum_grad(op, grad):
LookupError: If `reduction` is not `sum`.
"""
if op.get_attr('reduction') != 'sum':
- raise LookupError('No gradient defined for NcclAllReduce except all_sum.')
+ raise LookupError('No gradient defined for NcclAllReduce except sum.')
- _check_device_assignment(grad)
+ _check_device(grad, expected=op.device)
num_devices = op.get_attr('num_devices')
shared_name = op.get_attr('shared_name') + '_grad'
- with ops.device(grad.device):
+ with ops.device(op.device):
return gen_nccl_ops.nccl_all_reduce(
input=grad,
reduction='sum',
@@ -129,7 +127,7 @@ def all_max(tensors):
return _apply_all_reduce('max', tensors)
-def reduce_sum(tensors, dst_device):
+def reduce_sum(tensors):
"""Returns a tensor with the reduce sum across `tensors`.
The computation is done with a reduce operation, so only one tensor is
@@ -138,54 +136,76 @@ def reduce_sum(tensors, dst_device):
Args:
tensors: The input tensors across which to sum; must be assigned
to GPU devices.
- dst_device: The device of the returned tensor.
Returns:
- A tensor containing the sum of the input tensors, with the device of the
- tensor being `dst_device`.
+ A tensor containing the sum of the input tensors.
+
+ Raises:
+ LookupError: If context is not currently using a GPU device.
+ """
+ return _apply_reduce('sum', tensors)
+
+
+@ops.RegisterGradient('NcclReduce')
+def _reduce_sum_grad(op, grad):
+ """The gradients for input `Operation` of `reduce_sum`.
+
+ Args:
+ op: The `sum send` `Operation` that we are differentiating.
+ grad: Gradient with respect to the output of the `reduce_sum` op.
+
+ Returns:
+ The gradient with respect to the input of `reduce_sum` op.
+
+ Raises:
+ LookupError: If the reduction attribute of op is not `sum`.
"""
- return _apply_reduce('sum', tensors, dst_device)
+ if op.get_attr('reduction') != 'sum':
+ raise LookupError('No gradient defined for NcclReduce except sum.')
+ _check_device(grad, expected=op.device)
+ with ops.device(op.device):
+ result = gen_nccl_ops.nccl_broadcast(input=grad, shape=grad.shape)
-def broadcast(src_tensor, dst_devices):
- """Returns a list of tensors on `dst_devices`, each with value `tensor`.
+ return [result] * len(op.inputs)
- The computation is done with a broadcast nccl operation, so if only some of
- the returned tensors and src_tensor are evaluated then the computation will
- hang.
+
+def broadcast(tensor):
+ """Returns a tensor that can be efficiently transferred to other devices.
Args:
- src_tensor: The tensor to send; must be assigned to a GPU device.
- dst_devices: The GPU devices to receive the sent tensor.
+ tensor: The tensor to send; must be assigned to a GPU device.
Returns:
- An `Operation` to send the `src_tensor`, and a list of tensors, each with
- the value of `src_tensor`, where the device of tensor i is `dst_devices[i]`.
+ A tensor with the value of `src_tensor`, which can be used as input to
+ ops on other GPU devices.
"""
- if not dst_devices:
- raise ValueError('Must pass >0 dst_devices to broadcast')
_check_graph_mode()
- _check_device_assignment(src_tensor)
+ _check_device(tensor)
- shape = array_ops.shape(src_tensor, out_type=dtypes.int64)
- num_devices = len(dst_devices) + 1
- shared_name = _get_shared_name()
+ with ops.device(tensor.device):
+ return gen_nccl_ops.nccl_broadcast(input=tensor, shape=tensor.shape)
- with ops.device(src_tensor.device):
- send = gen_nccl_ops.nccl_broadcast_send(
- input=src_tensor, num_devices=num_devices, shared_name=shared_name)
-
- recvs = []
- for d in dst_devices:
- with ops.device(d):
- recvs.append(
- gen_nccl_ops.nccl_broadcast_recv(
- shape=shape,
- T=src_tensor.dtype,
- num_devices=num_devices,
- shared_name=shared_name))
- return send, recvs
+@ops.RegisterGradient('NcclBroadcast')
+def _broadcast_grad(op, accumulated_grad):
+ """The gradients for input `Operation` of `broadcast`.
+
+ Args:
+ op: The `broadcast send` `Operation` that we are differentiating.
+ accumulated_grad: Accumulated gradients with respect to the output of the
+ `broadcast` op.
+
+ Returns:
+ Gradients with respect to the input of `broadcast`.
+ """
+ # Grab inputs of accumulated_grad and replace accumulation with reduce_sum.
+ grads = [t for t in accumulated_grad.op.inputs]
+ for t in grads:
+ _check_device(t)
+
+ with ops.device(op.device):
+ return gen_nccl_ops.nccl_reduce(input=grads, reduction='sum')
def _apply_all_reduce(reduction, tensors):
@@ -198,7 +218,7 @@ def _apply_all_reduce(reduction, tensors):
res = []
for t in tensors:
- _check_device_assignment(t)
+ _check_device(t)
with ops.device(t.device):
res.append(
gen_nccl_ops.nccl_all_reduce(
@@ -210,40 +230,20 @@ def _apply_all_reduce(reduction, tensors):
return res
-def _apply_reduce(reduction, tensors, dst_device):
+def _apply_reduce(reduction, tensors):
"""Helper function for reduce_* functions."""
if not tensors:
raise ValueError('Must pass >0 tensors to reduce operations')
- if not dst_device:
- raise ValueError('Must pass dst_device to reduce operations')
_check_graph_mode()
+ for t in tensors:
+ _check_device(t)
+ result = gen_nccl_ops.nccl_reduce(input=tensors, reduction=reduction)
try:
- recv_index = next(i for i, t in enumerate(tensors)
- if t.device == dst_device)
+ next(t for t in tensors if t.device == result.device)
except StopIteration:
- raise ValueError('One of the tensors must be assigned to dst_device')
- shared_name = _get_shared_name()
-
- sends = []
- for t in tensors[:recv_index] + tensors[recv_index + 1:]:
- _check_device_assignment(t)
- with ops.device(t.device):
- sends.append(
- gen_nccl_ops.nccl_reduce_send(
- input=t,
- reduction=reduction,
- num_devices=len(tensors),
- shared_name=shared_name))
-
- with ops.device(dst_device):
- recv = gen_nccl_ops.nccl_reduce_recv(
- input=tensors[recv_index],
- reduction=reduction,
- num_devices=len(tensors),
- shared_name=shared_name)
-
- return recv, sends
+ raise ValueError('One input tensor must be assigned to current device')
+ return result
_lock = threading.Lock()
@@ -259,9 +259,11 @@ def _get_shared_name():
return 'c%s' % val
-def _check_device_assignment(tensor):
+def _check_device(tensor, expected=None):
if not device.canonical_name(tensor.device):
raise ValueError('Device assignment required for nccl collective ops')
+ if expected and expected != tensor.device:
+ raise ValueError('Expected device %s, got %s' % (expected, tensor.device))
def _check_graph_mode():
diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
index 96d67723a0..255409303a 100644
--- a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
+++ b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
@@ -22,8 +22,10 @@ from functools import partial
import numpy as np
from tensorflow.contrib import nccl
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients
from tensorflow.python.platform import test
@@ -36,27 +38,30 @@ def _DeviceTensors(tensors, devices):
def _NcclAllReduce(nccl_fun, tensors, devices):
- return nccl_fun(_DeviceTensors(tensors, devices)), []
+ return nccl_fun(_DeviceTensors(tensors, devices))
def _NcclReduce(nccl_fun, tensors, devices):
- d_tensors = _DeviceTensors(tensors, devices)
receiver = np.random.randint(0, len(devices))
- received_tensor, send_ops = nccl_fun(d_tensors, devices[receiver])
- return [received_tensor], send_ops
+ with ops.device(devices[receiver]):
+ return [nccl_fun(_DeviceTensors(tensors, devices))]
def _NcclBroadcast(tensors, devices):
sender = np.random.randint(0, len(devices))
- d_tensor = _DeviceTensors(tensors[0:1], devices[sender:sender + 1])[0]
- other_devices = devices[:sender] + devices[sender + 1:]
- send_op, received_tensors = nccl.broadcast(d_tensor, other_devices)
- return received_tensors, [send_op]
+ with ops.device(devices[sender]):
+ tensor = array_ops.identity(tensors[0])
+ broadcast = nccl.broadcast(tensor)
+ return _DeviceTensors([broadcast] * len(devices), devices)
class NcclTestCase(test.TestCase):
- def _Test(self, nccl_reduce, numpy_fn):
+ def _Test(self,
+ nccl_reduce,
+ numpy_fn,
+ device_sets=(['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'],
+ ['/device:GPU:1', '/device:GPU:0'])):
"""Tests that nccl_reduce does the same as reduction with numpy_fn.
Args:
@@ -65,6 +70,7 @@ class NcclTestCase(test.TestCase):
reduction.
numpy_fn: A function taking two tensors and returning the reduction of the
two.
+ device_sets: Tuple of virtual devices to run test on.
"""
if not test.is_gpu_available():
return # Test requires access to a GPU
@@ -74,26 +80,28 @@ class NcclTestCase(test.TestCase):
# same communicator across multiple sessions.
with self.test_session(use_gpu=True) as sess:
- for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'],
- ['/device:GPU:1', '/device:GPU:0']]:
+ for devices in device_sets:
shape = (3, 4)
random = (np.random.random_sample(shape) - .5) * 1024
- tensors = [random.astype(dtype)] * len(devices)
+ tensors = []
+ for _ in devices:
+ tensors.append(random.astype(dtype))
np_ans = tensors[0]
for t in tensors[1:]:
np_ans = numpy_fn(np_ans, t)
- reduce_tensors, reduce_ops = nccl_reduce(tensors, devices)
+ reduce_tensors = nccl_reduce(tensors, devices)
self.assertNotEmpty(reduce_tensors)
# Test shape inference.
for r in reduce_tensors:
self.assertEqual(shape, r.get_shape())
+ result_tensors = [array_ops.identity(t) for t in reduce_tensors]
+
# Test execution and results.
- nccl_results = sess.run(reduce_tensors + reduce_ops)
- for r in nccl_results[:len(reduce_tensors)]:
- self.assertAllClose(r, np_ans)
+ for t in sess.run(result_tensors):
+ self.assertAllClose(t, np_ans)
def _TestGradient(self, nccl_reduce, numpy_fn):
"""Tests the gradient of nccl_reduce.
@@ -106,14 +114,11 @@ class NcclTestCase(test.TestCase):
reduction of the two.
"""
def _Gradient(tensors, devices):
- reduce_tensors, _ = nccl_reduce(tensors, devices)
- tensor_ops = [t.op for t in reduce_tensors]
- d_tensors = _DeviceTensors(tensors, devices)
- grad_tensors = [
- ops.get_gradient_function(op)(op, loss)
- for op, loss in zip(tensor_ops, d_tensors)
- ]
- return grad_tensors, []
+ inputs = [array_ops.placeholder(t.dtype, t.shape) for t in tensors]
+ reduce_tensors = nccl_reduce(inputs, devices)
+ losses = _DeviceTensors(tensors, [t.device for t in reduce_tensors])
+ grads = gradients.gradients(reduce_tensors, inputs, losses)
+ return [g for g in grads if g is not None]
self._Test(_Gradient, numpy_fn)
@@ -142,27 +147,43 @@ class SingleReduceTest(NcclTestCase):
def testSum(self):
self._Test(partial(_NcclReduce, nccl.reduce_sum), lambda x, y: x + y)
+ def testSumGrad(self):
+ self._TestGradient(partial(_NcclReduce, nccl.reduce_sum), lambda x, y: x)
+
class BroadcastTest(NcclTestCase):
def testBroadcast(self):
self._Test(_NcclBroadcast, lambda x, y: x)
+ def testBroadcastSingleDevice(self):
+ # Broadcasts on a single device are removed completely during rewrite.
+ self._Test(_NcclBroadcast, lambda x, y: x,
+ (['/device:GPU:0', '/device:GPU:0']))
+
+ def testBroadcastToCpuError(self):
+ # Broadcasts to CPU is not supported.
+ with self.assertRaisesRegexp(
+ errors.NotFoundError,
+ "No registered '_NcclBroadcastRecv' OpKernel for CPU devices"):
+ self._Test(_NcclBroadcast, lambda x, y: x,
+ (['/device:GPU:0', '/device:CPU:0']))
+
+ def testBroadcastGrad(self):
+ self._TestGradient(_NcclBroadcast, lambda x, y: x + y)
+
class CombinedTest(NcclTestCase):
"""Test all-reduce vs. single-reduce plus broadcast in one session.run."""
- def _combined(self, tensors, devices):
- all_reduce_tensors = _NcclAllReduce(nccl.all_sum, tensors, devices)[0]
- single_reduce_tensors, single_reduce_ops = _NcclReduce(
- nccl.reduce_sum, tensors, devices)
- broadcast_tensors, broadcast_ops = _NcclBroadcast(single_reduce_tensors,
- devices)
- all_tensors = all_reduce_tensors + single_reduce_tensors + broadcast_tensors
- return all_tensors, single_reduce_ops + broadcast_ops
+ def _Combined(self, tensors, devices):
+ all_reduce_tensors = _NcclAllReduce(nccl.all_sum, tensors, devices)
+ single_reduce_tensors = _NcclReduce(nccl.reduce_sum, tensors, devices)
+ broadcast_tensors = _NcclBroadcast(single_reduce_tensors, devices)
+ return all_reduce_tensors + broadcast_tensors
def testCombined(self):
- self._Test(self._combined, lambda x, y: x + y)
+ self._Test(self._Combined, lambda x, y: x + y)
if __name__ == '__main__':