diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-02-21 18:04:49 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-21 18:31:12 -0800 |
commit | f3405c2d73196e409041d52bbf30748b2a64493b (patch) | |
tree | 297974084329b2789486bcf6d7a962dce94ae096 /tensorflow/contrib/nccl | |
parent | 4c26c3211520c81436f244add5c45e8fb0f6adec (diff) |
Change nccl_manager to use ncclCommInitAll.
Change: 148169806
Diffstat (limited to 'tensorflow/contrib/nccl')
-rw-r--r-- | tensorflow/contrib/nccl/kernels/nccl_manager.cc | 59 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/kernels/nccl_manager.h | 6 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/kernels/nccl_manager_test.cc | 11 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/kernels/nccl_ops.cc | 20 |
4 files changed, 44 insertions, 52 deletions
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.cc b/tensorflow/contrib/nccl/kernels/nccl_manager.cc index 31e85b571d..dfdfbc8eea 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_manager.cc @@ -92,13 +92,14 @@ ncclDataType_t ToNcclType(DataType t) { struct NcclManager::Participant { Participant(const Tensor* in_t, Tensor* out_t, EventMgr* event_mgr, perftools::gputools::Stream* tensor_stream, - perftools::gputools::StreamExecutor* executor, + perftools::gputools::StreamExecutor* executor, int gpu_device_id, NcclManager::DoneCallback done_callback) : in_t(in_t), out_t(out_t), event_mgr(event_mgr), tensor_stream(tensor_stream), executor(executor), + gpu_device_id(gpu_device_id), done_callback(std::move(done_callback)) { DCHECK(executor != nullptr); DCHECK(event_mgr != nullptr); @@ -120,7 +121,9 @@ struct NcclManager::Participant { // Matches the executor in CommunicatorMember::stream. Expected to be live for // process lifetime. - perftools::gputools::StreamExecutor* executor = nullptr; + perftools::gputools::StreamExecutor* const executor = nullptr; + + const int gpu_device_id; NcclManager::DoneCallback done_callback; @@ -222,6 +225,7 @@ NcclManager::Communicator* NcclManager::GetCommunicator( // Note that this is done under the lock; performance is not expected to // matter as this happens a very small number of times. std::vector<CommunicatorMember> members(num_devices); + std::vector<int> devices(num_devices); for (int i = 0; i < num_devices; ++i) { auto* executor = collective->participants[i]->executor; @@ -249,30 +253,14 @@ NcclManager::Communicator* NcclManager::GetCommunicator( } members[i].nccl_stream = nccl_stream; + devices[i] = collective->participants[i]->gpu_device_id; } - // Call ncclCommInitRank for each member. - ncclUniqueId id; - CHECK_EQ(ncclSuccess, ncclGetUniqueId(&id)); - std::unique_ptr<thread::ThreadPool> pool( - new thread::ThreadPool(env, "ncclCommInitRank", num_devices)); - std::vector<ncclResult_t> results(num_devices); + std::vector<ncclComm_t> nccl_comms(num_devices); + auto result = ncclCommInitAll(nccl_comms.data(), num_devices, devices.data()); + CHECK_EQ(result, ncclSuccess); for (int rank = 0; rank < num_devices; ++rank) { - CommunicatorMember* member = &members[rank]; - ncclResult_t* result = &results[rank]; - pool->Schedule([member, num_devices, result, rank, &id]() { - ScopedActivateExecutorContext scoped_context( - member->nccl_stream->executor); - LOG(INFO) << "Calling ncclCommInitRank for rank " << rank; - *result = ncclCommInitRank(&member->nccl_comm, num_devices, id, rank); - LOG(INFO) << "Done calling ncclCommInitRank for rank " << rank << " : " - << *result; - }); - } - - pool.reset(); // wait for completion. - for (int i = 0; i < num_devices; ++i) { - CHECK_EQ(results[i], ncclSuccess); + members[rank].nccl_comm = nccl_comms[rank]; } communicators_.emplace_back(new Communicator(std::move(members))); return communicators_.back().get(); @@ -281,24 +269,25 @@ NcclManager::Communicator* NcclManager::GetCommunicator( void NcclManager::AddToAllReduce(int num_devices, const string& key, ncclRedOp_t reduction_op, perftools::gputools::StreamExecutor* executor, - EventMgr* event_mgr, + int gpu_device_id, EventMgr* event_mgr, perftools::gputools::Stream* tensor_stream, const Tensor* in_t, Tensor* out_t, const DoneCallback& done_callback) { - std::unique_ptr<Participant> participant(new Participant( - in_t, out_t, event_mgr, tensor_stream, executor, done_callback)); + std::unique_ptr<Participant> participant( + new Participant(in_t, out_t, event_mgr, tensor_stream, executor, + gpu_device_id, done_callback)); AddParticipant(num_devices, key, std::move(participant), in_t->dtype(), kAllReduce, reduction_op); } void NcclManager::AddBroadcastSend( int num_devices, const string& key, - perftools::gputools::StreamExecutor* executor, EventMgr* event_mgr, - perftools::gputools::Stream* tensor_stream, const Tensor* in_t, - DoneCallback done_callback) { + perftools::gputools::StreamExecutor* executor, int gpu_device_id, + EventMgr* event_mgr, perftools::gputools::Stream* tensor_stream, + const Tensor* in_t, DoneCallback done_callback) { std::unique_ptr<Participant> participant( new Participant(in_t, nullptr /* out_t */, event_mgr, tensor_stream, - executor, done_callback)); + executor, gpu_device_id, done_callback)); participant->root = true; AddParticipant(num_devices, key, std::move(participant), in_t->dtype(), kBroadcast, ncclSum /* unused */); @@ -306,12 +295,12 @@ void NcclManager::AddBroadcastSend( void NcclManager::AddBroadcastRecv( int num_devices, const string& key, - perftools::gputools::StreamExecutor* executor, EventMgr* event_mgr, - perftools::gputools::Stream* tensor_stream, Tensor* out_t, - DoneCallback done_callback) { + perftools::gputools::StreamExecutor* executor, int gpu_device_id, + EventMgr* event_mgr, perftools::gputools::Stream* tensor_stream, + Tensor* out_t, DoneCallback done_callback) { std::unique_ptr<Participant> participant( new Participant(nullptr /* in_t */, out_t, event_mgr, tensor_stream, - executor, done_callback)); + executor, gpu_device_id, done_callback)); AddParticipant(num_devices, key, std::move(participant), out_t->dtype(), kBroadcast, ncclSum /* unused */); } @@ -331,7 +320,7 @@ void NcclManager::AddParticipant(int num_devices, const string& key, } Collective* collective = collective_ptr.get(); DCHECK_EQ(collective->type, collective_type); - DCHECK_EQ(collective->participants.size(), num_devices); + DCHECK_LT(collective->participants.size(), num_devices); collective->participants.emplace_back(std::move(participant)); ++collective->available_participants; diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.h b/tensorflow/contrib/nccl/kernels/nccl_manager.h index 8d5e5ddf76..1a661e8f7f 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager.h +++ b/tensorflow/contrib/nccl/kernels/nccl_manager.h @@ -57,7 +57,7 @@ class NcclManager { void AddToAllReduce(int num_devices, const string& key, ncclRedOp_t reduction_op, perftools::gputools::StreamExecutor* executor, - EventMgr* event_mgr, + int gpu_device_id, EventMgr* event_mgr, perftools::gputools::Stream* tensor_stream, const Tensor* in_t, Tensor* out_t, const DoneCallback& done_callback); @@ -66,12 +66,12 @@ class NcclManager { // to all receivers. void AddBroadcastSend(int num_devices, const string& key, perftools::gputools::StreamExecutor* executor, - EventMgr* event_mgr, + int gpu_device_id, EventMgr* event_mgr, perftools::gputools::Stream* tensor_stream, const Tensor* in_t, DoneCallback done_callback); void AddBroadcastRecv(int num_devices, const string& key, perftools::gputools::StreamExecutor* executor, - EventMgr* event_mgr, + int gpu_device_id, EventMgr* event_mgr, perftools::gputools::Stream* tensor_stream, Tensor* out_t, DoneCallback done_callback); diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc index b53cb82440..505c4b0d71 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc @@ -193,9 +193,9 @@ TEST_F(NcclManagerTest, BasicSumReduction) { auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; auto* stream = device->tensorflow_gpu_device_info()->stream; NcclManager::instance()->AddToAllReduce( - num_ranks, "allreduce", reduction_op, device->executor(), event_mgr, - stream, &test_case->ins[device_num], &test_case->outs[device_num], - CreateDoneCallback(test_case.get())); + num_ranks, "allreduce", reduction_op, device->executor(), + device->gpu_id(), event_mgr, stream, &test_case->ins[device_num], + &test_case->outs[device_num], CreateDoneCallback(test_case.get())); } LOG(ERROR) << "Verifying results"; @@ -259,8 +259,9 @@ TEST_F(NcclManagerTest, MultipleCallers) { TestCase* test_case = test_cases[test_num].get(); NcclManager::instance()->AddToAllReduce( num_ranks, strings::StrCat("allreduce", test_num), ncclSum, - device->executor(), event_mgr, stream, &test_case->ins[device_num], - &test_case->outs[device_num], CreateDoneCallback(test_case)); + device->executor(), device->gpu_id(), event_mgr, stream, + &test_case->ins[device_num], &test_case->outs[device_num], + CreateDoneCallback(test_case)); }; pool->Schedule(fn); } diff --git a/tensorflow/contrib/nccl/kernels/nccl_ops.cc b/tensorflow/contrib/nccl/kernels/nccl_ops.cc index db6ee3e0e7..b63ab5d611 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_ops.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_ops.cc @@ -90,11 +90,11 @@ class NcclAllReduceOpKernel : public NcclAsyncOpBase { }; auto* compute_stream = c->op_device_context()->stream(); - EventMgr* event_mgr = c->device()->tensorflow_gpu_device_info()->event_mgr; + auto* gpu_info = c->device()->tensorflow_gpu_device_info(); NcclManager::instance()->AddToAllReduce( num_devices(), GetCollectiveKey(c), reduction_op_, - compute_stream->parent(), event_mgr, compute_stream, in_t, out_t, - actual_done); + compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr, + compute_stream, in_t, out_t, actual_done); } private: @@ -115,10 +115,11 @@ class NcclBroadcastSendKernel : public NcclAsyncOpBase { }; auto* compute_stream = c->op_device_context()->stream(); - EventMgr* event_mgr = c->device()->tensorflow_gpu_device_info()->event_mgr; + auto* gpu_info = c->device()->tensorflow_gpu_device_info(); NcclManager::instance()->AddBroadcastSend( - num_devices(), GetCollectiveKey(c), compute_stream->parent(), event_mgr, - compute_stream, &c->input(0), std::move(actual_done)); + num_devices(), GetCollectiveKey(c), compute_stream->parent(), + gpu_info->gpu_id, gpu_info->event_mgr, compute_stream, &c->input(0), + std::move(actual_done)); } }; REGISTER_KERNEL_BUILDER(Name("NcclBroadcastSend").Device(DEVICE_GPU), @@ -142,10 +143,11 @@ class NcclBroadcastRecvKernel : public NcclAsyncOpBase { }; auto* compute_stream = c->op_device_context()->stream(); - EventMgr* event_mgr = c->device()->tensorflow_gpu_device_info()->event_mgr; + auto* gpu_info = c->device()->tensorflow_gpu_device_info(); NcclManager::instance()->AddBroadcastRecv( - num_devices(), GetCollectiveKey(c), compute_stream->parent(), event_mgr, - compute_stream, out_t, std::move(actual_done)); + num_devices(), GetCollectiveKey(c), compute_stream->parent(), + gpu_info->gpu_id, gpu_info->event_mgr, compute_stream, out_t, + std::move(actual_done)); } }; REGISTER_KERNEL_BUILDER( |