diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-01-24 15:19:38 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-24 15:29:03 -0800 |
commit | 5be95cbb389bc112161232c8514155947063ea72 (patch) | |
tree | f0812bf2efeb798155000d6133f9e58fd3738c86 /tensorflow/contrib/nccl | |
parent | 761405e7202e1bec875f1ca7d1a7660ebbb3dafb (diff) |
Add contrib/nccl for using all-reduce collectives across GPUs of a single
server.
Change: 145475050
Diffstat (limited to 'tensorflow/contrib/nccl')
-rw-r--r-- | tensorflow/contrib/nccl/BUILD | 120 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/__init__.py | 24 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/kernels/nccl_manager.cc | 471 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/kernels/nccl_manager.h | 122 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/kernels/nccl_manager_test.cc | 285 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/kernels/nccl_ops.cc | 157 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/ops/nccl_ops.cc | 94 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/python/ops/nccl_ops.py | 168 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/python/ops/nccl_ops_test.py | 151 |
9 files changed, 1592 insertions, 0 deletions
diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD new file mode 100644 index 0000000000..7c352dae88 --- /dev/null +++ b/tensorflow/contrib/nccl/BUILD @@ -0,0 +1,120 @@ +# Description: +# Wrap NVIDIA (https://github.com/NVIDIA/nccl) NCCL with tensorflow ops. +# APIs are meant to change over time. +package( + default_visibility = ["//visibility:private"], + features = ["-parse_headers"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_cuda_cc_test", + "tf_custom_op_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", +) +load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") + +tf_custom_op_library( + name = "python/ops/_nccl_ops.so", + srcs = [ + "kernels/nccl_manager.cc", + "kernels/nccl_manager.h", + "kernels/nccl_ops.cc", + "ops/nccl_ops.cc", + ], + deps = [ + "//tensorflow/core:gpu_headers_lib", + "@nccl_archive//:nccl", + ], +) + +tf_gen_op_libs( + op_lib_names = ["nccl_ops"], + deps = [ + "//tensorflow/core:lib", + ], +) + +tf_gen_op_wrapper_py( + name = "nccl_ops", + deps = [":nccl_ops_op_lib"], +) + +py_library( + name = "nccl_py", + srcs = [ + "__init__.py", + "python/ops/nccl_ops.py", + ], + data = [ + ":python/ops/_nccl_ops.so", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":nccl_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + ], +) + +cuda_py_test( + name = "nccl_ops_test", + size = "small", + srcs = ["python/ops/nccl_ops_test.py"], + additional_deps = [ + ":nccl_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + tags = [ + "manual", + "requires_cudnn5", + ], +) + +tf_cuda_cc_test( + name = "nccl_manager_test", + size = "small", + srcs = if_cuda( + [ + "kernels/nccl_manager.cc", + "kernels/nccl_manager.h", + "kernels/nccl_manager_test.cc", + ], + [], + ), + deps = if_cuda( + [ + "@nccl_archive//:nccl", + "//tensorflow/core", + "//tensorflow/core:cuda", + ], + [], + ) + [ + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/nccl/__init__.py b/tensorflow/contrib/nccl/__init__.py new file mode 100644 index 0000000000..0275ed6079 --- /dev/null +++ b/tensorflow/contrib/nccl/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops for nccl AllReduce.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.nccl.python.ops.nccl_ops import * +# pylint: enable=wildcard-import diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.cc b/tensorflow/contrib/nccl/kernels/nccl_manager.cc new file mode 100644 index 0000000000..31e85b571d --- /dev/null +++ b/tensorflow/contrib/nccl/kernels/nccl_manager.cc @@ -0,0 +1,471 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/nccl/kernels/nccl_manager.h" + +#ifdef GOOGLE_CUDA + +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/cuda.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +using ::perftools::gputools::cuda::ScopedActivateExecutorContext; + +// Contains data for a single stream used for nccl communication; this includes +// a background thread that calls NcclManager::LoopKernelLaunches. +struct NcclManager::NcclStream { + public: + NcclStream() {} + ~NcclStream() { + mutex_lock l(mu); + shutdown_requested = true; + cv.notify_all(); + } + + perftools::gputools::StreamExecutor* executor = nullptr; + + // The stream on which to run the nccl collective. + // This is a different stream than the tensorflow compute stream. + std::unique_ptr<perftools::gputools::Stream> stream; + + // See NcclManager::LoopKernelLaunches for information on these. + std::unique_ptr<Thread> thread; + mutex mu; + condition_variable cv; + // Has collective,rank pairs. + std::deque<std::pair<Collective*, int>> pending_launches_ GUARDED_BY(mu); + bool shutdown_requested GUARDED_BY(mu) = false; +}; + +struct NcclManager::CommunicatorMember { + public: + CommunicatorMember() {} + ~CommunicatorMember() { + if (nccl_comm != nullptr) ncclCommDestroy(nccl_comm); + } + ncclComm_t nccl_comm; + + // Owned by NcclManager::device_to_comm_streams_. + NcclStream* nccl_stream = nullptr; +}; + +struct NcclManager::Communicator { + public: + Communicator(std::vector<CommunicatorMember> members) + : num_devices(members.size()), members(std::move(members)) {} + + const int num_devices; + const std::vector<CommunicatorMember> members; // indexed by rank. +}; + +namespace { +ncclDataType_t ToNcclType(DataType t) { + switch (t) { + case DT_FLOAT: + return ncclFloat; + case DT_DOUBLE: + return ncclDouble; + case DT_INT32: + return ncclInt; + case DT_INT64: + return ncclInt64; + default: + return ncclFloat; + } +} +} // namespace + +// A participant in a Collective. See <Collective> below. +struct NcclManager::Participant { + Participant(const Tensor* in_t, Tensor* out_t, EventMgr* event_mgr, + perftools::gputools::Stream* tensor_stream, + perftools::gputools::StreamExecutor* executor, + NcclManager::DoneCallback done_callback) + : in_t(in_t), + out_t(out_t), + event_mgr(event_mgr), + tensor_stream(tensor_stream), + executor(executor), + done_callback(std::move(done_callback)) { + DCHECK(executor != nullptr); + DCHECK(event_mgr != nullptr); + DCHECK(tensor_stream != nullptr); + } + // Owned by the caller, who must keep it live until <done_callback> is called. + // Is NULL for participants that only receive data. + const Tensor* in_t; + + // Owned by the caller, who must keep it live until <done_callback> is called. + // Is NULL for participants that only send data. + Tensor* out_t; + + // Owned by the caller, who must keep it live until <done_callback> is called. + EventMgr* const event_mgr; + + // Owned by the caller, who must keep it live until <done_callback> is called. + perftools::gputools::Stream* const tensor_stream; + + // Matches the executor in CommunicatorMember::stream. Expected to be live for + // process lifetime. + perftools::gputools::StreamExecutor* executor = nullptr; + + NcclManager::DoneCallback done_callback; + + bool root = false; +}; + +// A Collective tracks a single communicator operation (e.g., a single +// AllReduce call). +struct NcclManager::Collective { + Collective(DataType data_type_in, CollectiveType type_in, + ncclRedOp_t reduction_op_in, int num_devices) + : data_type(data_type_in), + type(type_in), + reduction_op(reduction_op_in), + remaining_participants(num_devices) { + participants.reserve(num_devices); + } + + const DataType data_type; + const CollectiveType type; + const ncclRedOp_t reduction_op; // applies when <type> is a reduction. + + Communicator* communicator = nullptr; + + // All collective participants. + // + // Adding values in this vector is guarded by the mutex of the containing + // NcclManager. + std::vector<std::unique_ptr<Participant>> participants; + + // For collective types that have a root (e.g. the root of broadcast is the + // sender), this is the rank of the root. + int root_rank = -1; + + // How many participants have been registered so far. The Collective is + // eligible for running with <available_participants> == participants.size(). + // + // Guarded by the mutex of the containing Communicator. + int available_participants = 0; + + mutable std::atomic_int_fast32_t remaining_participants; +}; + +NcclManager::NcclManager() {} +NcclManager::~NcclManager() {} +NcclManager* NcclManager::instance() { + static NcclManager* instance = new NcclManager(); + return instance; +} + +NcclManager::Communicator* NcclManager::GetCommunicator( + NcclManager::Collective* collective) { + // Sort by executor to make ordering of executors deterministic. + std::sort(collective->participants.begin(), collective->participants.end(), + [](const std::unique_ptr<Participant>& a, + const std::unique_ptr<Participant>& b) { + return a->executor < b->executor; + }); + const int num_devices = collective->participants.size(); + + mutex_lock l(mu_); + + // Scan to find an existing communicator that provides nccl communication + // between the executors used by the participants in the collective. For + // example, if a collective is for GPUs 0, 1, and 2 then this will scan + // to find the communicator for GPUs 0, 1, and 2. + // + // Note that each executor identifies a context on one device, so this is the + // same as getting the communicator connecting the devices in the collective. + // A device can be in different communicators as well - for example, a + // communicator for GPUs 0 and 1 is separate from one for GPUs 0, 1, and 2. + // + // Since it's expected that a small number of distinct communicators will + // be needed, communicators_ is not garbage collected currently. + // + // Launching of kernels must be serialized so that, given collectives A and B, + // and an order of them (e.g., A before B), then for each comm_stream + // involved, the kernel for A is launched before the kernel for B. This is + // guaranteed currently be a global mutex controlling additions of the kernels + // to per-stream launch queues. The launch queues are processed by + // LoopKernelLaunches. + for (auto& comm : communicators_) { + if (comm->num_devices == num_devices) { + int i; + for (i = 0; i < num_devices; ++i) { + if (comm->members[i].nccl_stream->executor != + collective->participants[i]->executor) { + break; + } + } + if (i == num_devices) return comm.get(); + } + } + + auto* env = Env::Default(); + std::set<NcclStream*> used_streams; + + // Create and initialize a new communicator. + // Note that this is done under the lock; performance is not expected to + // matter as this happens a very small number of times. + std::vector<CommunicatorMember> members(num_devices); + for (int i = 0; i < num_devices; ++i) { + auto* executor = collective->participants[i]->executor; + + // Find a communication stream to use for the device. + auto& streams = device_to_comm_streams_[executor]; + NcclStream* nccl_stream = nullptr; + for (const auto& s : streams) { + if (used_streams.insert(s.get()).second) { + nccl_stream = s.get(); + break; + } + } + if (nccl_stream == nullptr) { + nccl_stream = new NcclStream(); + nccl_stream->executor = executor; + nccl_stream->stream.reset(new perftools::gputools::Stream(executor)); + nccl_stream->stream->Init(); + + streams.emplace_back(nccl_stream); + used_streams.insert(nccl_stream); + + nccl_stream->thread.reset(env->StartThread( + ThreadOptions(), "nccl_kernel_launch", + [this, nccl_stream] { LoopKernelLaunches(nccl_stream); })); + } + + members[i].nccl_stream = nccl_stream; + } + + // Call ncclCommInitRank for each member. + ncclUniqueId id; + CHECK_EQ(ncclSuccess, ncclGetUniqueId(&id)); + std::unique_ptr<thread::ThreadPool> pool( + new thread::ThreadPool(env, "ncclCommInitRank", num_devices)); + std::vector<ncclResult_t> results(num_devices); + for (int rank = 0; rank < num_devices; ++rank) { + CommunicatorMember* member = &members[rank]; + ncclResult_t* result = &results[rank]; + pool->Schedule([member, num_devices, result, rank, &id]() { + ScopedActivateExecutorContext scoped_context( + member->nccl_stream->executor); + LOG(INFO) << "Calling ncclCommInitRank for rank " << rank; + *result = ncclCommInitRank(&member->nccl_comm, num_devices, id, rank); + LOG(INFO) << "Done calling ncclCommInitRank for rank " << rank << " : " + << *result; + }); + } + + pool.reset(); // wait for completion. + for (int i = 0; i < num_devices; ++i) { + CHECK_EQ(results[i], ncclSuccess); + } + communicators_.emplace_back(new Communicator(std::move(members))); + return communicators_.back().get(); +} + +void NcclManager::AddToAllReduce(int num_devices, const string& key, + ncclRedOp_t reduction_op, + perftools::gputools::StreamExecutor* executor, + EventMgr* event_mgr, + perftools::gputools::Stream* tensor_stream, + const Tensor* in_t, Tensor* out_t, + const DoneCallback& done_callback) { + std::unique_ptr<Participant> participant(new Participant( + in_t, out_t, event_mgr, tensor_stream, executor, done_callback)); + AddParticipant(num_devices, key, std::move(participant), in_t->dtype(), + kAllReduce, reduction_op); +} + +void NcclManager::AddBroadcastSend( + int num_devices, const string& key, + perftools::gputools::StreamExecutor* executor, EventMgr* event_mgr, + perftools::gputools::Stream* tensor_stream, const Tensor* in_t, + DoneCallback done_callback) { + std::unique_ptr<Participant> participant( + new Participant(in_t, nullptr /* out_t */, event_mgr, tensor_stream, + executor, done_callback)); + participant->root = true; + AddParticipant(num_devices, key, std::move(participant), in_t->dtype(), + kBroadcast, ncclSum /* unused */); +} + +void NcclManager::AddBroadcastRecv( + int num_devices, const string& key, + perftools::gputools::StreamExecutor* executor, EventMgr* event_mgr, + perftools::gputools::Stream* tensor_stream, Tensor* out_t, + DoneCallback done_callback) { + std::unique_ptr<Participant> participant( + new Participant(nullptr /* in_t */, out_t, event_mgr, tensor_stream, + executor, done_callback)); + AddParticipant(num_devices, key, std::move(participant), out_t->dtype(), + kBroadcast, ncclSum /* unused */); +} + +void NcclManager::AddParticipant(int num_devices, const string& key, + std::unique_ptr<Participant> participant, + DataType data_type, + CollectiveType collective_type, + ncclRedOp_t reduction_op) { + Collective* to_run = nullptr; + { + mutex_lock l(mu_); + auto& collective_ptr = collectives_[key]; + if (collective_ptr == nullptr) { + collective_ptr.reset(new Collective(data_type, collective_type, + reduction_op, num_devices)); + } + Collective* collective = collective_ptr.get(); + DCHECK_EQ(collective->type, collective_type); + DCHECK_EQ(collective->participants.size(), num_devices); + collective->participants.emplace_back(std::move(participant)); + ++collective->available_participants; + + if (collective->available_participants == num_devices) { + to_run = collective; + + // Ownership is going to be transferred to RunCollective. + collective_ptr.release(); + collectives_.erase(key); + } + } + + if (to_run != nullptr) { + RunCollective(key, to_run); + } +} + +void NcclManager::RunCollective(const string& key, Collective* collective) { + static mutex collective_mu; + + auto* communicator = GetCommunicator(collective); + collective->communicator = communicator; + const int size = communicator->num_devices; + + for (int rank = 0; rank < size; ++rank) { + Participant* p = collective->participants[rank].get(); + NcclStream* nccl_stream = communicator->members[rank].nccl_stream; + CHECK(nccl_stream != nullptr); + + if (p->in_t != nullptr) { + // Wait to ensure that the kernel that produces the data in the input + // tensor has finished running before the nccl kernel runs on the + // communication stream. + nccl_stream->stream->ThenWaitFor(p->tensor_stream); + } + if (p->root) { + CHECK_EQ(collective->root_rank, -1); + collective->root_rank = rank; + } + } + + if (collective->type == kBroadcast) { + CHECK_NE(collective->root_rank, -1); + } + + { + // Allow only one collective at a time to queue kernels for launching. This + // is to prevent collectives from deadlocking each other. + // Note that it would be possible to run multiple collectives at once, if + // they have non-intersecting sets of devices. + mutex_lock l(collective_mu); + for (int rank = 0; rank < size; ++rank) { + NcclStream* nccl_stream = communicator->members[rank].nccl_stream; + mutex_lock l(nccl_stream->mu); + nccl_stream->pending_launches_.push_front( + std::make_pair(collective, rank)); + nccl_stream->cv.notify_all(); + } + } +} + +void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { + perftools::gputools::Stream* comm_stream = nccl_stream->stream.get(); + ScopedActivateExecutorContext scoped_context(nccl_stream->executor); + const cudaStream_t* cu_stream = reinterpret_cast<const cudaStream_t*>( + comm_stream->implementation()->CudaStreamMemberHack()); + + while (true) { + // Find collective to run. + std::pair<Collective*, int> next_launch; + { + mutex_lock l(nccl_stream->mu); + while (nccl_stream->pending_launches_.empty()) { + if (nccl_stream->shutdown_requested) { + // No work and shutdown requested, exit. + return; + } + nccl_stream->cv.wait(l); + } + next_launch = nccl_stream->pending_launches_.back(); + nccl_stream->pending_launches_.pop_back(); + } + Collective* collective = next_launch.first; + int rank = next_launch.second; + + // Launch the nccl kernel. + ncclDataType_t data_type = ToNcclType(collective->data_type); + Participant* p = collective->participants[rank].get(); + + auto nccl_comm = collective->communicator->members[rank].nccl_comm; + ncclResult_t nccl_result = ncclSuccess; + switch (collective->type) { + case kAllReduce: { + const void* sendbuff = p->in_t->tensor_data().data(); + void* recvbuff = const_cast<char*>(p->out_t->tensor_data().data()); + + nccl_result = + ncclAllReduce(sendbuff, recvbuff, p->in_t->NumElements(), data_type, + collective->reduction_op, nccl_comm, *cu_stream); + break; + } + case kBroadcast: { + const Tensor* buf_t = p->in_t ? p->in_t : p->out_t; + void* buf = const_cast<char*>(buf_t->tensor_data().data()); + nccl_result = ncclBcast(buf, buf_t->NumElements(), data_type, + collective->root_rank, nccl_comm, *cu_stream); + break; + } + } + + // Run the done_callback when the nccl kernel finishes running. + auto done_callback = [collective, rank, nccl_result]() { + if (nccl_result == ncclSuccess) { + collective->participants[rank]->done_callback(Status::OK()); + } else { + // Propagate the error, but note that if other members of the collective + // did launch their kernels, then they are hanging. + collective->participants[rank]->done_callback(errors::Unknown( + "Error invoking AllReduce: ", ncclGetErrorString(nccl_result))); + } + + // TODO(cwhipkey): use RefCounted after figuring out how to use in a + // custom op library. + // See tensorflow/core/lib/core/refcount.h for details on this locking. + if (collective->remaining_participants.load(std::memory_order_acquire) == + 1 || + collective->remaining_participants.fetch_sub(1) == 1) { + delete collective; + } + }; + p->event_mgr->ThenExecute(comm_stream, done_callback); + } +} + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.h b/tensorflow/contrib/nccl/kernels/nccl_manager.h new file mode 100644 index 0000000000..8d5e5ddf76 --- /dev/null +++ b/tensorflow/contrib/nccl/kernels/nccl_manager.h @@ -0,0 +1,122 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_ + +#ifdef GOOGLE_CUDA + +#include <unordered_map> +#include <vector> + +#include "external/nccl_archive/src/nccl.h" +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor.h" + +namespace tensorflow { + +// The communicator is used to make the asynchronous communicator calls and to +// manage the per-device streams used for communication. +// +// See nccl_ops.cc for example usage, including description of memory +// management and stream synchronization. +class NcclManager { + public: + typedef std::function<void(Status)> DoneCallback; + NcclManager(); + ~NcclManager(); + + static NcclManager* instance(); + + // Add one participant to an all-reduce, sending in data from <in_t> and + // receiving the result of the all-reduce in <out_t>. The device for this + // participant is managed by <executor>, and its events are polled by + // <event_mgr>. + // + // This is an asynchronous call. When <done_callback> is called, <out_t> has + // been set to the all-reduce result (note: the stream may not yet have been + // synced). + // + // <tensor_stream> is the stream that should be waited on to ensure <in_t>'s + // data is available on the GPU for the communication stream to access. It + // is also the stream that will use the produced data; <done_callback> is + // not called until the next kernel launched on <stream> would see the data. + void AddToAllReduce(int num_devices, const string& key, + ncclRedOp_t reduction_op, + perftools::gputools::StreamExecutor* executor, + EventMgr* event_mgr, + perftools::gputools::Stream* tensor_stream, + const Tensor* in_t, Tensor* out_t, + const DoneCallback& done_callback); + + // AddBroadcastSend and AddBroadcastRecv combine to sent data from one sender + // to all receivers. + void AddBroadcastSend(int num_devices, const string& key, + perftools::gputools::StreamExecutor* executor, + EventMgr* event_mgr, + perftools::gputools::Stream* tensor_stream, + const Tensor* in_t, DoneCallback done_callback); + void AddBroadcastRecv(int num_devices, const string& key, + perftools::gputools::StreamExecutor* executor, + EventMgr* event_mgr, + perftools::gputools::Stream* tensor_stream, + Tensor* out_t, DoneCallback done_callback); + + private: + enum CollectiveType { + kAllReduce = 1, + kBroadcast = 2, + }; + struct Collective; + struct Communicator; + struct CommunicatorMember; + struct NcclStream; + struct Participant; + + Communicator* GetCommunicator(Collective* collective); + + void AddParticipant(int num_devices, const string& key, + std::unique_ptr<Participant> participant, + DataType data_type, CollectiveType collective_type, + ncclRedOp_t reduction_op); + + // Run <collective>. This calls takes ownership of <collective>. + void RunCollective(const string& key, Collective* collective); + void LoopKernelLaunches(NcclStream* stream); + + mutex mu_; + + // Maps key to collectives currently being assembled or run. + std::unordered_map<string, std::unique_ptr<Collective>> collectives_ + GUARDED_BY(mu_); + + // Maps a device to the communication streams that make up its collective. + // This is used to share the stream across different communicators that + // include the same device. + std::map<perftools::gputools::StreamExecutor*, + std::vector<std::unique_ptr<NcclStream>>> + device_to_comm_streams_ GUARDED_BY(mu_); + + std::vector<std::unique_ptr<Communicator>> communicators_; + + TF_DISALLOW_COPY_AND_ASSIGN(NcclManager); +}; + +} // namespace tensorflow + +#endif // GOOGLE_CUDA + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_ diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc new file mode 100644 index 0000000000..b53cb82440 --- /dev/null +++ b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc @@ -0,0 +1,285 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef GOOGLE_CUDA + +#include <algorithm> +#include <vector> + +#include "tensorflow/contrib/nccl/kernels/nccl_manager.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/gpu/gpu_device.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +static std::vector<BaseGPUDevice*> GetGPUDevices() { + std::vector<Device*> devices; + SessionOptions session_options; + session_options.env = Env::Default(); + Status s = DeviceFactory::GetFactory(DEVICE_GPU) + ->AddDevices(session_options, "", &devices); + TF_CHECK_OK(s); + std::vector<BaseGPUDevice*> gpus; + for (Device* d : devices) { + if (d->device_type() == "GPU") { + gpus.push_back(static_cast<BaseGPUDevice*>(d)); + } else { + delete d; + } + } + return gpus; +} + +class NcclManagerTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + setenv("NCCL_DEBUG", "INFO", 1 /* replace */); + devices = new std::vector<BaseGPUDevice*>(GetGPUDevices()); + CHECK(!devices->empty()); + LOG(ERROR) << "Running test with " << devices->size() << " gpus"; + } + static void TearDownTestCase() { + for (auto device : *devices) delete device; + delete devices; + } + + static Allocator* gpu_allocator(BaseGPUDevice* device) { + return device->GetStepAllocator(AllocatorAttributes(), + nullptr /* step_resource_manager */); + } + + static std::vector<BaseGPUDevice*>* devices; + + template <typename Scalar> + perftools::gputools::DeviceMemory<Scalar> AsDeviceMemory( + const Scalar* cuda_memory) { + perftools::gputools::DeviceMemoryBase wrapped( + const_cast<Scalar*>(cuda_memory)); + perftools::gputools::DeviceMemory<Scalar> typed(wrapped); + return typed; + } + + // A single all-reduce to apply. + struct TestCase { + string key; + std::vector<Tensor> ins; + std::vector<Tensor> outs; + Tensor expected; + + mutex mu; + Status final_status; + int num_completed = 0; + }; + + TestCase* MakeTestCase(int num_ranks, ncclRedOp_t reduction_op, + TensorShape shape, float value_offset) { + TestCase* test_case = new TestCase(); + test_case->expected = Tensor(DT_FLOAT, shape); + if (reduction_op == ncclProd) { + test::FillFn<float>(&test_case->expected, [](int) { return 1; }); + } else if (reduction_op == ncclSum) { + test::FillFn<float>(&test_case->expected, [](int) { return 0; }); + } else if (reduction_op == ncclMax) { + test::FillFn<float>(&test_case->expected, [](int) { + return -1 * std::numeric_limits<float>::max(); + }); + } else if (reduction_op == ncclMin) { + test::FillFn<float>(&test_case->expected, [](int) { + return std::numeric_limits<float>::max(); + }); + } else { + LOG(FATAL) << "Invalid reduction_op " << reduction_op; + } + + int mult = 1; + for (int i = 0; i < num_ranks; ++i) { + auto* device = devices->at(i % devices->size()); + auto* stream = device->tensorflow_gpu_device_info()->stream; + + Tensor in_cpu(DT_FLOAT, shape); + test::FillFn<float>(&in_cpu, [mult, value_offset](int index) { + return value_offset + (index + 1) * mult; + }); + for (int j = 0; j < shape.num_elements(); ++j) { + auto in_val = in_cpu.flat<float>()(j); + auto out_expr = test_case->expected.flat<float>(); + if (reduction_op == ncclProd) { + out_expr(j) *= in_val; + } else if (reduction_op == ncclSum) { + out_expr(j) += in_val; + } else if (reduction_op == ncclMax) { + if (in_val > out_expr(j)) { + out_expr(j) = in_val; + } + } else if (reduction_op == ncclMin) { + if (in_val < out_expr(j)) { + out_expr(j) = in_val; + } + } + } + + mult *= 10; + test_case->ins.emplace_back(gpu_allocator(device), DT_FLOAT, shape); + test_case->outs.emplace_back(gpu_allocator(device), DT_FLOAT, shape); + + const Tensor& in_gpu = test_case->ins.back(); + auto in_gpu_mem = AsDeviceMemory(in_gpu.flat<float>().data()); + stream->ThenMemcpy(&in_gpu_mem, in_cpu.flat<float>().data(), + in_cpu.TotalBytes()); + } + return test_case; + } + + NcclManager::DoneCallback CreateDoneCallback(TestCase* test_case) { + return [this, test_case](Status s) { + mutex_lock l(test_case->mu); + ++test_case->num_completed; + test_case->final_status.Update(s); + }; + } + + void VerifyResults(const string& case_label, TestCase* test_case) { + // Wait for the done callback to be called. + { + test_case->mu.lock(); + while (test_case->num_completed != test_case->outs.size()) { + test_case->mu.unlock(); + Env::Default()->SleepForMicroseconds(10); + test_case->mu.lock(); + } + test_case->mu.unlock(); + } + // Copy memory to host and verify. + for (int i = 0; i < test_case->outs.size(); ++i) { + auto* device = devices->at(i % devices->size()); + auto* stream = device->tensorflow_gpu_device_info()->stream; + const Tensor& out_gpu = test_case->outs[i]; + Tensor out_cpu(DT_FLOAT, out_gpu.shape()); + auto out_gpu_mem = AsDeviceMemory(out_gpu.flat<float>().data()); + stream->ThenMemcpy(out_cpu.flat<float>().data(), out_gpu_mem, + out_cpu.TotalBytes()); + stream->BlockHostUntilDone(); + test::ExpectTensorEqual<float>(test_case->expected, out_cpu); + } + } +}; +std::vector<BaseGPUDevice*>* NcclManagerTest::devices = nullptr; + +// Test basic sum reduction. +TEST_F(NcclManagerTest, BasicSumReduction) { + const int num_ranks = 3; + + for (int op = 0; op < 4; ++op) { + ncclRedOp_t reduction_op = static_cast<ncclRedOp_t>(op); + std::unique_ptr<TestCase> test_case( + MakeTestCase(num_ranks, reduction_op, TensorShape({2, 3}), 0)); + for (int device_num = 0; device_num < num_ranks; ++device_num) { + auto* device = devices->at(device_num % devices->size()); + auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; + auto* stream = device->tensorflow_gpu_device_info()->stream; + NcclManager::instance()->AddToAllReduce( + num_ranks, "allreduce", reduction_op, device->executor(), event_mgr, + stream, &test_case->ins[device_num], &test_case->outs[device_num], + CreateDoneCallback(test_case.get())); + } + + LOG(ERROR) << "Verifying results"; + VerifyResults("test_case", test_case.get()); + } +} + +// Same as the Basic test, but with multiple threads launching parts of many +// reductions. +// +// Testing the multi-rank execution is currently reduced as it can hang when run +// with num_ranks > devices->size(), for some GPUs (e.g. K20m). +// To test the higher settings, increase num_ranks, +// num_collectives_per_iteration and time_limit_micros. +TEST_F(NcclManagerTest, MultipleCallers) { + const int num_ranks = 1; // 2; + const int num_collectives_per_iteration = 1; // 1000; + const int num_threads = 3; + const int time_limit_micros = 1; // 60 * 30 * 1000 * 1000; + + int64 start = Env::Default()->NowMicros(); + srand(Env::Default()->NowMicros()); + + for (;;) { + std::vector<std::pair<int, int>> case_and_device_num; + std::vector<std::unique_ptr<TestCase>> test_cases; + for (int i = 0; i < num_collectives_per_iteration; ++i) { + test_cases.emplace_back( + MakeTestCase(num_ranks, ncclSum, + TensorShape({100, i % 5 + 1, i % 3 + 1}), i + 0.1 * i)); + for (int j = 0; j < num_ranks; ++j) { + case_and_device_num.emplace_back(i, j); + } + } + + for (int i = 0; i < num_ranks; ++i) { + auto* device = devices->at(i % devices->size()); + auto* stream = device->tensorflow_gpu_device_info()->stream; + stream->BlockHostUntilDone(); + } + + std::random_shuffle(case_and_device_num.begin(), case_and_device_num.end()); + + mutex mu; // guards case_and_device_num. + std::unique_ptr<thread::ThreadPool> pool( + new thread::ThreadPool(Env::Default(), "test", num_threads)); + const int to_schedule = case_and_device_num.size(); + for (int i = 0; i < to_schedule; ++i) { + auto fn = [&]() { + int device_num; + int test_num; + { + mutex_lock l(mu); + test_num = case_and_device_num.back().first; + device_num = case_and_device_num.back().second; + case_and_device_num.pop_back(); + } + auto* device = devices->at(device_num % devices->size()); + auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; + auto* stream = device->tensorflow_gpu_device_info()->stream; + TestCase* test_case = test_cases[test_num].get(); + NcclManager::instance()->AddToAllReduce( + num_ranks, strings::StrCat("allreduce", test_num), ncclSum, + device->executor(), event_mgr, stream, &test_case->ins[device_num], + &test_case->outs[device_num], CreateDoneCallback(test_case)); + }; + pool->Schedule(fn); + } + pool.reset(); // wait for all work to be scheduled. + + LOG(ERROR) << "Verifying results for " << num_collectives_per_iteration + << " collectives"; + for (int i = 0; i < test_cases.size(); ++i) { + VerifyResults(strings::StrCat("collective", i), test_cases[i].get()); + } + + int64 delta = Env::Default()->NowMicros() - start; + if (delta > time_limit_micros) { + LOG(ERROR) << "Ran for " << delta << " quitting"; + break; + } + } +} + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/nccl/kernels/nccl_ops.cc b/tensorflow/contrib/nccl/kernels/nccl_ops.cc new file mode 100644 index 0000000000..db6ee3e0e7 --- /dev/null +++ b/tensorflow/contrib/nccl/kernels/nccl_ops.cc @@ -0,0 +1,157 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#include <unordered_map> +#include <vector> + +#include "external/nccl_archive/src/nccl.h" +#include "tensorflow/contrib/nccl/kernels/nccl_manager.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +// Base class for all communicator ops that use nccl. +// +// About memory management and stream syncing: +// 1. The nccl communicator has a stream for each rank. +// 2. For input tensors to the communicator, the compute stream is passed to the +// NcclManager which will do a needed +// communicator_stream.ThenWaitFor(input_tensor_stream). +// 3. The done_callback of the async kernel is not called by the +// NcclManager until after the communicator kernel is complete. This +// is enough to a) keep the input tensor data valid for the lifetime of the +// collective; and b) ensure the data in the output tensor is available +// when the async op kernel's done callback is called. +class NcclAsyncOpBase : public AsyncOpKernel { + public: + NcclAsyncOpBase(OpKernelConstruction* c) : AsyncOpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("num_devices", &num_devices_)); + OP_REQUIRES_OK(c, c->GetAttr("shared_name", &collective_prefix_)); + } + + string GetCollectiveKey(OpKernelContext* c) { + return strings::StrCat(collective_prefix_, ";", c->step_id(), ";", + c->frame_iter().frame_id, ":", + c->frame_iter().iter_id); + } + + int num_devices() const { return num_devices_; } + + private: + int num_devices_; + string collective_prefix_; + + TF_DISALLOW_COPY_AND_ASSIGN(NcclAsyncOpBase); +}; + +// To execute a single all-reduce, this kernel is called once for each of the +// <k> devices in the communicator. +class NcclAllReduceOpKernel : public NcclAsyncOpBase { + public: + NcclAllReduceOpKernel(OpKernelConstruction* c) : NcclAsyncOpBase(c) { + string reduction; + OP_REQUIRES_OK(c, c->GetAttr("reduction", &reduction)); + if (reduction == "min") { + reduction_op_ = ncclMin; + } else if (reduction == "max") { + reduction_op_ = ncclMax; + } else if (reduction == "sum") { + reduction_op_ = ncclSum; + } else if (reduction == "prod") { + reduction_op_ = ncclProd; + } else { + OP_REQUIRES_OK(c, + errors::InvalidArgument("Invalid reduction: ", reduction)); + } + } + + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + const Tensor* in_t = &c->input(0); + Tensor* out_t; + OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, in_t->shape(), &out_t), done); + + auto actual_done = [c, done](Status s) { + OP_REQUIRES_OK_ASYNC(c, s, done); + done(); + }; + + auto* compute_stream = c->op_device_context()->stream(); + EventMgr* event_mgr = c->device()->tensorflow_gpu_device_info()->event_mgr; + NcclManager::instance()->AddToAllReduce( + num_devices(), GetCollectiveKey(c), reduction_op_, + compute_stream->parent(), event_mgr, compute_stream, in_t, out_t, + actual_done); + } + + private: + ncclRedOp_t reduction_op_; +}; + +REGISTER_KERNEL_BUILDER(Name("NcclAllReduce").Device(DEVICE_GPU), + NcclAllReduceOpKernel); + +class NcclBroadcastSendKernel : public NcclAsyncOpBase { + public: + NcclBroadcastSendKernel(OpKernelConstruction* c) : NcclAsyncOpBase(c) {} + + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + auto actual_done = [c, done](Status s) { + OP_REQUIRES_OK_ASYNC(c, s, done); + done(); + }; + + auto* compute_stream = c->op_device_context()->stream(); + EventMgr* event_mgr = c->device()->tensorflow_gpu_device_info()->event_mgr; + NcclManager::instance()->AddBroadcastSend( + num_devices(), GetCollectiveKey(c), compute_stream->parent(), event_mgr, + compute_stream, &c->input(0), std::move(actual_done)); + } +}; +REGISTER_KERNEL_BUILDER(Name("NcclBroadcastSend").Device(DEVICE_GPU), + NcclBroadcastSendKernel); + +class NcclBroadcastRecvKernel : public NcclAsyncOpBase { + public: + NcclBroadcastRecvKernel(OpKernelConstruction* c) : NcclAsyncOpBase(c) {} + + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + const Tensor& shape_t = c->input(0); + TensorShape shape; + OP_REQUIRES_OK_ASYNC( + c, TensorShapeUtils::MakeShape(shape_t.vec<int64>(), &shape), done); + Tensor* out_t; + OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape, &out_t), done); + + auto actual_done = [c, done](Status s) { + OP_REQUIRES_OK_ASYNC(c, s, done); + done(); + }; + + auto* compute_stream = c->op_device_context()->stream(); + EventMgr* event_mgr = c->device()->tensorflow_gpu_device_info()->event_mgr; + NcclManager::instance()->AddBroadcastRecv( + num_devices(), GetCollectiveKey(c), compute_stream->parent(), event_mgr, + compute_stream, out_t, std::move(actual_done)); + } +}; +REGISTER_KERNEL_BUILDER( + Name("NcclBroadcastRecv").Device(DEVICE_GPU).HostMemory("shape"), + NcclBroadcastRecvKernel); + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/nccl/ops/nccl_ops.cc b/tensorflow/contrib/nccl/ops/nccl_ops.cc new file mode 100644 index 0000000000..d767636fad --- /dev/null +++ b/tensorflow/contrib/nccl/ops/nccl_ops.cc @@ -0,0 +1,94 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +REGISTER_OP("NcclAllReduce") + .Input("input: T") + .Output("data: T") + .Attr("reduction: {'min', 'max', 'prod', 'sum'}") + .Attr("T: {float, float64, int32, int64}") + .Attr("num_devices: int") + .Attr("shared_name: string") + .SetIsStateful() + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Outputs a tensor containing the reduction across all input tensors passed to ops +within the same `shared_name. + +The graph should be constructed so if one op runs with shared_name value `c`, +then `num_devices` ops will run with shared_name value `c`. Failure to do so +will cause the graph execution to fail to complete. + +input: the input to the reduction +data: the value of the reduction across all `num_devices` devices. +reduction: the reduction operation to perform. +num_devices: The number of devices participating in this reduction. +shared_name: Identifier that shared between ops of the same reduction. +)doc"); + +REGISTER_OP("NcclBroadcastSend") + .Input("input: T") + .Attr("T: {float, float64, int32, int64}") + .Attr("num_devices: int") + .Attr("shared_name: string") + .SetIsStateful() + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Sends `input` to the NcclBroadcastRecv ops registered in the same `shared_name`. + +The graph should be constructed so that one device runs `NcclBroadcastSend` and +`num_devices-1` devices run NcclBroadcastRecv ops with shared_name value `c`. +Failure to do so will cause the graph execution to fail to complete. + +input: The input to the broadcast +num_devices: The number of devices participating in this reduction. +shared_name: Identifier that is shared between ops of the same broadcast. + )doc"); + +REGISTER_OP("NcclBroadcastRecv") + .Input("shape: int64") + .Output("output: T") + .Attr("T: {float, float64, int32, int64}") + .Attr("num_devices: int") + .Attr("shared_name: string") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); + c->set_output(0, out); + return Status::OK(); + }) + .Doc(R"doc( +Sends data of shape `shape` from the NcclBroadcastSend op registered in the +same `shared_name`. + +The graph should be constructed so that one device runs `NcclBroadcastSend` and +`num_devices-1` devices run NcclBroadcastRecv ops with shared_name value `c`. +Failure to do so will cause the graph execution to fail to complete. + +shape: The shape of the output. +output: The broadcast data received from the NcclBroadcastSend op. +num_devices: The number of devices participating in this reduction. +shared_name: Identifier that is shared between ops of the same broadcast. + )doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops.py b/tensorflow/contrib/nccl/python/ops/nccl_ops.py new file mode 100644 index 0000000000..b31cc53e0a --- /dev/null +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops.py @@ -0,0 +1,168 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops for GPU collective operations implemented using NVIDIA nccl.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +from tensorflow.contrib.nccl.ops import gen_nccl_ops +from tensorflow.contrib.util import loader +from tensorflow.python.framework import device +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import resource_loader + +_nccl_ops_so = loader.load_op_library( + resource_loader.get_path_to_datafile('_nccl_ops.so')) + + +def all_sum(tensors): + """Returns a list of tensors with the all-reduce sum across `tensors`. + + The computation is done with an all-reduce operation, so if only some of the + returned tensors are evaluated then the computation will hang. + + Args: + tensors: The input tensors across which to sum; must be assigned + to GPU devices. + + Returns: + List of tensors, each with the sum of the input tensors, where tensor i has + the same device as `tensors[i]`. + """ + return _apply_all_reduce('sum', tensors) + + +def all_prod(tensors): + """Returns a list of tensors with the all-reduce product across `tensors`. + + The computation is done with an all-reduce operation, so if only some of the + returned tensors are evaluated then the computation will hang. + + Args: + tensors: The input tensors across which to multiply; must be assigned + to GPU devices. + + Returns: + List of tensors, each with the product of the input tensors, where tensor i + has the same device as `tensors[i]`. + """ + return _apply_all_reduce('prod', tensors) + + +def all_min(tensors): + """Returns a list of tensors with the all-reduce min across `tensors`. + + The computation is done with an all-reduce operation, so if only some of the + returned tensors are evaluated then the computation will hang. + + Args: + tensors: The input tensors across which to reduce; must be assigned + to GPU devices. + + Returns: + List of tensors, each with the minimum of the input tensors, where tensor i + has the same device as `tensors[i]`. + """ + return _apply_all_reduce('min', tensors) + + +def all_max(tensors): + """Returns a list of tensors with the all-reduce max across `tensors`. + + The computation is done with an all-reduce operation, so if only some of the + returned tensors are evaluated then the computation will hang. + + Args: + tensors: The input tensors across which to reduce; must be assigned + to GPU devices. + + Returns: + List of tensors, each with the maximum of the input tensors, where tensor i + has the same device as `tensors[i]`. + """ + return _apply_all_reduce('max', tensors) + + +def broadcast(src_tensor, dst_devices): + """Returns a list of tensors on `dst_devices`, each with value `tensor`. + + The computation is done with a broadcast nccl operation, so if only some of + the returned tensors and src_tensor are evaluated then the computation will + hang. + + Args: + src_tensor: The tensor to send; must be assigned to a GPU device. + dst_devices: The GPU devices to receive the sent tensor. + + Returns: + List of tensors, each with the value of `src_tensor`, which the device + of tensor i is `dst_devices[i]`. + """ + if not dst_devices: + raise ValueError('Must pass >0 dst_devices to broadcast') + all_devices = [src_tensor.device] + dst_devices + shared_name = _get_shared_name() + + with ops.device(src_tensor.device): + send = gen_nccl_ops.nccl_broadcast_send( + input=src_tensor, num_devices=len(all_devices), shared_name=shared_name) + + shape_op = array_ops.shape(src_tensor, out_type=dtypes.int64) + recvs = [] + for d in dst_devices: + with ops.device(d): + recvs.append( + gen_nccl_ops.nccl_broadcast_recv( + shape=shape_op, + T=src_tensor.dtype, + num_devices=len(all_devices), + shared_name=shared_name)) + + return send, recvs + + +def _apply_all_reduce(reduction_op, tensors): + if not tensors: + raise ValueError('Must pass >0 tensors to all reduce operations') + shared_name = _get_shared_name() + res = [] + for t in tensors: + if not device.canonical_name(t.device): + raise ValueError('Device assignment required for nccl collective ops') + with ops.device(t.device): + res.append( + gen_nccl_ops.nccl_all_reduce( + t, + reduction=reduction_op, + num_devices=len(tensors), + shared_name=shared_name)) + return res + + +_lock = threading.Lock() +_shared_name_counter = 0 + + +def _get_shared_name(): + global _shared_name_counter + + with _lock: + val = _shared_name_counter + _shared_name_counter += 1 + return 'c%s' % val diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py new file mode 100644 index 0000000000..130cb4ca12 --- /dev/null +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py @@ -0,0 +1,151 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for nccl ops. See also the cc test for nccl_communicator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib import nccl +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class AllReduceTest(test.TestCase): + + def testAllReduce(self): + if not test.is_gpu_available(): + return # Test requires access to a GPU + + for dtype in [np.float32, np.int32, np.int64, np.float64]: + # Create session inside outer loop to test use of + # same communicator across multiple sessions. + with self.test_session(use_gpu=True) as sess: + self._testSingleAllReduce(sess, dtype, nccl.all_sum, lambda x, y: x + y) + self._testSingleAllReduce(sess, dtype, nccl.all_prod, + lambda x, y: x * y) + self._testSingleAllReduce(sess, dtype, nccl.all_min, np.minimum) + self._testSingleAllReduce(sess, dtype, nccl.all_max, np.maximum) + + def _testSingleAllReduce(self, sess, np_type, nccl_fn, numpy_accumulation_fn): + for devices in [['/gpu:0', '/gpu:0', '/gpu:0'], ['/gpu:0', '/gpu:0']]: + shape = (3, 4) + np_ans = None + tensors = [] + for d in devices: + with ops.device(d): + t = ((np.random.random_sample(shape) - .5) * 1024).astype(np_type) + if np_ans is None: + np_ans = t + else: + np_ans = numpy_accumulation_fn(np_ans, t) + tensors.append(array_ops.identity(t)) + + all_reduce_tensors = nccl_fn(tensors) + + # Test shape inference. + for r in all_reduce_tensors: + self.assertEqual(shape, r.get_shape()) + + # Test execution and results. + nccl_results = sess.run(all_reduce_tensors) + for r in nccl_results: + self.assertAllClose(r, np_ans) + + def testErrors(self): + with self.assertRaisesRegexp(ValueError, 'Device assignment required'): + nccl.all_sum([array_ops.identity(np.random.random_sample((3, 4)))]) + with self.assertRaisesRegexp(ValueError, 'Must pass >0 tensors'): + nccl.all_sum([]) + + +class BroadcastTest(test.TestCase): + + def testBroadcast(self): + if not test.is_gpu_available(): + return # Test requires access to a GPU + + for dtype in [np.float32, np.int32, np.int64, np.float64]: + # Create session inside outer loop to test use of + # same communicator across multiple sessions. + with self.test_session(use_gpu=True) as sess: + for devices in [['/gpu:0', '/gpu:0', '/gpu:0'], ['/gpu:0', '/gpu:0']]: + shape = (3, 4) + sender = np.random.randint(0, len(devices) - 1) + with ops.device(devices[sender]): + np_ans = (( + (np.random.random_sample(shape) - .5) * 1024).astype(dtype)) + t = array_ops.identity(np_ans) + other_devices = devices[:sender] + devices[sender + 1:] + send_op, received_tensors = nccl.broadcast(t, other_devices) + + # Verify shape inference. + for r in received_tensors: + self.assertEqual(shape, r.get_shape()) + + # Run and verify results. + nccl_results = sess.run(received_tensors + [send_op]) + for r in nccl_results[:-1]: + self.assertAllClose(r, np_ans) + + +class CombinedTest(test.TestCase): + """Tests using a mix of all-reduce ops in one session.run call.""" + + def testCombined(self): + if not test.is_gpu_available(): + return # Test requires access to a GPU + + for dtype in [np.float32, np.int32, np.int64, np.float64]: + # Create session inside outer loop to test use of + # same communicator across multiple sessions. + with self.test_session(use_gpu=True) as sess: + for devices in [['/gpu:0', '/gpu:0', '/gpu:0'], ['/gpu:0', '/gpu:0']]: + shape = (3, 4) + + # all-reduce + np_ans = np.zeros(shape=shape, dtype=dtype) + tensors = [] + for d in devices: + with ops.device(d): + t = ((np.random.random_sample(shape) - .5) * 1024).astype(dtype) + np_ans += t + tensors.append(array_ops.identity(t)) + all_reduce_tensors = nccl.all_sum(tensors) + + sender = np.random.randint(0, len(devices) - 1) + other_devices = devices[:sender] + devices[sender + 1:] + send_op, received_tensors = nccl.broadcast(all_reduce_tensors[sender], + other_devices) + + # sender doesn't need to be fetched as part of outputs of session.run. + del all_reduce_tensors[sender] + + # Verify shape inference. + for r in received_tensors: + self.assertEqual(shape, r.get_shape()) + + # Run and verify results. + nccl_results = sess.run( + received_tensors + [send_op] + all_reduce_tensors) + for r in nccl_results[:len(received_tensors)]: + self.assertAllClose(r, np_ans) + + +if __name__ == '__main__': + test.main() |