aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/nccl
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-12 02:00:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-12 02:04:54 -0700
commit9b9e54b344803994cfd7997035c62ecbbf24d152 (patch)
treecc09fbd68d62ade219f993138ce40ce51155ccad /tensorflow/contrib/nccl
parentbc300318e7bb9c2b5f1dcbfdc6f2d97d5279abf8 (diff)
Adding NCCL sum op, register all_sum gradient.
Streamlining nccl test. PiperOrigin-RevId: 168347428
Diffstat (limited to 'tensorflow/contrib/nccl')
-rw-r--r--tensorflow/contrib/nccl/__init__.py2
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager.cc41
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager.h18
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_ops.cc96
-rw-r--r--tensorflow/contrib/nccl/ops/nccl_ops.cc45
-rw-r--r--tensorflow/contrib/nccl/python/ops/nccl_ops.py125
-rw-r--r--tensorflow/contrib/nccl/python/ops/nccl_ops_test.py211
7 files changed, 417 insertions, 121 deletions
diff --git a/tensorflow/contrib/nccl/__init__.py b/tensorflow/contrib/nccl/__init__.py
index d851c522c0..4a682cb703 100644
--- a/tensorflow/contrib/nccl/__init__.py
+++ b/tensorflow/contrib/nccl/__init__.py
@@ -18,6 +18,7 @@
@@all_min
@@all_prod
@@all_sum
+@@reduce_sum
@@broadcast
"""
@@ -31,6 +32,7 @@ from tensorflow.contrib.nccl.python.ops.nccl_ops import all_min
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_prod
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_sum
from tensorflow.contrib.nccl.python.ops.nccl_ops import broadcast
+from tensorflow.contrib.nccl.python.ops.nccl_ops import reduce_sum
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.cc b/tensorflow/contrib/nccl/kernels/nccl_manager.cc
index 42e7789301..4b642f64c1 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_manager.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager.cc
@@ -260,7 +260,7 @@ NcclManager::Communicator* NcclManager::GetCommunicator(
std::vector<ncclComm_t> nccl_comms(num_devices);
auto result = ncclCommInitAll(nccl_comms.data(), num_devices, devices.data());
- CHECK_EQ(result, ncclSuccess);
+ CHECK_EQ(result, ncclSuccess) << ncclGetErrorString(result);
for (int rank = 0; rank < num_devices; ++rank) {
members[rank].nccl_comm = nccl_comms[rank];
}
@@ -307,6 +307,35 @@ void NcclManager::AddBroadcastRecv(
kBroadcast, ncclSum /* unused */);
}
+void NcclManager::AddReduceSend(int num_devices, const string& key,
+ ncclRedOp_t reduction_op,
+ perftools::gputools::StreamExecutor* executor,
+ int gpu_device_id, EventMgr* event_mgr,
+ perftools::gputools::Stream* tensor_stream,
+ const Tensor* in_t, Tensor* temp_t,
+ DoneCallback done_callback) {
+ std::unique_ptr<Participant> participant(
+ new Participant(in_t, temp_t, event_mgr, tensor_stream, executor,
+ gpu_device_id, std::move(done_callback)));
+ AddParticipant(num_devices, key, std::move(participant), in_t->dtype(),
+ kReduce, reduction_op);
+}
+
+void NcclManager::AddReduceRecv(int num_devices, const string& key,
+ ncclRedOp_t reduction_op,
+ perftools::gputools::StreamExecutor* executor,
+ int gpu_device_id, EventMgr* event_mgr,
+ perftools::gputools::Stream* tensor_stream,
+ const Tensor* in_t, Tensor* out_t,
+ DoneCallback done_callback) {
+ std::unique_ptr<Participant> participant(
+ new Participant(in_t, out_t, event_mgr, tensor_stream, executor,
+ gpu_device_id, std::move(done_callback)));
+ participant->root = true;
+ AddParticipant(num_devices, key, std::move(participant), in_t->dtype(),
+ kReduce, reduction_op);
+}
+
void NcclManager::AddParticipant(int num_devices, const string& key,
std::unique_ptr<Participant> participant,
DataType data_type,
@@ -431,6 +460,14 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
collective->root_rank, nccl_comm, *cu_stream);
break;
}
+ case kReduce: {
+ const void* sendbuff = p->in_t->tensor_data().data();
+ void* recvbuff = const_cast<char*>(p->out_t->tensor_data().data());
+ nccl_result = ncclReduce(sendbuff, recvbuff, p->in_t->NumElements(),
+ data_type, collective->reduction_op,
+ collective->root_rank, nccl_comm, *cu_stream);
+ break;
+ }
}
// Run the done_callback when the nccl kernel finishes running.
@@ -441,7 +478,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
// Propagate the error, but note that if other members of the collective
// did launch their kernels, then they are hanging.
collective->participants[rank]->done_callback(errors::Unknown(
- "Error invoking AllReduce: ", ncclGetErrorString(nccl_result)));
+ "Error invoking NCCL: ", ncclGetErrorString(nccl_result)));
}
// TODO(cwhipkey): use RefCounted after figuring out how to use in a
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.h b/tensorflow/contrib/nccl/kernels/nccl_manager.h
index 6e2f8e953a..619a1b69bf 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_manager.h
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager.h
@@ -75,10 +75,28 @@ class NcclManager {
perftools::gputools::Stream* tensor_stream,
Tensor* out_t, DoneCallback done_callback);
+ // AddReduceSend and AddReduceRecv combine to sent data from all senders
+ // to one receiver.
+ void AddReduceSend(int num_devices, const string& key,
+ ncclRedOp_t reduction_op,
+ perftools::gputools::StreamExecutor* executor,
+ int gpu_device_id, EventMgr* event_mgr,
+ perftools::gputools::Stream* tensor_stream,
+ const Tensor* in_t, Tensor* temp_t,
+ DoneCallback done_callback);
+ void AddReduceRecv(int num_devices, const string& key,
+ ncclRedOp_t reduction_op,
+ perftools::gputools::StreamExecutor* executor,
+ int gpu_device_id, EventMgr* event_mgr,
+ perftools::gputools::Stream* tensor_stream,
+ const Tensor* in_t, Tensor* out_t,
+ DoneCallback done_callback);
+
private:
enum CollectiveType {
kAllReduce = 1,
kBroadcast = 2,
+ kReduce = 3,
};
struct Collective;
struct Communicator;
diff --git a/tensorflow/contrib/nccl/kernels/nccl_ops.cc b/tensorflow/contrib/nccl/kernels/nccl_ops.cc
index d4455483f7..81cc74416b 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_ops.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_ops.cc
@@ -15,6 +15,7 @@ limitations under the License.
#if GOOGLE_CUDA
+#include <memory>
#include <unordered_map>
#include <vector>
@@ -58,11 +59,9 @@ class NcclAsyncOpBase : public AsyncOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(NcclAsyncOpBase);
};
-// To execute a single all-reduce, this kernel is called once for each of the
-// <k> devices in the communicator.
-class NcclAllReduceOpKernel : public NcclAsyncOpBase {
+class NcclReduceOpBase : public NcclAsyncOpBase {
public:
- explicit NcclAllReduceOpKernel(OpKernelConstruction* c) : NcclAsyncOpBase(c) {
+ explicit NcclReduceOpBase(OpKernelConstruction* c) : NcclAsyncOpBase(c) {
string reduction;
OP_REQUIRES_OK(c, c->GetAttr("reduction", &reduction));
if (reduction == "min") {
@@ -79,6 +78,19 @@ class NcclAllReduceOpKernel : public NcclAsyncOpBase {
}
}
+ ncclRedOp_t reduction_op() const { return reduction_op_; }
+
+ private:
+ ncclRedOp_t reduction_op_;
+};
+
+// To execute a single all-reduce, this kernel is called once for each of the
+// <k> devices in the communicator.
+class NcclAllReduceOpKernel : public NcclReduceOpBase {
+ public:
+ explicit NcclAllReduceOpKernel(OpKernelConstruction* c)
+ : NcclReduceOpBase(c) {}
+
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
const Tensor* in_t = &c->input(0);
Tensor* out_t;
@@ -92,18 +104,81 @@ class NcclAllReduceOpKernel : public NcclAsyncOpBase {
auto* compute_stream = c->op_device_context()->stream();
auto* gpu_info = c->device()->tensorflow_gpu_device_info();
NcclManager::instance()->AddToAllReduce(
- num_devices(), GetCollectiveKey(c), reduction_op_,
+ num_devices(), GetCollectiveKey(c), reduction_op(),
compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr,
- compute_stream, in_t, out_t, actual_done);
+ compute_stream, in_t, out_t, std::move(actual_done));
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("NcclAllReduce").Device(DEVICE_GPU),
+ NcclAllReduceOpKernel);
+
+// To execute a single reduce, this kernel is called once for all but one of the
+// <k> devices in the communicator, and NcclReduceRecvKernel is called once for
+// the remaining device.
+class NcclReduceSendKernel : public NcclReduceOpBase {
+ public:
+ explicit NcclReduceSendKernel(OpKernelConstruction* c)
+ : NcclReduceOpBase(c) {}
+
+ void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+ const Tensor& in_t = c->input(0);
+ std::unique_ptr<Tensor> temp_ptr(new Tensor());
+ OP_REQUIRES_OK_ASYNC(
+ c, c->allocate_temp(in_t.dtype(), in_t.shape(), temp_ptr.get()), done);
+ Tensor* temp_t = temp_ptr.release();
+
+ auto actual_done = [c, done, temp_t](Status s) {
+ delete temp_t;
+ OP_REQUIRES_OK_ASYNC(c, s, done);
+ done();
+ };
+
+ auto* compute_stream = c->op_device_context()->stream();
+ auto* gpu_info = c->device()->tensorflow_gpu_device_info();
+ NcclManager::instance()->AddReduceSend(
+ num_devices(), GetCollectiveKey(c), reduction_op(),
+ compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr,
+ compute_stream, &in_t, temp_t, std::move(actual_done));
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("NcclReduceSend").Device(DEVICE_GPU),
+ NcclReduceSendKernel);
+
+// To execute a single reduce, this kernel is called once for one devices, and
+// NcclReduceSendKernel is called for all other <k-1> devices in the
+// communicator.
+class NcclReduceRecvKernel : public NcclReduceOpBase {
+ public:
+ explicit NcclReduceRecvKernel(OpKernelConstruction* c)
+ : NcclReduceOpBase(c) {}
+
+ void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+ const Tensor& in_t = c->input(0);
+ Tensor* out_t;
+ OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, in_t.shape(), &out_t), done);
+
+ auto actual_done = [c, done](Status s) {
+ OP_REQUIRES_OK_ASYNC(c, s, done);
+ done();
+ };
+
+ auto* compute_stream = c->op_device_context()->stream();
+ auto* gpu_info = c->device()->tensorflow_gpu_device_info();
+ NcclManager::instance()->AddReduceRecv(
+ num_devices(), GetCollectiveKey(c), reduction_op(),
+ compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr,
+ compute_stream, &in_t, out_t, std::move(actual_done));
}
private:
ncclRedOp_t reduction_op_;
};
+REGISTER_KERNEL_BUILDER(Name("NcclReduceRecv").Device(DEVICE_GPU),
+ NcclReduceRecvKernel);
-REGISTER_KERNEL_BUILDER(Name("NcclAllReduce").Device(DEVICE_GPU),
- NcclAllReduceOpKernel);
-
+// To execute a single broadcast, this kernel is called once for one device, and
+// NcclBroadcastRecvKernel is called for all other <k-1> devices in the
+// communicator.
class NcclBroadcastSendKernel : public NcclAsyncOpBase {
public:
explicit NcclBroadcastSendKernel(OpKernelConstruction* c)
@@ -126,6 +201,9 @@ class NcclBroadcastSendKernel : public NcclAsyncOpBase {
REGISTER_KERNEL_BUILDER(Name("NcclBroadcastSend").Device(DEVICE_GPU),
NcclBroadcastSendKernel);
+// To execute a single broadcast, this kernel is called once for all but one of
+// the <k> devices in the communicator, and NcclBroadcastSendKernel is called
+// once for the remaining device.
class NcclBroadcastRecvKernel : public NcclAsyncOpBase {
public:
explicit NcclBroadcastRecvKernel(OpKernelConstruction* c)
diff --git a/tensorflow/contrib/nccl/ops/nccl_ops.cc b/tensorflow/contrib/nccl/ops/nccl_ops.cc
index d767636fad..532c79c24c 100644
--- a/tensorflow/contrib/nccl/ops/nccl_ops.cc
+++ b/tensorflow/contrib/nccl/ops/nccl_ops.cc
@@ -45,6 +45,51 @@ num_devices: The number of devices participating in this reduction.
shared_name: Identifier that shared between ops of the same reduction.
)doc");
+REGISTER_OP("NcclReduceSend")
+ .Input("input: T")
+ .Attr("reduction: {'min', 'max', 'prod', 'sum'}")
+ .Attr("T: {float, float64, int32, int64}")
+ .Attr("num_devices: int")
+ .Attr("shared_name: string")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`.
+
+The graph should be constructed so that 'num_devices-1' devices run
+`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value
+`c`. Failure to do so will cause the graph execution to fail to complete.
+
+input: The input to the reduction
+reduction: the reduction operation to perform.
+num_devices: The number of devices participating in this reduction.
+shared_name: Identifier that is shared between ops of the same reduce.
+ )doc");
+
+REGISTER_OP("NcclReduceRecv")
+ .Input("input: T")
+ .Output("data: T")
+ .Attr("reduction: {'min', 'max', 'prod', 'sum'}")
+ .Attr("T: {float, float64, int32, int64}")
+ .Attr("num_devices: int")
+ .Attr("shared_name: string")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Reduces 'input' from this op and the NcclReduceSend ops registered in the same
+`shared_name`.
+
+The graph should be constructed so that 'num_devices-1' devices run
+`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value
+`c`. Failure to do so will cause the graph execution to fail to complete.
+
+input: The input to the reduction
+data: The reduced data received from this op and the NcclReduceSend op.
+reduction: the reduction operation to perform.
+num_devices: The number of devices participating in this reduction.
+shared_name: Identifier that is shared between ops of the same reduce.
+ )doc");
+
REGISTER_OP("NcclBroadcastSend")
.Input("input: T")
.Attr("T: {float, float64, int32, int64}")
diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops.py b/tensorflow/contrib/nccl/python/ops/nccl_ops.py
index b31cc53e0a..906d9f948a 100644
--- a/tensorflow/contrib/nccl/python/ops/nccl_ops.py
+++ b/tensorflow/contrib/nccl/python/ops/nccl_ops.py
@@ -21,6 +21,7 @@ import threading
from tensorflow.contrib.nccl.ops import gen_nccl_ops
from tensorflow.contrib.util import loader
+from tensorflow.python.eager import context
from tensorflow.python.framework import device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -48,6 +49,35 @@ def all_sum(tensors):
return _apply_all_reduce('sum', tensors)
+@ops.RegisterGradient('NcclAllReduce')
+def _all_sum_grad(op, grad):
+ """The gradients for `all_sum`.
+
+ Args:
+ op: The `all_sum` `Operation` that we are differentiating.
+ grad: Gradient with respect to the output of the `all_sum` op.
+
+ Returns:
+ The gradient with respect to the output of `all_sum`.
+
+ Raises:
+ LookupError: If `reduction` is not `sum`.
+ """
+ if op.get_attr('reduction') != 'sum':
+ raise LookupError('No gradient defined for NcclAllReduce except all_sum.')
+
+ _check_device_assignment(grad)
+ num_devices = op.get_attr('num_devices')
+ shared_name = op.get_attr('shared_name') + '_grad'
+
+ with ops.device(grad.device):
+ return gen_nccl_ops.nccl_all_reduce(
+ input=grad,
+ reduction='sum',
+ num_devices=num_devices,
+ shared_name=shared_name)
+
+
def all_prod(tensors):
"""Returns a list of tensors with the all-reduce product across `tensors`.
@@ -99,6 +129,24 @@ def all_max(tensors):
return _apply_all_reduce('max', tensors)
+def reduce_sum(tensors, dst_device):
+ """Returns a tensor with the reduce sum across `tensors`.
+
+ The computation is done with a reduce operation, so only one tensor is
+ returned.
+
+ Args:
+ tensors: The input tensors across which to sum; must be assigned
+ to GPU devices.
+ dst_device: The device of the returned tensor.
+
+ Returns:
+ A tensor containing the sum of the input tensors, with the device of the
+ tensor being `dst_device`.
+ """
+ return _apply_reduce('sum', tensors, dst_device)
+
+
def broadcast(src_tensor, dst_devices):
"""Returns a list of tensors on `dst_devices`, each with value `tensor`.
@@ -111,50 +159,93 @@ def broadcast(src_tensor, dst_devices):
dst_devices: The GPU devices to receive the sent tensor.
Returns:
- List of tensors, each with the value of `src_tensor`, which the device
- of tensor i is `dst_devices[i]`.
+ An `Operation` to send the `src_tensor`, and a list of tensors, each with
+ the value of `src_tensor`, where the device of tensor i is `dst_devices[i]`.
"""
if not dst_devices:
raise ValueError('Must pass >0 dst_devices to broadcast')
- all_devices = [src_tensor.device] + dst_devices
+ _check_graph_mode()
+ _check_device_assignment(src_tensor)
+
+ shape = array_ops.shape(src_tensor, out_type=dtypes.int64)
+ num_devices = len(dst_devices) + 1
shared_name = _get_shared_name()
with ops.device(src_tensor.device):
send = gen_nccl_ops.nccl_broadcast_send(
- input=src_tensor, num_devices=len(all_devices), shared_name=shared_name)
+ input=src_tensor, num_devices=num_devices, shared_name=shared_name)
- shape_op = array_ops.shape(src_tensor, out_type=dtypes.int64)
recvs = []
for d in dst_devices:
with ops.device(d):
recvs.append(
gen_nccl_ops.nccl_broadcast_recv(
- shape=shape_op,
+ shape=shape,
T=src_tensor.dtype,
- num_devices=len(all_devices),
+ num_devices=num_devices,
shared_name=shared_name))
return send, recvs
-def _apply_all_reduce(reduction_op, tensors):
+def _apply_all_reduce(reduction, tensors):
+ """Helper function for all_* functions."""
if not tensors:
raise ValueError('Must pass >0 tensors to all reduce operations')
+ _check_graph_mode()
+
shared_name = _get_shared_name()
res = []
+
for t in tensors:
- if not device.canonical_name(t.device):
- raise ValueError('Device assignment required for nccl collective ops')
+ _check_device_assignment(t)
with ops.device(t.device):
res.append(
gen_nccl_ops.nccl_all_reduce(
- t,
- reduction=reduction_op,
+ input=t,
+ reduction=reduction,
num_devices=len(tensors),
shared_name=shared_name))
+
return res
+def _apply_reduce(reduction, tensors, dst_device):
+ """Helper function for reduce_* functions."""
+ if not tensors:
+ raise ValueError('Must pass >0 tensors to reduce operations')
+ if not dst_device:
+ raise ValueError('Must pass dst_device to reduce operations')
+ _check_graph_mode()
+
+ try:
+ recv_index = next(i for i, t in enumerate(tensors)
+ if t.device == dst_device)
+ except StopIteration:
+ raise ValueError('One of the tensors must be assigned to dst_device')
+ shared_name = _get_shared_name()
+
+ sends = []
+ for t in tensors[:recv_index] + tensors[recv_index + 1:]:
+ _check_device_assignment(t)
+ with ops.device(t.device):
+ sends.append(
+ gen_nccl_ops.nccl_reduce_send(
+ input=t,
+ reduction=reduction,
+ num_devices=len(tensors),
+ shared_name=shared_name))
+
+ with ops.device(dst_device):
+ recv = gen_nccl_ops.nccl_reduce_recv(
+ input=tensors[recv_index],
+ reduction=reduction,
+ num_devices=len(tensors),
+ shared_name=shared_name)
+
+ return recv, sends
+
+
_lock = threading.Lock()
_shared_name_counter = 0
@@ -166,3 +257,13 @@ def _get_shared_name():
val = _shared_name_counter
_shared_name_counter += 1
return 'c%s' % val
+
+
+def _check_device_assignment(tensor):
+ if not device.canonical_name(tensor.device):
+ raise ValueError('Device assignment required for nccl collective ops')
+
+
+def _check_graph_mode():
+ if context.in_eager_mode():
+ raise ValueError('Nccl ops are not supported in eager mode')
diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
index 1621e9f28e..96d67723a0 100644
--- a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
+++ b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from functools import partial
import numpy as np
from tensorflow.contrib import nccl
@@ -26,58 +27,45 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class AllReduceTest(test.TestCase):
+def _DeviceTensors(tensors, devices):
+ res = []
+ for t, d in zip(tensors, devices):
+ with ops.device(d):
+ res.append(array_ops.identity(t))
+ return res
- def testAllReduce(self):
- if not test.is_gpu_available():
- return # Test requires access to a GPU
- for dtype in [np.float32, np.int32, np.int64, np.float64]:
- # Create session inside outer loop to test use of
- # same communicator across multiple sessions.
- with self.test_session(use_gpu=True) as sess:
- self._testSingleAllReduce(sess, dtype, nccl.all_sum, lambda x, y: x + y)
- self._testSingleAllReduce(sess, dtype, nccl.all_prod,
- lambda x, y: x * y)
- self._testSingleAllReduce(sess, dtype, nccl.all_min, np.minimum)
- self._testSingleAllReduce(sess, dtype, nccl.all_max, np.maximum)
-
- def _testSingleAllReduce(self, sess, np_type, nccl_fn, numpy_accumulation_fn):
- for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'],
- ['/device:GPU:1', '/device:GPU:0']]:
- shape = (3, 4)
- np_ans = None
- tensors = []
- for d in devices:
- with ops.device(d):
- t = ((np.random.random_sample(shape) - .5) * 1024).astype(np_type)
- if np_ans is None:
- np_ans = t
- else:
- np_ans = numpy_accumulation_fn(np_ans, t)
- tensors.append(array_ops.identity(t))
-
- all_reduce_tensors = nccl_fn(tensors)
-
- # Test shape inference.
- for r in all_reduce_tensors:
- self.assertEqual(shape, r.get_shape())
-
- # Test execution and results.
- nccl_results = sess.run(all_reduce_tensors)
- for r in nccl_results:
- self.assertAllClose(r, np_ans)
+def _NcclAllReduce(nccl_fun, tensors, devices):
+ return nccl_fun(_DeviceTensors(tensors, devices)), []
- def testErrors(self):
- with self.assertRaisesRegexp(ValueError, 'Device assignment required'):
- nccl.all_sum([array_ops.identity(np.random.random_sample((3, 4)))])
- with self.assertRaisesRegexp(ValueError, 'Must pass >0 tensors'):
- nccl.all_sum([])
+def _NcclReduce(nccl_fun, tensors, devices):
+ d_tensors = _DeviceTensors(tensors, devices)
+ receiver = np.random.randint(0, len(devices))
+ received_tensor, send_ops = nccl_fun(d_tensors, devices[receiver])
+ return [received_tensor], send_ops
-class BroadcastTest(test.TestCase):
- def testBroadcast(self):
+def _NcclBroadcast(tensors, devices):
+ sender = np.random.randint(0, len(devices))
+ d_tensor = _DeviceTensors(tensors[0:1], devices[sender:sender + 1])[0]
+ other_devices = devices[:sender] + devices[sender + 1:]
+ send_op, received_tensors = nccl.broadcast(d_tensor, other_devices)
+ return received_tensors, [send_op]
+
+
+class NcclTestCase(test.TestCase):
+
+ def _Test(self, nccl_reduce, numpy_fn):
+ """Tests that nccl_reduce does the same as reduction with numpy_fn.
+
+ Args:
+ nccl_reduce: A function taking a list of tensors and a list of devices,
+ and returns a list of reduced tensors and a list of ops to perform the
+ reduction.
+ numpy_fn: A function taking two tensors and returning the reduction of the
+ two.
+ """
if not test.is_gpu_available():
return # Test requires access to a GPU
@@ -85,69 +73,96 @@ class BroadcastTest(test.TestCase):
# Create session inside outer loop to test use of
# same communicator across multiple sessions.
with self.test_session(use_gpu=True) as sess:
- for devices in [['/device:GPU:1', '/device:GPU:0', '/device:GPU:2'],
+
+ for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'],
['/device:GPU:1', '/device:GPU:0']]:
shape = (3, 4)
- sender = np.random.randint(0, len(devices) - 1)
- with ops.device(devices[sender]):
- np_ans = ((
- (np.random.random_sample(shape) - .5) * 1024).astype(dtype))
- t = array_ops.identity(np_ans)
- other_devices = devices[:sender] + devices[sender + 1:]
- send_op, received_tensors = nccl.broadcast(t, other_devices)
-
- # Verify shape inference.
- for r in received_tensors:
+ random = (np.random.random_sample(shape) - .5) * 1024
+ tensors = [random.astype(dtype)] * len(devices)
+ np_ans = tensors[0]
+ for t in tensors[1:]:
+ np_ans = numpy_fn(np_ans, t)
+
+ reduce_tensors, reduce_ops = nccl_reduce(tensors, devices)
+ self.assertNotEmpty(reduce_tensors)
+
+ # Test shape inference.
+ for r in reduce_tensors:
self.assertEqual(shape, r.get_shape())
- # Run and verify results.
- nccl_results = sess.run(received_tensors + [send_op])
- for r in nccl_results[:-1]:
+ # Test execution and results.
+ nccl_results = sess.run(reduce_tensors + reduce_ops)
+ for r in nccl_results[:len(reduce_tensors)]:
self.assertAllClose(r, np_ans)
+ def _TestGradient(self, nccl_reduce, numpy_fn):
+ """Tests the gradient of nccl_reduce.
-class CombinedTest(test.TestCase):
- """Tests using a mix of all-reduce ops in one session.run call."""
+ Args:
+ nccl_reduce: A function taking a list of tensors and a list of devices,
+ and returns a list of reduced tensors and a list of ops to perform the
+ reduction.
+ numpy_fn: A function taking two tensors and returning the gradient of the
+ reduction of the two.
+ """
+ def _Gradient(tensors, devices):
+ reduce_tensors, _ = nccl_reduce(tensors, devices)
+ tensor_ops = [t.op for t in reduce_tensors]
+ d_tensors = _DeviceTensors(tensors, devices)
+ grad_tensors = [
+ ops.get_gradient_function(op)(op, loss)
+ for op, loss in zip(tensor_ops, d_tensors)
+ ]
+ return grad_tensors, []
- def testCombined(self):
- if not test.is_gpu_available():
- return # Test requires access to a GPU
+ self._Test(_Gradient, numpy_fn)
- for dtype in [np.float32, np.int32, np.int64, np.float64]:
- # Create session inside outer loop to test use of
- # same communicator across multiple sessions.
- with self.test_session(use_gpu=True) as sess:
- for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'],
- ['/device:GPU:0', '/device:GPU:1']]:
- shape = (3, 4)
- # all-reduce
- np_ans = np.zeros(shape=shape, dtype=dtype)
- tensors = []
- for d in devices:
- with ops.device(d):
- t = ((np.random.random_sample(shape) - .5) * 1024).astype(dtype)
- np_ans += t
- tensors.append(array_ops.identity(t))
- all_reduce_tensors = nccl.all_sum(tensors)
-
- sender = np.random.randint(0, len(devices) - 1)
- other_devices = devices[:sender] + devices[sender + 1:]
- send_op, received_tensors = nccl.broadcast(all_reduce_tensors[sender],
- other_devices)
-
- # sender doesn't need to be fetched as part of outputs of session.run.
- del all_reduce_tensors[sender]
-
- # Verify shape inference.
- for r in received_tensors:
- self.assertEqual(shape, r.get_shape())
+class AllReduceTest(NcclTestCase):
- # Run and verify results.
- nccl_results = sess.run(
- received_tensors + [send_op] + all_reduce_tensors)
- for r in nccl_results[:len(received_tensors)]:
- self.assertAllClose(r, np_ans)
+ def testAllReduce(self):
+ self._Test(partial(_NcclAllReduce, nccl.all_sum), lambda x, y: x + y)
+ self._Test(partial(_NcclAllReduce, nccl.all_prod), lambda x, y: x * y)
+ self._Test(partial(_NcclAllReduce, nccl.all_min), np.minimum)
+ self._Test(partial(_NcclAllReduce, nccl.all_max), np.maximum)
+
+ def testAllSumGrad(self):
+ self._TestGradient(
+ partial(_NcclAllReduce, nccl.all_sum), lambda x, y: x + y)
+
+ def testErrors(self):
+ with self.assertRaisesRegexp(ValueError, 'Device assignment required'):
+ nccl.all_sum([array_ops.identity(np.random.random_sample((3, 4)))])
+ with self.assertRaisesRegexp(ValueError, 'Must pass >0 tensors'):
+ nccl.all_sum([])
+
+
+class SingleReduceTest(NcclTestCase):
+
+ def testSum(self):
+ self._Test(partial(_NcclReduce, nccl.reduce_sum), lambda x, y: x + y)
+
+
+class BroadcastTest(NcclTestCase):
+
+ def testBroadcast(self):
+ self._Test(_NcclBroadcast, lambda x, y: x)
+
+
+class CombinedTest(NcclTestCase):
+ """Test all-reduce vs. single-reduce plus broadcast in one session.run."""
+
+ def _combined(self, tensors, devices):
+ all_reduce_tensors = _NcclAllReduce(nccl.all_sum, tensors, devices)[0]
+ single_reduce_tensors, single_reduce_ops = _NcclReduce(
+ nccl.reduce_sum, tensors, devices)
+ broadcast_tensors, broadcast_ops = _NcclBroadcast(single_reduce_tensors,
+ devices)
+ all_tensors = all_reduce_tensors + single_reduce_tensors + broadcast_tensors
+ return all_tensors, single_reduce_ops + broadcast_ops
+
+ def testCombined(self):
+ self._Test(self._combined, lambda x, y: x + y)
if __name__ == '__main__':