From ca8af1d0dbb605087a4f8ae076188f2b9a26b1ba Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 15 Oct 2017 21:36:48 -0700 Subject: Replace NcclReduce/Broadcast ops during graph optimization so that we can generate gradients for Reduce/Broadcast. Changing _NcclBroadcastRecv shape input to int32 so that the corresponding Const op is outputting to HostMem. PiperOrigin-RevId: 172279684 --- tensorflow/contrib/nccl/BUILD | 2 + tensorflow/contrib/nccl/kernels/nccl_ops.cc | 28 ++- tensorflow/contrib/nccl/kernels/nccl_rewrite.cc | 271 +++++++++++++++++++++ tensorflow/contrib/nccl/ops/nccl_ops.cc | 84 +++++-- tensorflow/contrib/nccl/python/ops/nccl_ops.py | 138 +++++------ .../contrib/nccl/python/ops/nccl_ops_test.py | 87 ++++--- 6 files changed, 483 insertions(+), 127 deletions(-) create mode 100644 tensorflow/contrib/nccl/kernels/nccl_rewrite.cc (limited to 'tensorflow/contrib/nccl') diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD index d6508362b8..5e7263ff62 100644 --- a/tensorflow/contrib/nccl/BUILD +++ b/tensorflow/contrib/nccl/BUILD @@ -71,10 +71,12 @@ tf_kernel_library( "kernels/nccl_manager.cc", "kernels/nccl_manager.h", "kernels/nccl_ops.cc", + "kernels/nccl_rewrite.cc", ], deps = [ "//tensorflow/core:framework", "//tensorflow/core:gpu_headers_lib", + "//tensorflow/core:proto_text", "@nccl_archive//:nccl", ], alwayslink = 1, diff --git a/tensorflow/contrib/nccl/kernels/nccl_ops.cc b/tensorflow/contrib/nccl/kernels/nccl_ops.cc index 4eb52492db..266d4f6f0d 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_ops.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_ops.cc @@ -15,8 +15,6 @@ limitations under the License. #if GOOGLE_CUDA -#include -#include #include #include "src/nccl.h" @@ -24,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +namespace { // Base class for all communicator ops that use nccl. // @@ -134,7 +133,7 @@ class NcclReduceSendKernel : public NcclReduceOpBase { compute_stream, &c->input(0), std::move(actual_done)); } }; -REGISTER_KERNEL_BUILDER(Name("NcclReduceSend").Device(DEVICE_GPU), +REGISTER_KERNEL_BUILDER(Name("_NcclReduceSend").Device(DEVICE_GPU), NcclReduceSendKernel); // To execute a single reduce, this kernel is called once for one devices, and @@ -166,7 +165,7 @@ class NcclReduceRecvKernel : public NcclReduceOpBase { private: ncclRedOp_t reduction_op_; }; -REGISTER_KERNEL_BUILDER(Name("NcclReduceRecv").Device(DEVICE_GPU), +REGISTER_KERNEL_BUILDER(Name("_NcclReduceRecv").Device(DEVICE_GPU), NcclReduceRecvKernel); // To execute a single broadcast, this kernel is called once for one device, and @@ -191,7 +190,7 @@ class NcclBroadcastSendKernel : public NcclAsyncOpBase { std::move(actual_done)); } }; -REGISTER_KERNEL_BUILDER(Name("NcclBroadcastSend").Device(DEVICE_GPU), +REGISTER_KERNEL_BUILDER(Name("_NcclBroadcastSend").Device(DEVICE_GPU), NcclBroadcastSendKernel); // To execute a single broadcast, this kernel is called once for all but one of @@ -206,7 +205,7 @@ class NcclBroadcastRecvKernel : public NcclAsyncOpBase { const Tensor& shape_t = c->input(0); TensorShape shape; OP_REQUIRES_OK_ASYNC( - c, TensorShapeUtils::MakeShape(shape_t.vec(), &shape), done); + c, TensorShapeUtils::MakeShape(shape_t.vec(), &shape), done); Tensor* out_t; OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape, &out_t), done); @@ -224,9 +223,24 @@ class NcclBroadcastRecvKernel : public NcclAsyncOpBase { } }; REGISTER_KERNEL_BUILDER( - Name("NcclBroadcastRecv").Device(DEVICE_GPU).HostMemory("shape"), + Name("_NcclBroadcastRecv").Device(DEVICE_GPU).HostMemory("shape"), NcclBroadcastRecvKernel); +// Define stub kernels for the ops that get replaced post placement. +class NcclStubKernel : public AsyncOpKernel { + public: + explicit NcclStubKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {} + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + c->SetStatus(errors::Unimplemented( + "This op should be replaced during graph optimization.")); + done(); + } +}; +REGISTER_KERNEL_BUILDER(Name("NcclBroadcast").Device(DEVICE_GPU), + NcclStubKernel); +REGISTER_KERNEL_BUILDER(Name("NcclReduce").Device(DEVICE_GPU), NcclStubKernel); + +} // namespace } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc new file mode 100644 index 0000000000..94a77c59da --- /dev/null +++ b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc @@ -0,0 +1,271 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#include +#include + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/graph/node_builder.h" + +namespace tensorflow { +namespace { + +// Replaces NcclReduce node with _NcclReduceRecv reusing one input of same +// device, adds one _NcclReduceSend for each other input. +Status ReplaceReduce(Graph* graph, Node* node) { + string reduction; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "reduction", &reduction)); + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype)); + int num_devices = node->num_inputs(); + string shared_name = node->name(); + auto make_builder = [&](StringPiece op_name, StringPiece suffix) { + return NodeBuilder(strings::StrCat(shared_name, suffix), op_name) + .Attr("reduction", reduction) + .Attr("num_devices", num_devices) + .Attr("shared_name", shared_name) + .Attr("T", dtype); + }; + std::vector control_inputs; + for (const auto& edge : node->in_edges()) { + if (edge->IsControlEdge()) { + control_inputs.push_back(edge->src()); + } + } + std::vector out_nodes; + for (const auto& edge : node->out_edges()) { + out_nodes.emplace_back(edge->dst(), edge->dst_input()); + } + int recv_dev = node->assigned_device_name_index(); + NodeBuilder recv_builder = + make_builder("_NcclReduceRecv", "Recv").ControlInputs(control_inputs); + bool recv_input_set = false; + int send_counter = 0; + for (const auto& edge : node->in_edges()) { + Node* src_node = edge->src(); + if (edge->IsControlEdge()) { + continue; + } + int send_dev = src_node->assigned_device_name_index(); + if (!recv_input_set && send_dev == recv_dev) { + recv_builder.Input(src_node); + recv_input_set = true; + continue; + } + auto send_builder = make_builder("_NcclReduceSend", + strings::StrCat("Send_", ++send_counter)) + .Input(src_node) + .ControlInputs(control_inputs); + Node* send_node = nullptr; + TF_RETURN_IF_ERROR(send_builder.Finalize(graph, &send_node)); + send_node->set_assigned_device_name_index(send_dev); + // Send nodes don't have any outputs and therefore have no data dependencies + // to the outputs of the graph. We add a control dependency to the receive + // node so that those 'dangling' nodes are run. + // TODO(b/67027412): Avoid these cross-device control edges. + for (const auto& out_node : out_nodes) { + graph->AddControlEdge(send_node, out_node.node); + } + } + if (!recv_input_set) { + return errors::InvalidArgument( + "No input tensor uses the same device as the NcclReduce op"); + } + Node* recv_node = nullptr; + TF_RETURN_IF_ERROR(recv_builder.Finalize(graph, &recv_node)); + recv_node->set_assigned_device_name_index(recv_dev); + graph->RemoveNode(node); + for (const auto& out_node : out_nodes) { + if (out_node.index == Graph::kControlSlot) { + graph->AddControlEdge(recv_node, out_node.node); + } else { + graph->AddEdge(recv_node, 0, out_node.node, out_node.index); + } + } + return Status::OK(); +} + +TensorProto TensorFromShape(const TensorShapeProto& shape) { + TensorProto result; + result.set_dtype(DT_INT32); + for (const auto& dim : shape.dim()) { + result.add_int_val(dim.size()); + } + result.mutable_tensor_shape()->add_dim()->set_size(shape.dim_size()); + return result; +} + +// Replaces NcclBroadcast node with _NcclBroadcastSend, connects the input to +// all outputs of same device, adds one _NcclBroadcastRecv for each other output +// device. +Status ReplaceBroadcast(Graph* graph, Node* node) { + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype)); + int send_dev = node->assigned_device_name_index(); + int num_devices = 0; // Number of distinct devices, incremented below. + + // Map device name index to nodes that take the broadcast as input. + std::vector> out_nodes_map; + for (const auto& edge : node->out_edges()) { + int dst_dev = edge->IsControlEdge() + ? send_dev + : edge->dst()->assigned_device_name_index(); + if (out_nodes_map.size() <= dst_dev) { + out_nodes_map.resize(dst_dev + 1); + } + auto it = out_nodes_map.begin() + dst_dev; + if (it->empty()) { + ++num_devices; + } + it->emplace_front(NodeBuilder::NodeOut(edge->dst(), edge->dst_input())); + } + + if (num_devices <= 1) { + // Only one participating device, skip NCCL op. + const Edge* in_edge = nullptr; + TF_RETURN_IF_ERROR(node->input_edge(0, &in_edge)); + Node* in_node = in_edge->src(); + int in_index = in_edge->src_output(); + graph->RemoveNode(node); + for (const auto& out_nodes : out_nodes_map) { + for (const auto& out_node : out_nodes) { + if (out_node.index == Graph::kControlSlot) { + graph->AddControlEdge(in_node, out_node.node); + } else { + graph->AddEdge(in_node, in_index, out_node.node, out_node.index); + } + } + } + return Status::OK(); + } + + string shared_name = node->name(); + auto make_builder = [&](StringPiece op_name, StringPiece suffix) { + return NodeBuilder(strings::StrCat(shared_name, suffix), op_name) + .Attr("num_devices", num_devices) + .Attr("shared_name", shared_name) + .Attr("T", dtype); + }; + + // Create broadcast send node and replace the original broadcast node. + NodeBuilder::NodeOut in_node; + NodeBuilder send_builder = make_builder("_NcclBroadcastSend", "Send"); + for (const auto& edge : node->in_edges()) { + if (edge->IsControlEdge()) { + send_builder.ControlInput(edge->src()); + } else { + in_node = NodeBuilder::NodeOut(edge->src(), edge->src_output()); + send_builder.Input(in_node); + } + } + Node* send_node = nullptr; + TF_RETURN_IF_ERROR(send_builder.Finalize(graph, &send_node)); + send_node->set_assigned_device_name_index(send_dev); + + TensorShapeProto shape_proto; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "shape", &shape_proto)); + + // Delete the original node before reconnecting to outputs. + graph->RemoveNode(node); + + // Connect all outputs on the device of broadcast send. + for (const auto& out_node : out_nodes_map[send_dev]) { + if (out_node.index == Graph::kControlSlot) { + graph->AddControlEdge(send_node, out_node.node); + } else { + graph->AddEdge(in_node.node, in_node.index, out_node.node, + out_node.index); + // Add control edge so send node is run. + graph->AddControlEdge(send_node, out_node.node); + } + } + out_nodes_map[send_dev].clear(); + + TensorProto tensor_proto = TensorFromShape(shape_proto); + bool is_fully_defined = TensorShape(shape_proto).IsFullyDefined(); + string shape_name = strings::StrCat(in_node.node->name(), "/Shape"); + Node* shape_node = nullptr; + if (!is_fully_defined) { + NodeBuilder shape_builder(shape_name, "Shape"); + shape_builder.Input(in_node).Attr("out_type", DT_INT32).Attr("T", dtype); + TF_RETURN_IF_ERROR(shape_builder.Finalize(graph, &shape_node)); + shape_node->set_assigned_device_name_index(send_dev); + } + + // For all other devices, create a broadcast receive and connect outputs. + for (int recv_dev = 0; recv_dev < out_nodes_map.size(); ++recv_dev) { + if (out_nodes_map[recv_dev].empty()) { + continue; + } + if (is_fully_defined) { + // If the shape is fully defined, define one const node per device. + NodeBuilder shape_builder(strings::StrCat(shape_name, recv_dev), "Const"); + shape_builder.Attr("value", tensor_proto).Attr("dtype", DT_INT32); + TF_RETURN_IF_ERROR(shape_builder.Finalize(graph, &shape_node)); + shape_node->set_assigned_device_name_index(recv_dev); + } + Node* recv_node; + TF_RETURN_IF_ERROR( + make_builder("_NcclBroadcastRecv", strings::StrCat("Recv_", recv_dev)) + .Input(shape_node) + .Finalize(graph, &recv_node)); + recv_node->set_assigned_device_name_index(recv_dev); + for (const auto& out_node : out_nodes_map[recv_dev]) { + graph->AddEdge(recv_node, 0, out_node.node, out_node.index); + } + } + + return Status::OK(); +} + +// Replaces occurrences of Nccl{Reduce, Broadcast}Input/Output with their +// _Nccl...Send/Recv counterparts and removes data dependencies between them. +class NcclReplacePass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override { + if (options.graph == nullptr) { + return Status::OK(); + } + Graph* graph = options.graph->get(); + if (graph == nullptr) { + return errors::Internal( + "NCCL replacement should happen before partitioning and a " + "graph should be available."); + } + // Find reduction and broadcast ops and replace them with Send/Recv ops. + for (Node* node : graph->op_nodes()) { + StringPiece type = node->type_string(); + if (!type.starts_with("Nccl")) { + continue; + } + if (type == "NcclReduce") { + TF_RETURN_IF_ERROR(ReplaceReduce(graph, node)); + } + if (type == "NcclBroadcast") { + TF_RETURN_IF_ERROR(ReplaceBroadcast(graph, node)); + } + } + return Status::OK(); + } +}; +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_PLACEMENT, 0, + NcclReplacePass); + +} // namespace +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/nccl/ops/nccl_ops.cc b/tensorflow/contrib/nccl/ops/nccl_ops.cc index 532c79c24c..8eb804c2e9 100644 --- a/tensorflow/contrib/nccl/ops/nccl_ops.cc +++ b/tensorflow/contrib/nccl/ops/nccl_ops.cc @@ -45,7 +45,28 @@ num_devices: The number of devices participating in this reduction. shared_name: Identifier that shared between ops of the same reduction. )doc"); -REGISTER_OP("NcclReduceSend") +// Note: This op has no kernel implementation, but is replaced by +// _NcclReduceSend and _NcclReduceRecv during graph optimization stage. +REGISTER_OP("NcclReduce") + .Input("input: num_devices * T") + .Output("data: T") + .Attr("reduction: {'min', 'max', 'prod', 'sum'}") + .Attr("T: {float, float64, int32, int64}") + .Attr("num_devices: int") + .SetIsStateful() + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Reduces `input` from `num_devices` using `reduction` to a single device. + +The graph should be constructed so that all inputs have a valid device +assignment, and the op itself is assigned one of these devices. + +input: The input to the reduction. +data: the value of the reduction across all `num_devices` devices. +reduction: the reduction operation to perform. + )doc"); + +REGISTER_OP("_NcclReduceSend") .Input("input: T") .Attr("reduction: {'min', 'max', 'prod', 'sum'}") .Attr("T: {float, float64, int32, int64}") @@ -54,19 +75,20 @@ REGISTER_OP("NcclReduceSend") .SetIsStateful() .SetShapeFn(shape_inference::NoOutputs) .Doc(R"doc( -Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`. +Replacement node for NcclReduce. +Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`. The graph should be constructed so that 'num_devices-1' devices run -`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value +`_NcclReduceSend` and one device runs _NcclReduceRecv op with shared_name value `c`. Failure to do so will cause the graph execution to fail to complete. -input: The input to the reduction +input: The input to the reduction. reduction: the reduction operation to perform. num_devices: The number of devices participating in this reduction. shared_name: Identifier that is shared between ops of the same reduce. )doc"); -REGISTER_OP("NcclReduceRecv") +REGISTER_OP("_NcclReduceRecv") .Input("input: T") .Output("data: T") .Attr("reduction: {'min', 'max', 'prod', 'sum'}") @@ -76,21 +98,42 @@ REGISTER_OP("NcclReduceRecv") .SetIsStateful() .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( +Replacement node for NcclReduce. + Reduces 'input' from this op and the NcclReduceSend ops registered in the same `shared_name`. - The graph should be constructed so that 'num_devices-1' devices run -`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value +`_NcclReduceSend` and one device runs _NcclReduceRecv op with shared_name value `c`. Failure to do so will cause the graph execution to fail to complete. -input: The input to the reduction +input: The input to the reduction. data: The reduced data received from this op and the NcclReduceSend op. reduction: the reduction operation to perform. num_devices: The number of devices participating in this reduction. shared_name: Identifier that is shared between ops of the same reduce. )doc"); -REGISTER_OP("NcclBroadcastSend") +// Note: This op has no kernel implementation, but is replaced by +// _NcclBroadcastSend and _NcclBroadcastRecv during graph optimization stage. +REGISTER_OP("NcclBroadcast") + .Input("input: T") + .Output("output: T") + .Attr("T: {float, float64, int32, int64}") + .Attr("shape: shape") + .SetIsStateful() + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Sends `input` to all devices that are connected to the output. + +The graph should be constructed so that all ops connected to the output have a +valid device assignment, and the op itself is assigned one of these devices. + +input: The input to the broadcast. +output: The same as input. +shape: The shape of the input tensor. + )doc"); + +REGISTER_OP("_NcclBroadcastSend") .Input("input: T") .Attr("T: {float, float64, int32, int64}") .Attr("num_devices: int") @@ -98,19 +141,21 @@ REGISTER_OP("NcclBroadcastSend") .SetIsStateful() .SetShapeFn(shape_inference::NoOutputs) .Doc(R"doc( -Sends `input` to the NcclBroadcastRecv ops registered in the same `shared_name`. +Replacement node for NcclBroadcast. -The graph should be constructed so that one device runs `NcclBroadcastSend` and -`num_devices-1` devices run NcclBroadcastRecv ops with shared_name value `c`. +Sends `input` to the _NcclBroadcastRecv ops registered in the same +`shared_name`. +The graph should be constructed so that one device runs `_NcclBroadcastSend` and +`num_devices-1` devices run _NcclBroadcastRecv ops with shared_name value `c`. Failure to do so will cause the graph execution to fail to complete. -input: The input to the broadcast +input: The input to the broadcast. num_devices: The number of devices participating in this reduction. shared_name: Identifier that is shared between ops of the same broadcast. )doc"); -REGISTER_OP("NcclBroadcastRecv") - .Input("shape: int64") +REGISTER_OP("_NcclBroadcastRecv") + .Input("shape: int32") .Output("output: T") .Attr("T: {float, float64, int32, int64}") .Attr("num_devices: int") @@ -123,11 +168,12 @@ REGISTER_OP("NcclBroadcastRecv") return Status::OK(); }) .Doc(R"doc( -Sends data of shape `shape` from the NcclBroadcastSend op registered in the -same `shared_name`. +Replacement node for NcclBroadcast. -The graph should be constructed so that one device runs `NcclBroadcastSend` and -`num_devices-1` devices run NcclBroadcastRecv ops with shared_name value `c`. +Sends data of shape `shape` from the _NcclBroadcastSend op registered in the +same `shared_name`. +The graph should be constructed so that one device runs `_NcclBroadcastSend` and +`num_devices-1` devices run _NcclBroadcastRecv ops with shared_name value `c`. Failure to do so will cause the graph execution to fail to complete. shape: The shape of the output. diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops.py b/tensorflow/contrib/nccl/python/ops/nccl_ops.py index 906d9f948a..8dc038b9ac 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops.py @@ -23,9 +23,7 @@ from tensorflow.contrib.nccl.ops import gen_nccl_ops from tensorflow.contrib.util import loader from tensorflow.python.eager import context from tensorflow.python.framework import device -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops from tensorflow.python.platform import resource_loader _nccl_ops_so = loader.load_op_library( @@ -64,13 +62,13 @@ def _all_sum_grad(op, grad): LookupError: If `reduction` is not `sum`. """ if op.get_attr('reduction') != 'sum': - raise LookupError('No gradient defined for NcclAllReduce except all_sum.') + raise LookupError('No gradient defined for NcclAllReduce except sum.') - _check_device_assignment(grad) + _check_device(grad, expected=op.device) num_devices = op.get_attr('num_devices') shared_name = op.get_attr('shared_name') + '_grad' - with ops.device(grad.device): + with ops.device(op.device): return gen_nccl_ops.nccl_all_reduce( input=grad, reduction='sum', @@ -129,7 +127,7 @@ def all_max(tensors): return _apply_all_reduce('max', tensors) -def reduce_sum(tensors, dst_device): +def reduce_sum(tensors): """Returns a tensor with the reduce sum across `tensors`. The computation is done with a reduce operation, so only one tensor is @@ -138,54 +136,76 @@ def reduce_sum(tensors, dst_device): Args: tensors: The input tensors across which to sum; must be assigned to GPU devices. - dst_device: The device of the returned tensor. Returns: - A tensor containing the sum of the input tensors, with the device of the - tensor being `dst_device`. + A tensor containing the sum of the input tensors. + + Raises: + LookupError: If context is not currently using a GPU device. + """ + return _apply_reduce('sum', tensors) + + +@ops.RegisterGradient('NcclReduce') +def _reduce_sum_grad(op, grad): + """The gradients for input `Operation` of `reduce_sum`. + + Args: + op: The `sum send` `Operation` that we are differentiating. + grad: Gradient with respect to the output of the `reduce_sum` op. + + Returns: + The gradient with respect to the input of `reduce_sum` op. + + Raises: + LookupError: If the reduction attribute of op is not `sum`. """ - return _apply_reduce('sum', tensors, dst_device) + if op.get_attr('reduction') != 'sum': + raise LookupError('No gradient defined for NcclReduce except sum.') + _check_device(grad, expected=op.device) + with ops.device(op.device): + result = gen_nccl_ops.nccl_broadcast(input=grad, shape=grad.shape) -def broadcast(src_tensor, dst_devices): - """Returns a list of tensors on `dst_devices`, each with value `tensor`. + return [result] * len(op.inputs) - The computation is done with a broadcast nccl operation, so if only some of - the returned tensors and src_tensor are evaluated then the computation will - hang. + +def broadcast(tensor): + """Returns a tensor that can be efficiently transferred to other devices. Args: - src_tensor: The tensor to send; must be assigned to a GPU device. - dst_devices: The GPU devices to receive the sent tensor. + tensor: The tensor to send; must be assigned to a GPU device. Returns: - An `Operation` to send the `src_tensor`, and a list of tensors, each with - the value of `src_tensor`, where the device of tensor i is `dst_devices[i]`. + A tensor with the value of `src_tensor`, which can be used as input to + ops on other GPU devices. """ - if not dst_devices: - raise ValueError('Must pass >0 dst_devices to broadcast') _check_graph_mode() - _check_device_assignment(src_tensor) + _check_device(tensor) - shape = array_ops.shape(src_tensor, out_type=dtypes.int64) - num_devices = len(dst_devices) + 1 - shared_name = _get_shared_name() + with ops.device(tensor.device): + return gen_nccl_ops.nccl_broadcast(input=tensor, shape=tensor.shape) - with ops.device(src_tensor.device): - send = gen_nccl_ops.nccl_broadcast_send( - input=src_tensor, num_devices=num_devices, shared_name=shared_name) - - recvs = [] - for d in dst_devices: - with ops.device(d): - recvs.append( - gen_nccl_ops.nccl_broadcast_recv( - shape=shape, - T=src_tensor.dtype, - num_devices=num_devices, - shared_name=shared_name)) - return send, recvs +@ops.RegisterGradient('NcclBroadcast') +def _broadcast_grad(op, accumulated_grad): + """The gradients for input `Operation` of `broadcast`. + + Args: + op: The `broadcast send` `Operation` that we are differentiating. + accumulated_grad: Accumulated gradients with respect to the output of the + `broadcast` op. + + Returns: + Gradients with respect to the input of `broadcast`. + """ + # Grab inputs of accumulated_grad and replace accumulation with reduce_sum. + grads = [t for t in accumulated_grad.op.inputs] + for t in grads: + _check_device(t) + + with ops.device(op.device): + return gen_nccl_ops.nccl_reduce(input=grads, reduction='sum') def _apply_all_reduce(reduction, tensors): @@ -198,7 +218,7 @@ def _apply_all_reduce(reduction, tensors): res = [] for t in tensors: - _check_device_assignment(t) + _check_device(t) with ops.device(t.device): res.append( gen_nccl_ops.nccl_all_reduce( @@ -210,40 +230,20 @@ def _apply_all_reduce(reduction, tensors): return res -def _apply_reduce(reduction, tensors, dst_device): +def _apply_reduce(reduction, tensors): """Helper function for reduce_* functions.""" if not tensors: raise ValueError('Must pass >0 tensors to reduce operations') - if not dst_device: - raise ValueError('Must pass dst_device to reduce operations') _check_graph_mode() + for t in tensors: + _check_device(t) + result = gen_nccl_ops.nccl_reduce(input=tensors, reduction=reduction) try: - recv_index = next(i for i, t in enumerate(tensors) - if t.device == dst_device) + next(t for t in tensors if t.device == result.device) except StopIteration: - raise ValueError('One of the tensors must be assigned to dst_device') - shared_name = _get_shared_name() - - sends = [] - for t in tensors[:recv_index] + tensors[recv_index + 1:]: - _check_device_assignment(t) - with ops.device(t.device): - sends.append( - gen_nccl_ops.nccl_reduce_send( - input=t, - reduction=reduction, - num_devices=len(tensors), - shared_name=shared_name)) - - with ops.device(dst_device): - recv = gen_nccl_ops.nccl_reduce_recv( - input=tensors[recv_index], - reduction=reduction, - num_devices=len(tensors), - shared_name=shared_name) - - return recv, sends + raise ValueError('One input tensor must be assigned to current device') + return result _lock = threading.Lock() @@ -259,9 +259,11 @@ def _get_shared_name(): return 'c%s' % val -def _check_device_assignment(tensor): +def _check_device(tensor, expected=None): if not device.canonical_name(tensor.device): raise ValueError('Device assignment required for nccl collective ops') + if expected and expected != tensor.device: + raise ValueError('Expected device %s, got %s' % (expected, tensor.device)) def _check_graph_mode(): diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py index 96d67723a0..255409303a 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py @@ -22,8 +22,10 @@ from functools import partial import numpy as np from tensorflow.contrib import nccl +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients from tensorflow.python.platform import test @@ -36,27 +38,30 @@ def _DeviceTensors(tensors, devices): def _NcclAllReduce(nccl_fun, tensors, devices): - return nccl_fun(_DeviceTensors(tensors, devices)), [] + return nccl_fun(_DeviceTensors(tensors, devices)) def _NcclReduce(nccl_fun, tensors, devices): - d_tensors = _DeviceTensors(tensors, devices) receiver = np.random.randint(0, len(devices)) - received_tensor, send_ops = nccl_fun(d_tensors, devices[receiver]) - return [received_tensor], send_ops + with ops.device(devices[receiver]): + return [nccl_fun(_DeviceTensors(tensors, devices))] def _NcclBroadcast(tensors, devices): sender = np.random.randint(0, len(devices)) - d_tensor = _DeviceTensors(tensors[0:1], devices[sender:sender + 1])[0] - other_devices = devices[:sender] + devices[sender + 1:] - send_op, received_tensors = nccl.broadcast(d_tensor, other_devices) - return received_tensors, [send_op] + with ops.device(devices[sender]): + tensor = array_ops.identity(tensors[0]) + broadcast = nccl.broadcast(tensor) + return _DeviceTensors([broadcast] * len(devices), devices) class NcclTestCase(test.TestCase): - def _Test(self, nccl_reduce, numpy_fn): + def _Test(self, + nccl_reduce, + numpy_fn, + device_sets=(['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'], + ['/device:GPU:1', '/device:GPU:0'])): """Tests that nccl_reduce does the same as reduction with numpy_fn. Args: @@ -65,6 +70,7 @@ class NcclTestCase(test.TestCase): reduction. numpy_fn: A function taking two tensors and returning the reduction of the two. + device_sets: Tuple of virtual devices to run test on. """ if not test.is_gpu_available(): return # Test requires access to a GPU @@ -74,26 +80,28 @@ class NcclTestCase(test.TestCase): # same communicator across multiple sessions. with self.test_session(use_gpu=True) as sess: - for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'], - ['/device:GPU:1', '/device:GPU:0']]: + for devices in device_sets: shape = (3, 4) random = (np.random.random_sample(shape) - .5) * 1024 - tensors = [random.astype(dtype)] * len(devices) + tensors = [] + for _ in devices: + tensors.append(random.astype(dtype)) np_ans = tensors[0] for t in tensors[1:]: np_ans = numpy_fn(np_ans, t) - reduce_tensors, reduce_ops = nccl_reduce(tensors, devices) + reduce_tensors = nccl_reduce(tensors, devices) self.assertNotEmpty(reduce_tensors) # Test shape inference. for r in reduce_tensors: self.assertEqual(shape, r.get_shape()) + result_tensors = [array_ops.identity(t) for t in reduce_tensors] + # Test execution and results. - nccl_results = sess.run(reduce_tensors + reduce_ops) - for r in nccl_results[:len(reduce_tensors)]: - self.assertAllClose(r, np_ans) + for t in sess.run(result_tensors): + self.assertAllClose(t, np_ans) def _TestGradient(self, nccl_reduce, numpy_fn): """Tests the gradient of nccl_reduce. @@ -106,14 +114,11 @@ class NcclTestCase(test.TestCase): reduction of the two. """ def _Gradient(tensors, devices): - reduce_tensors, _ = nccl_reduce(tensors, devices) - tensor_ops = [t.op for t in reduce_tensors] - d_tensors = _DeviceTensors(tensors, devices) - grad_tensors = [ - ops.get_gradient_function(op)(op, loss) - for op, loss in zip(tensor_ops, d_tensors) - ] - return grad_tensors, [] + inputs = [array_ops.placeholder(t.dtype, t.shape) for t in tensors] + reduce_tensors = nccl_reduce(inputs, devices) + losses = _DeviceTensors(tensors, [t.device for t in reduce_tensors]) + grads = gradients.gradients(reduce_tensors, inputs, losses) + return [g for g in grads if g is not None] self._Test(_Gradient, numpy_fn) @@ -142,27 +147,43 @@ class SingleReduceTest(NcclTestCase): def testSum(self): self._Test(partial(_NcclReduce, nccl.reduce_sum), lambda x, y: x + y) + def testSumGrad(self): + self._TestGradient(partial(_NcclReduce, nccl.reduce_sum), lambda x, y: x) + class BroadcastTest(NcclTestCase): def testBroadcast(self): self._Test(_NcclBroadcast, lambda x, y: x) + def testBroadcastSingleDevice(self): + # Broadcasts on a single device are removed completely during rewrite. + self._Test(_NcclBroadcast, lambda x, y: x, + (['/device:GPU:0', '/device:GPU:0'])) + + def testBroadcastToCpuError(self): + # Broadcasts to CPU is not supported. + with self.assertRaisesRegexp( + errors.NotFoundError, + "No registered '_NcclBroadcastRecv' OpKernel for CPU devices"): + self._Test(_NcclBroadcast, lambda x, y: x, + (['/device:GPU:0', '/device:CPU:0'])) + + def testBroadcastGrad(self): + self._TestGradient(_NcclBroadcast, lambda x, y: x + y) + class CombinedTest(NcclTestCase): """Test all-reduce vs. single-reduce plus broadcast in one session.run.""" - def _combined(self, tensors, devices): - all_reduce_tensors = _NcclAllReduce(nccl.all_sum, tensors, devices)[0] - single_reduce_tensors, single_reduce_ops = _NcclReduce( - nccl.reduce_sum, tensors, devices) - broadcast_tensors, broadcast_ops = _NcclBroadcast(single_reduce_tensors, - devices) - all_tensors = all_reduce_tensors + single_reduce_tensors + broadcast_tensors - return all_tensors, single_reduce_ops + broadcast_ops + def _Combined(self, tensors, devices): + all_reduce_tensors = _NcclAllReduce(nccl.all_sum, tensors, devices) + single_reduce_tensors = _NcclReduce(nccl.reduce_sum, tensors, devices) + broadcast_tensors = _NcclBroadcast(single_reduce_tensors, devices) + return all_reduce_tensors + broadcast_tensors def testCombined(self): - self._Test(self._combined, lambda x, y: x + y) + self._Test(self._Combined, lambda x, y: x + y) if __name__ == '__main__': -- cgit v1.2.3