diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-09-12 02:00:55 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-12 02:04:54 -0700 |
commit | 9b9e54b344803994cfd7997035c62ecbbf24d152 (patch) | |
tree | cc09fbd68d62ade219f993138ce40ce51155ccad /tensorflow/contrib/nccl | |
parent | bc300318e7bb9c2b5f1dcbfdc6f2d97d5279abf8 (diff) |
Adding NCCL sum op, register all_sum gradient.
Streamlining nccl test.
PiperOrigin-RevId: 168347428
Diffstat (limited to 'tensorflow/contrib/nccl')
-rw-r--r-- | tensorflow/contrib/nccl/__init__.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/kernels/nccl_manager.cc | 41 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/kernels/nccl_manager.h | 18 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/kernels/nccl_ops.cc | 96 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/ops/nccl_ops.cc | 45 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/python/ops/nccl_ops.py | 125 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/python/ops/nccl_ops_test.py | 211 |
7 files changed, 417 insertions, 121 deletions
diff --git a/tensorflow/contrib/nccl/__init__.py b/tensorflow/contrib/nccl/__init__.py index d851c522c0..4a682cb703 100644 --- a/tensorflow/contrib/nccl/__init__.py +++ b/tensorflow/contrib/nccl/__init__.py @@ -18,6 +18,7 @@ @@all_min @@all_prod @@all_sum +@@reduce_sum @@broadcast """ @@ -31,6 +32,7 @@ from tensorflow.contrib.nccl.python.ops.nccl_ops import all_min from tensorflow.contrib.nccl.python.ops.nccl_ops import all_prod from tensorflow.contrib.nccl.python.ops.nccl_ops import all_sum from tensorflow.contrib.nccl.python.ops.nccl_ops import broadcast +from tensorflow.contrib.nccl.python.ops.nccl_ops import reduce_sum from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(__name__) diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.cc b/tensorflow/contrib/nccl/kernels/nccl_manager.cc index 42e7789301..4b642f64c1 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_manager.cc @@ -260,7 +260,7 @@ NcclManager::Communicator* NcclManager::GetCommunicator( std::vector<ncclComm_t> nccl_comms(num_devices); auto result = ncclCommInitAll(nccl_comms.data(), num_devices, devices.data()); - CHECK_EQ(result, ncclSuccess); + CHECK_EQ(result, ncclSuccess) << ncclGetErrorString(result); for (int rank = 0; rank < num_devices; ++rank) { members[rank].nccl_comm = nccl_comms[rank]; } @@ -307,6 +307,35 @@ void NcclManager::AddBroadcastRecv( kBroadcast, ncclSum /* unused */); } +void NcclManager::AddReduceSend(int num_devices, const string& key, + ncclRedOp_t reduction_op, + perftools::gputools::StreamExecutor* executor, + int gpu_device_id, EventMgr* event_mgr, + perftools::gputools::Stream* tensor_stream, + const Tensor* in_t, Tensor* temp_t, + DoneCallback done_callback) { + std::unique_ptr<Participant> participant( + new Participant(in_t, temp_t, event_mgr, tensor_stream, executor, + gpu_device_id, std::move(done_callback))); + AddParticipant(num_devices, key, std::move(participant), in_t->dtype(), + kReduce, reduction_op); +} + +void NcclManager::AddReduceRecv(int num_devices, const string& key, + ncclRedOp_t reduction_op, + perftools::gputools::StreamExecutor* executor, + int gpu_device_id, EventMgr* event_mgr, + perftools::gputools::Stream* tensor_stream, + const Tensor* in_t, Tensor* out_t, + DoneCallback done_callback) { + std::unique_ptr<Participant> participant( + new Participant(in_t, out_t, event_mgr, tensor_stream, executor, + gpu_device_id, std::move(done_callback))); + participant->root = true; + AddParticipant(num_devices, key, std::move(participant), in_t->dtype(), + kReduce, reduction_op); +} + void NcclManager::AddParticipant(int num_devices, const string& key, std::unique_ptr<Participant> participant, DataType data_type, @@ -431,6 +460,14 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { collective->root_rank, nccl_comm, *cu_stream); break; } + case kReduce: { + const void* sendbuff = p->in_t->tensor_data().data(); + void* recvbuff = const_cast<char*>(p->out_t->tensor_data().data()); + nccl_result = ncclReduce(sendbuff, recvbuff, p->in_t->NumElements(), + data_type, collective->reduction_op, + collective->root_rank, nccl_comm, *cu_stream); + break; + } } // Run the done_callback when the nccl kernel finishes running. @@ -441,7 +478,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { // Propagate the error, but note that if other members of the collective // did launch their kernels, then they are hanging. collective->participants[rank]->done_callback(errors::Unknown( - "Error invoking AllReduce: ", ncclGetErrorString(nccl_result))); + "Error invoking NCCL: ", ncclGetErrorString(nccl_result))); } // TODO(cwhipkey): use RefCounted after figuring out how to use in a diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.h b/tensorflow/contrib/nccl/kernels/nccl_manager.h index 6e2f8e953a..619a1b69bf 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager.h +++ b/tensorflow/contrib/nccl/kernels/nccl_manager.h @@ -75,10 +75,28 @@ class NcclManager { perftools::gputools::Stream* tensor_stream, Tensor* out_t, DoneCallback done_callback); + // AddReduceSend and AddReduceRecv combine to sent data from all senders + // to one receiver. + void AddReduceSend(int num_devices, const string& key, + ncclRedOp_t reduction_op, + perftools::gputools::StreamExecutor* executor, + int gpu_device_id, EventMgr* event_mgr, + perftools::gputools::Stream* tensor_stream, + const Tensor* in_t, Tensor* temp_t, + DoneCallback done_callback); + void AddReduceRecv(int num_devices, const string& key, + ncclRedOp_t reduction_op, + perftools::gputools::StreamExecutor* executor, + int gpu_device_id, EventMgr* event_mgr, + perftools::gputools::Stream* tensor_stream, + const Tensor* in_t, Tensor* out_t, + DoneCallback done_callback); + private: enum CollectiveType { kAllReduce = 1, kBroadcast = 2, + kReduce = 3, }; struct Collective; struct Communicator; diff --git a/tensorflow/contrib/nccl/kernels/nccl_ops.cc b/tensorflow/contrib/nccl/kernels/nccl_ops.cc index d4455483f7..81cc74416b 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_ops.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #if GOOGLE_CUDA +#include <memory> #include <unordered_map> #include <vector> @@ -58,11 +59,9 @@ class NcclAsyncOpBase : public AsyncOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(NcclAsyncOpBase); }; -// To execute a single all-reduce, this kernel is called once for each of the -// <k> devices in the communicator. -class NcclAllReduceOpKernel : public NcclAsyncOpBase { +class NcclReduceOpBase : public NcclAsyncOpBase { public: - explicit NcclAllReduceOpKernel(OpKernelConstruction* c) : NcclAsyncOpBase(c) { + explicit NcclReduceOpBase(OpKernelConstruction* c) : NcclAsyncOpBase(c) { string reduction; OP_REQUIRES_OK(c, c->GetAttr("reduction", &reduction)); if (reduction == "min") { @@ -79,6 +78,19 @@ class NcclAllReduceOpKernel : public NcclAsyncOpBase { } } + ncclRedOp_t reduction_op() const { return reduction_op_; } + + private: + ncclRedOp_t reduction_op_; +}; + +// To execute a single all-reduce, this kernel is called once for each of the +// <k> devices in the communicator. +class NcclAllReduceOpKernel : public NcclReduceOpBase { + public: + explicit NcclAllReduceOpKernel(OpKernelConstruction* c) + : NcclReduceOpBase(c) {} + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { const Tensor* in_t = &c->input(0); Tensor* out_t; @@ -92,18 +104,81 @@ class NcclAllReduceOpKernel : public NcclAsyncOpBase { auto* compute_stream = c->op_device_context()->stream(); auto* gpu_info = c->device()->tensorflow_gpu_device_info(); NcclManager::instance()->AddToAllReduce( - num_devices(), GetCollectiveKey(c), reduction_op_, + num_devices(), GetCollectiveKey(c), reduction_op(), compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr, - compute_stream, in_t, out_t, actual_done); + compute_stream, in_t, out_t, std::move(actual_done)); + } +}; +REGISTER_KERNEL_BUILDER(Name("NcclAllReduce").Device(DEVICE_GPU), + NcclAllReduceOpKernel); + +// To execute a single reduce, this kernel is called once for all but one of the +// <k> devices in the communicator, and NcclReduceRecvKernel is called once for +// the remaining device. +class NcclReduceSendKernel : public NcclReduceOpBase { + public: + explicit NcclReduceSendKernel(OpKernelConstruction* c) + : NcclReduceOpBase(c) {} + + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + const Tensor& in_t = c->input(0); + std::unique_ptr<Tensor> temp_ptr(new Tensor()); + OP_REQUIRES_OK_ASYNC( + c, c->allocate_temp(in_t.dtype(), in_t.shape(), temp_ptr.get()), done); + Tensor* temp_t = temp_ptr.release(); + + auto actual_done = [c, done, temp_t](Status s) { + delete temp_t; + OP_REQUIRES_OK_ASYNC(c, s, done); + done(); + }; + + auto* compute_stream = c->op_device_context()->stream(); + auto* gpu_info = c->device()->tensorflow_gpu_device_info(); + NcclManager::instance()->AddReduceSend( + num_devices(), GetCollectiveKey(c), reduction_op(), + compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr, + compute_stream, &in_t, temp_t, std::move(actual_done)); + } +}; +REGISTER_KERNEL_BUILDER(Name("NcclReduceSend").Device(DEVICE_GPU), + NcclReduceSendKernel); + +// To execute a single reduce, this kernel is called once for one devices, and +// NcclReduceSendKernel is called for all other <k-1> devices in the +// communicator. +class NcclReduceRecvKernel : public NcclReduceOpBase { + public: + explicit NcclReduceRecvKernel(OpKernelConstruction* c) + : NcclReduceOpBase(c) {} + + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + const Tensor& in_t = c->input(0); + Tensor* out_t; + OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, in_t.shape(), &out_t), done); + + auto actual_done = [c, done](Status s) { + OP_REQUIRES_OK_ASYNC(c, s, done); + done(); + }; + + auto* compute_stream = c->op_device_context()->stream(); + auto* gpu_info = c->device()->tensorflow_gpu_device_info(); + NcclManager::instance()->AddReduceRecv( + num_devices(), GetCollectiveKey(c), reduction_op(), + compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr, + compute_stream, &in_t, out_t, std::move(actual_done)); } private: ncclRedOp_t reduction_op_; }; +REGISTER_KERNEL_BUILDER(Name("NcclReduceRecv").Device(DEVICE_GPU), + NcclReduceRecvKernel); -REGISTER_KERNEL_BUILDER(Name("NcclAllReduce").Device(DEVICE_GPU), - NcclAllReduceOpKernel); - +// To execute a single broadcast, this kernel is called once for one device, and +// NcclBroadcastRecvKernel is called for all other <k-1> devices in the +// communicator. class NcclBroadcastSendKernel : public NcclAsyncOpBase { public: explicit NcclBroadcastSendKernel(OpKernelConstruction* c) @@ -126,6 +201,9 @@ class NcclBroadcastSendKernel : public NcclAsyncOpBase { REGISTER_KERNEL_BUILDER(Name("NcclBroadcastSend").Device(DEVICE_GPU), NcclBroadcastSendKernel); +// To execute a single broadcast, this kernel is called once for all but one of +// the <k> devices in the communicator, and NcclBroadcastSendKernel is called +// once for the remaining device. class NcclBroadcastRecvKernel : public NcclAsyncOpBase { public: explicit NcclBroadcastRecvKernel(OpKernelConstruction* c) diff --git a/tensorflow/contrib/nccl/ops/nccl_ops.cc b/tensorflow/contrib/nccl/ops/nccl_ops.cc index d767636fad..532c79c24c 100644 --- a/tensorflow/contrib/nccl/ops/nccl_ops.cc +++ b/tensorflow/contrib/nccl/ops/nccl_ops.cc @@ -45,6 +45,51 @@ num_devices: The number of devices participating in this reduction. shared_name: Identifier that shared between ops of the same reduction. )doc"); +REGISTER_OP("NcclReduceSend") + .Input("input: T") + .Attr("reduction: {'min', 'max', 'prod', 'sum'}") + .Attr("T: {float, float64, int32, int64}") + .Attr("num_devices: int") + .Attr("shared_name: string") + .SetIsStateful() + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`. + +The graph should be constructed so that 'num_devices-1' devices run +`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value +`c`. Failure to do so will cause the graph execution to fail to complete. + +input: The input to the reduction +reduction: the reduction operation to perform. +num_devices: The number of devices participating in this reduction. +shared_name: Identifier that is shared between ops of the same reduce. + )doc"); + +REGISTER_OP("NcclReduceRecv") + .Input("input: T") + .Output("data: T") + .Attr("reduction: {'min', 'max', 'prod', 'sum'}") + .Attr("T: {float, float64, int32, int64}") + .Attr("num_devices: int") + .Attr("shared_name: string") + .SetIsStateful() + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Reduces 'input' from this op and the NcclReduceSend ops registered in the same +`shared_name`. + +The graph should be constructed so that 'num_devices-1' devices run +`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value +`c`. Failure to do so will cause the graph execution to fail to complete. + +input: The input to the reduction +data: The reduced data received from this op and the NcclReduceSend op. +reduction: the reduction operation to perform. +num_devices: The number of devices participating in this reduction. +shared_name: Identifier that is shared between ops of the same reduce. + )doc"); + REGISTER_OP("NcclBroadcastSend") .Input("input: T") .Attr("T: {float, float64, int32, int64}") diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops.py b/tensorflow/contrib/nccl/python/ops/nccl_ops.py index b31cc53e0a..906d9f948a 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops.py @@ -21,6 +21,7 @@ import threading from tensorflow.contrib.nccl.ops import gen_nccl_ops from tensorflow.contrib.util import loader +from tensorflow.python.eager import context from tensorflow.python.framework import device from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -48,6 +49,35 @@ def all_sum(tensors): return _apply_all_reduce('sum', tensors) +@ops.RegisterGradient('NcclAllReduce') +def _all_sum_grad(op, grad): + """The gradients for `all_sum`. + + Args: + op: The `all_sum` `Operation` that we are differentiating. + grad: Gradient with respect to the output of the `all_sum` op. + + Returns: + The gradient with respect to the output of `all_sum`. + + Raises: + LookupError: If `reduction` is not `sum`. + """ + if op.get_attr('reduction') != 'sum': + raise LookupError('No gradient defined for NcclAllReduce except all_sum.') + + _check_device_assignment(grad) + num_devices = op.get_attr('num_devices') + shared_name = op.get_attr('shared_name') + '_grad' + + with ops.device(grad.device): + return gen_nccl_ops.nccl_all_reduce( + input=grad, + reduction='sum', + num_devices=num_devices, + shared_name=shared_name) + + def all_prod(tensors): """Returns a list of tensors with the all-reduce product across `tensors`. @@ -99,6 +129,24 @@ def all_max(tensors): return _apply_all_reduce('max', tensors) +def reduce_sum(tensors, dst_device): + """Returns a tensor with the reduce sum across `tensors`. + + The computation is done with a reduce operation, so only one tensor is + returned. + + Args: + tensors: The input tensors across which to sum; must be assigned + to GPU devices. + dst_device: The device of the returned tensor. + + Returns: + A tensor containing the sum of the input tensors, with the device of the + tensor being `dst_device`. + """ + return _apply_reduce('sum', tensors, dst_device) + + def broadcast(src_tensor, dst_devices): """Returns a list of tensors on `dst_devices`, each with value `tensor`. @@ -111,50 +159,93 @@ def broadcast(src_tensor, dst_devices): dst_devices: The GPU devices to receive the sent tensor. Returns: - List of tensors, each with the value of `src_tensor`, which the device - of tensor i is `dst_devices[i]`. + An `Operation` to send the `src_tensor`, and a list of tensors, each with + the value of `src_tensor`, where the device of tensor i is `dst_devices[i]`. """ if not dst_devices: raise ValueError('Must pass >0 dst_devices to broadcast') - all_devices = [src_tensor.device] + dst_devices + _check_graph_mode() + _check_device_assignment(src_tensor) + + shape = array_ops.shape(src_tensor, out_type=dtypes.int64) + num_devices = len(dst_devices) + 1 shared_name = _get_shared_name() with ops.device(src_tensor.device): send = gen_nccl_ops.nccl_broadcast_send( - input=src_tensor, num_devices=len(all_devices), shared_name=shared_name) + input=src_tensor, num_devices=num_devices, shared_name=shared_name) - shape_op = array_ops.shape(src_tensor, out_type=dtypes.int64) recvs = [] for d in dst_devices: with ops.device(d): recvs.append( gen_nccl_ops.nccl_broadcast_recv( - shape=shape_op, + shape=shape, T=src_tensor.dtype, - num_devices=len(all_devices), + num_devices=num_devices, shared_name=shared_name)) return send, recvs -def _apply_all_reduce(reduction_op, tensors): +def _apply_all_reduce(reduction, tensors): + """Helper function for all_* functions.""" if not tensors: raise ValueError('Must pass >0 tensors to all reduce operations') + _check_graph_mode() + shared_name = _get_shared_name() res = [] + for t in tensors: - if not device.canonical_name(t.device): - raise ValueError('Device assignment required for nccl collective ops') + _check_device_assignment(t) with ops.device(t.device): res.append( gen_nccl_ops.nccl_all_reduce( - t, - reduction=reduction_op, + input=t, + reduction=reduction, num_devices=len(tensors), shared_name=shared_name)) + return res +def _apply_reduce(reduction, tensors, dst_device): + """Helper function for reduce_* functions.""" + if not tensors: + raise ValueError('Must pass >0 tensors to reduce operations') + if not dst_device: + raise ValueError('Must pass dst_device to reduce operations') + _check_graph_mode() + + try: + recv_index = next(i for i, t in enumerate(tensors) + if t.device == dst_device) + except StopIteration: + raise ValueError('One of the tensors must be assigned to dst_device') + shared_name = _get_shared_name() + + sends = [] + for t in tensors[:recv_index] + tensors[recv_index + 1:]: + _check_device_assignment(t) + with ops.device(t.device): + sends.append( + gen_nccl_ops.nccl_reduce_send( + input=t, + reduction=reduction, + num_devices=len(tensors), + shared_name=shared_name)) + + with ops.device(dst_device): + recv = gen_nccl_ops.nccl_reduce_recv( + input=tensors[recv_index], + reduction=reduction, + num_devices=len(tensors), + shared_name=shared_name) + + return recv, sends + + _lock = threading.Lock() _shared_name_counter = 0 @@ -166,3 +257,13 @@ def _get_shared_name(): val = _shared_name_counter _shared_name_counter += 1 return 'c%s' % val + + +def _check_device_assignment(tensor): + if not device.canonical_name(tensor.device): + raise ValueError('Device assignment required for nccl collective ops') + + +def _check_graph_mode(): + if context.in_eager_mode(): + raise ValueError('Nccl ops are not supported in eager mode') diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py index 1621e9f28e..96d67723a0 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from functools import partial import numpy as np from tensorflow.contrib import nccl @@ -26,58 +27,45 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class AllReduceTest(test.TestCase): +def _DeviceTensors(tensors, devices): + res = [] + for t, d in zip(tensors, devices): + with ops.device(d): + res.append(array_ops.identity(t)) + return res - def testAllReduce(self): - if not test.is_gpu_available(): - return # Test requires access to a GPU - for dtype in [np.float32, np.int32, np.int64, np.float64]: - # Create session inside outer loop to test use of - # same communicator across multiple sessions. - with self.test_session(use_gpu=True) as sess: - self._testSingleAllReduce(sess, dtype, nccl.all_sum, lambda x, y: x + y) - self._testSingleAllReduce(sess, dtype, nccl.all_prod, - lambda x, y: x * y) - self._testSingleAllReduce(sess, dtype, nccl.all_min, np.minimum) - self._testSingleAllReduce(sess, dtype, nccl.all_max, np.maximum) - - def _testSingleAllReduce(self, sess, np_type, nccl_fn, numpy_accumulation_fn): - for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'], - ['/device:GPU:1', '/device:GPU:0']]: - shape = (3, 4) - np_ans = None - tensors = [] - for d in devices: - with ops.device(d): - t = ((np.random.random_sample(shape) - .5) * 1024).astype(np_type) - if np_ans is None: - np_ans = t - else: - np_ans = numpy_accumulation_fn(np_ans, t) - tensors.append(array_ops.identity(t)) - - all_reduce_tensors = nccl_fn(tensors) - - # Test shape inference. - for r in all_reduce_tensors: - self.assertEqual(shape, r.get_shape()) - - # Test execution and results. - nccl_results = sess.run(all_reduce_tensors) - for r in nccl_results: - self.assertAllClose(r, np_ans) +def _NcclAllReduce(nccl_fun, tensors, devices): + return nccl_fun(_DeviceTensors(tensors, devices)), [] - def testErrors(self): - with self.assertRaisesRegexp(ValueError, 'Device assignment required'): - nccl.all_sum([array_ops.identity(np.random.random_sample((3, 4)))]) - with self.assertRaisesRegexp(ValueError, 'Must pass >0 tensors'): - nccl.all_sum([]) +def _NcclReduce(nccl_fun, tensors, devices): + d_tensors = _DeviceTensors(tensors, devices) + receiver = np.random.randint(0, len(devices)) + received_tensor, send_ops = nccl_fun(d_tensors, devices[receiver]) + return [received_tensor], send_ops -class BroadcastTest(test.TestCase): - def testBroadcast(self): +def _NcclBroadcast(tensors, devices): + sender = np.random.randint(0, len(devices)) + d_tensor = _DeviceTensors(tensors[0:1], devices[sender:sender + 1])[0] + other_devices = devices[:sender] + devices[sender + 1:] + send_op, received_tensors = nccl.broadcast(d_tensor, other_devices) + return received_tensors, [send_op] + + +class NcclTestCase(test.TestCase): + + def _Test(self, nccl_reduce, numpy_fn): + """Tests that nccl_reduce does the same as reduction with numpy_fn. + + Args: + nccl_reduce: A function taking a list of tensors and a list of devices, + and returns a list of reduced tensors and a list of ops to perform the + reduction. + numpy_fn: A function taking two tensors and returning the reduction of the + two. + """ if not test.is_gpu_available(): return # Test requires access to a GPU @@ -85,69 +73,96 @@ class BroadcastTest(test.TestCase): # Create session inside outer loop to test use of # same communicator across multiple sessions. with self.test_session(use_gpu=True) as sess: - for devices in [['/device:GPU:1', '/device:GPU:0', '/device:GPU:2'], + + for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'], ['/device:GPU:1', '/device:GPU:0']]: shape = (3, 4) - sender = np.random.randint(0, len(devices) - 1) - with ops.device(devices[sender]): - np_ans = (( - (np.random.random_sample(shape) - .5) * 1024).astype(dtype)) - t = array_ops.identity(np_ans) - other_devices = devices[:sender] + devices[sender + 1:] - send_op, received_tensors = nccl.broadcast(t, other_devices) - - # Verify shape inference. - for r in received_tensors: + random = (np.random.random_sample(shape) - .5) * 1024 + tensors = [random.astype(dtype)] * len(devices) + np_ans = tensors[0] + for t in tensors[1:]: + np_ans = numpy_fn(np_ans, t) + + reduce_tensors, reduce_ops = nccl_reduce(tensors, devices) + self.assertNotEmpty(reduce_tensors) + + # Test shape inference. + for r in reduce_tensors: self.assertEqual(shape, r.get_shape()) - # Run and verify results. - nccl_results = sess.run(received_tensors + [send_op]) - for r in nccl_results[:-1]: + # Test execution and results. + nccl_results = sess.run(reduce_tensors + reduce_ops) + for r in nccl_results[:len(reduce_tensors)]: self.assertAllClose(r, np_ans) + def _TestGradient(self, nccl_reduce, numpy_fn): + """Tests the gradient of nccl_reduce. -class CombinedTest(test.TestCase): - """Tests using a mix of all-reduce ops in one session.run call.""" + Args: + nccl_reduce: A function taking a list of tensors and a list of devices, + and returns a list of reduced tensors and a list of ops to perform the + reduction. + numpy_fn: A function taking two tensors and returning the gradient of the + reduction of the two. + """ + def _Gradient(tensors, devices): + reduce_tensors, _ = nccl_reduce(tensors, devices) + tensor_ops = [t.op for t in reduce_tensors] + d_tensors = _DeviceTensors(tensors, devices) + grad_tensors = [ + ops.get_gradient_function(op)(op, loss) + for op, loss in zip(tensor_ops, d_tensors) + ] + return grad_tensors, [] - def testCombined(self): - if not test.is_gpu_available(): - return # Test requires access to a GPU + self._Test(_Gradient, numpy_fn) - for dtype in [np.float32, np.int32, np.int64, np.float64]: - # Create session inside outer loop to test use of - # same communicator across multiple sessions. - with self.test_session(use_gpu=True) as sess: - for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'], - ['/device:GPU:0', '/device:GPU:1']]: - shape = (3, 4) - # all-reduce - np_ans = np.zeros(shape=shape, dtype=dtype) - tensors = [] - for d in devices: - with ops.device(d): - t = ((np.random.random_sample(shape) - .5) * 1024).astype(dtype) - np_ans += t - tensors.append(array_ops.identity(t)) - all_reduce_tensors = nccl.all_sum(tensors) - - sender = np.random.randint(0, len(devices) - 1) - other_devices = devices[:sender] + devices[sender + 1:] - send_op, received_tensors = nccl.broadcast(all_reduce_tensors[sender], - other_devices) - - # sender doesn't need to be fetched as part of outputs of session.run. - del all_reduce_tensors[sender] - - # Verify shape inference. - for r in received_tensors: - self.assertEqual(shape, r.get_shape()) +class AllReduceTest(NcclTestCase): - # Run and verify results. - nccl_results = sess.run( - received_tensors + [send_op] + all_reduce_tensors) - for r in nccl_results[:len(received_tensors)]: - self.assertAllClose(r, np_ans) + def testAllReduce(self): + self._Test(partial(_NcclAllReduce, nccl.all_sum), lambda x, y: x + y) + self._Test(partial(_NcclAllReduce, nccl.all_prod), lambda x, y: x * y) + self._Test(partial(_NcclAllReduce, nccl.all_min), np.minimum) + self._Test(partial(_NcclAllReduce, nccl.all_max), np.maximum) + + def testAllSumGrad(self): + self._TestGradient( + partial(_NcclAllReduce, nccl.all_sum), lambda x, y: x + y) + + def testErrors(self): + with self.assertRaisesRegexp(ValueError, 'Device assignment required'): + nccl.all_sum([array_ops.identity(np.random.random_sample((3, 4)))]) + with self.assertRaisesRegexp(ValueError, 'Must pass >0 tensors'): + nccl.all_sum([]) + + +class SingleReduceTest(NcclTestCase): + + def testSum(self): + self._Test(partial(_NcclReduce, nccl.reduce_sum), lambda x, y: x + y) + + +class BroadcastTest(NcclTestCase): + + def testBroadcast(self): + self._Test(_NcclBroadcast, lambda x, y: x) + + +class CombinedTest(NcclTestCase): + """Test all-reduce vs. single-reduce plus broadcast in one session.run.""" + + def _combined(self, tensors, devices): + all_reduce_tensors = _NcclAllReduce(nccl.all_sum, tensors, devices)[0] + single_reduce_tensors, single_reduce_ops = _NcclReduce( + nccl.reduce_sum, tensors, devices) + broadcast_tensors, broadcast_ops = _NcclBroadcast(single_reduce_tensors, + devices) + all_tensors = all_reduce_tensors + single_reduce_tensors + broadcast_tensors + return all_tensors, single_reduce_ops + broadcast_ops + + def testCombined(self): + self._Test(self._combined, lambda x, y: x + y) if __name__ == '__main__': |