aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/nccl
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-01-24 15:19:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-24 15:29:03 -0800
commit5be95cbb389bc112161232c8514155947063ea72 (patch)
treef0812bf2efeb798155000d6133f9e58fd3738c86 /tensorflow/contrib/nccl
parent761405e7202e1bec875f1ca7d1a7660ebbb3dafb (diff)
Add contrib/nccl for using all-reduce collectives across GPUs of a single
server. Change: 145475050
Diffstat (limited to 'tensorflow/contrib/nccl')
-rw-r--r--tensorflow/contrib/nccl/BUILD120
-rw-r--r--tensorflow/contrib/nccl/__init__.py24
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager.cc471
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager.h122
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager_test.cc285
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_ops.cc157
-rw-r--r--tensorflow/contrib/nccl/ops/nccl_ops.cc94
-rw-r--r--tensorflow/contrib/nccl/python/ops/nccl_ops.py168
-rw-r--r--tensorflow/contrib/nccl/python/ops/nccl_ops_test.py151
9 files changed, 1592 insertions, 0 deletions
diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD
new file mode 100644
index 0000000000..7c352dae88
--- /dev/null
+++ b/tensorflow/contrib/nccl/BUILD
@@ -0,0 +1,120 @@
+# Description:
+# Wrap NVIDIA (https://github.com/NVIDIA/nccl) NCCL with tensorflow ops.
+# APIs are meant to change over time.
+package(
+ default_visibility = ["//visibility:private"],
+ features = ["-parse_headers"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cuda_cc_test",
+ "tf_custom_op_library",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+)
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
+
+tf_custom_op_library(
+ name = "python/ops/_nccl_ops.so",
+ srcs = [
+ "kernels/nccl_manager.cc",
+ "kernels/nccl_manager.h",
+ "kernels/nccl_ops.cc",
+ "ops/nccl_ops.cc",
+ ],
+ deps = [
+ "//tensorflow/core:gpu_headers_lib",
+ "@nccl_archive//:nccl",
+ ],
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["nccl_ops"],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "nccl_ops",
+ deps = [":nccl_ops_op_lib"],
+)
+
+py_library(
+ name = "nccl_py",
+ srcs = [
+ "__init__.py",
+ "python/ops/nccl_ops.py",
+ ],
+ data = [
+ ":python/ops/_nccl_ops.so",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":nccl_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:platform",
+ ],
+)
+
+cuda_py_test(
+ name = "nccl_ops_test",
+ size = "small",
+ srcs = ["python/ops/nccl_ops_test.py"],
+ additional_deps = [
+ ":nccl_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+ tags = [
+ "manual",
+ "requires_cudnn5",
+ ],
+)
+
+tf_cuda_cc_test(
+ name = "nccl_manager_test",
+ size = "small",
+ srcs = if_cuda(
+ [
+ "kernels/nccl_manager.cc",
+ "kernels/nccl_manager.h",
+ "kernels/nccl_manager_test.cc",
+ ],
+ [],
+ ),
+ deps = if_cuda(
+ [
+ "@nccl_archive//:nccl",
+ "//tensorflow/core",
+ "//tensorflow/core:cuda",
+ ],
+ [],
+ ) + [
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/nccl/__init__.py b/tensorflow/contrib/nccl/__init__.py
new file mode 100644
index 0000000000..0275ed6079
--- /dev/null
+++ b/tensorflow/contrib/nccl/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops for nccl AllReduce."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# go/tf-wildcard-import
+# pylint: disable=wildcard-import
+from tensorflow.contrib.nccl.python.ops.nccl_ops import *
+# pylint: enable=wildcard-import
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.cc b/tensorflow/contrib/nccl/kernels/nccl_manager.cc
new file mode 100644
index 0000000000..31e85b571d
--- /dev/null
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager.cc
@@ -0,0 +1,471 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/nccl/kernels/nccl_manager.h"
+
+#ifdef GOOGLE_CUDA
+
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/cuda.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+
+using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
+
+// Contains data for a single stream used for nccl communication; this includes
+// a background thread that calls NcclManager::LoopKernelLaunches.
+struct NcclManager::NcclStream {
+ public:
+ NcclStream() {}
+ ~NcclStream() {
+ mutex_lock l(mu);
+ shutdown_requested = true;
+ cv.notify_all();
+ }
+
+ perftools::gputools::StreamExecutor* executor = nullptr;
+
+ // The stream on which to run the nccl collective.
+ // This is a different stream than the tensorflow compute stream.
+ std::unique_ptr<perftools::gputools::Stream> stream;
+
+ // See NcclManager::LoopKernelLaunches for information on these.
+ std::unique_ptr<Thread> thread;
+ mutex mu;
+ condition_variable cv;
+ // Has collective,rank pairs.
+ std::deque<std::pair<Collective*, int>> pending_launches_ GUARDED_BY(mu);
+ bool shutdown_requested GUARDED_BY(mu) = false;
+};
+
+struct NcclManager::CommunicatorMember {
+ public:
+ CommunicatorMember() {}
+ ~CommunicatorMember() {
+ if (nccl_comm != nullptr) ncclCommDestroy(nccl_comm);
+ }
+ ncclComm_t nccl_comm;
+
+ // Owned by NcclManager::device_to_comm_streams_.
+ NcclStream* nccl_stream = nullptr;
+};
+
+struct NcclManager::Communicator {
+ public:
+ Communicator(std::vector<CommunicatorMember> members)
+ : num_devices(members.size()), members(std::move(members)) {}
+
+ const int num_devices;
+ const std::vector<CommunicatorMember> members; // indexed by rank.
+};
+
+namespace {
+ncclDataType_t ToNcclType(DataType t) {
+ switch (t) {
+ case DT_FLOAT:
+ return ncclFloat;
+ case DT_DOUBLE:
+ return ncclDouble;
+ case DT_INT32:
+ return ncclInt;
+ case DT_INT64:
+ return ncclInt64;
+ default:
+ return ncclFloat;
+ }
+}
+} // namespace
+
+// A participant in a Collective. See <Collective> below.
+struct NcclManager::Participant {
+ Participant(const Tensor* in_t, Tensor* out_t, EventMgr* event_mgr,
+ perftools::gputools::Stream* tensor_stream,
+ perftools::gputools::StreamExecutor* executor,
+ NcclManager::DoneCallback done_callback)
+ : in_t(in_t),
+ out_t(out_t),
+ event_mgr(event_mgr),
+ tensor_stream(tensor_stream),
+ executor(executor),
+ done_callback(std::move(done_callback)) {
+ DCHECK(executor != nullptr);
+ DCHECK(event_mgr != nullptr);
+ DCHECK(tensor_stream != nullptr);
+ }
+ // Owned by the caller, who must keep it live until <done_callback> is called.
+ // Is NULL for participants that only receive data.
+ const Tensor* in_t;
+
+ // Owned by the caller, who must keep it live until <done_callback> is called.
+ // Is NULL for participants that only send data.
+ Tensor* out_t;
+
+ // Owned by the caller, who must keep it live until <done_callback> is called.
+ EventMgr* const event_mgr;
+
+ // Owned by the caller, who must keep it live until <done_callback> is called.
+ perftools::gputools::Stream* const tensor_stream;
+
+ // Matches the executor in CommunicatorMember::stream. Expected to be live for
+ // process lifetime.
+ perftools::gputools::StreamExecutor* executor = nullptr;
+
+ NcclManager::DoneCallback done_callback;
+
+ bool root = false;
+};
+
+// A Collective tracks a single communicator operation (e.g., a single
+// AllReduce call).
+struct NcclManager::Collective {
+ Collective(DataType data_type_in, CollectiveType type_in,
+ ncclRedOp_t reduction_op_in, int num_devices)
+ : data_type(data_type_in),
+ type(type_in),
+ reduction_op(reduction_op_in),
+ remaining_participants(num_devices) {
+ participants.reserve(num_devices);
+ }
+
+ const DataType data_type;
+ const CollectiveType type;
+ const ncclRedOp_t reduction_op; // applies when <type> is a reduction.
+
+ Communicator* communicator = nullptr;
+
+ // All collective participants.
+ //
+ // Adding values in this vector is guarded by the mutex of the containing
+ // NcclManager.
+ std::vector<std::unique_ptr<Participant>> participants;
+
+ // For collective types that have a root (e.g. the root of broadcast is the
+ // sender), this is the rank of the root.
+ int root_rank = -1;
+
+ // How many participants have been registered so far. The Collective is
+ // eligible for running with <available_participants> == participants.size().
+ //
+ // Guarded by the mutex of the containing Communicator.
+ int available_participants = 0;
+
+ mutable std::atomic_int_fast32_t remaining_participants;
+};
+
+NcclManager::NcclManager() {}
+NcclManager::~NcclManager() {}
+NcclManager* NcclManager::instance() {
+ static NcclManager* instance = new NcclManager();
+ return instance;
+}
+
+NcclManager::Communicator* NcclManager::GetCommunicator(
+ NcclManager::Collective* collective) {
+ // Sort by executor to make ordering of executors deterministic.
+ std::sort(collective->participants.begin(), collective->participants.end(),
+ [](const std::unique_ptr<Participant>& a,
+ const std::unique_ptr<Participant>& b) {
+ return a->executor < b->executor;
+ });
+ const int num_devices = collective->participants.size();
+
+ mutex_lock l(mu_);
+
+ // Scan to find an existing communicator that provides nccl communication
+ // between the executors used by the participants in the collective. For
+ // example, if a collective is for GPUs 0, 1, and 2 then this will scan
+ // to find the communicator for GPUs 0, 1, and 2.
+ //
+ // Note that each executor identifies a context on one device, so this is the
+ // same as getting the communicator connecting the devices in the collective.
+ // A device can be in different communicators as well - for example, a
+ // communicator for GPUs 0 and 1 is separate from one for GPUs 0, 1, and 2.
+ //
+ // Since it's expected that a small number of distinct communicators will
+ // be needed, communicators_ is not garbage collected currently.
+ //
+ // Launching of kernels must be serialized so that, given collectives A and B,
+ // and an order of them (e.g., A before B), then for each comm_stream
+ // involved, the kernel for A is launched before the kernel for B. This is
+ // guaranteed currently be a global mutex controlling additions of the kernels
+ // to per-stream launch queues. The launch queues are processed by
+ // LoopKernelLaunches.
+ for (auto& comm : communicators_) {
+ if (comm->num_devices == num_devices) {
+ int i;
+ for (i = 0; i < num_devices; ++i) {
+ if (comm->members[i].nccl_stream->executor !=
+ collective->participants[i]->executor) {
+ break;
+ }
+ }
+ if (i == num_devices) return comm.get();
+ }
+ }
+
+ auto* env = Env::Default();
+ std::set<NcclStream*> used_streams;
+
+ // Create and initialize a new communicator.
+ // Note that this is done under the lock; performance is not expected to
+ // matter as this happens a very small number of times.
+ std::vector<CommunicatorMember> members(num_devices);
+ for (int i = 0; i < num_devices; ++i) {
+ auto* executor = collective->participants[i]->executor;
+
+ // Find a communication stream to use for the device.
+ auto& streams = device_to_comm_streams_[executor];
+ NcclStream* nccl_stream = nullptr;
+ for (const auto& s : streams) {
+ if (used_streams.insert(s.get()).second) {
+ nccl_stream = s.get();
+ break;
+ }
+ }
+ if (nccl_stream == nullptr) {
+ nccl_stream = new NcclStream();
+ nccl_stream->executor = executor;
+ nccl_stream->stream.reset(new perftools::gputools::Stream(executor));
+ nccl_stream->stream->Init();
+
+ streams.emplace_back(nccl_stream);
+ used_streams.insert(nccl_stream);
+
+ nccl_stream->thread.reset(env->StartThread(
+ ThreadOptions(), "nccl_kernel_launch",
+ [this, nccl_stream] { LoopKernelLaunches(nccl_stream); }));
+ }
+
+ members[i].nccl_stream = nccl_stream;
+ }
+
+ // Call ncclCommInitRank for each member.
+ ncclUniqueId id;
+ CHECK_EQ(ncclSuccess, ncclGetUniqueId(&id));
+ std::unique_ptr<thread::ThreadPool> pool(
+ new thread::ThreadPool(env, "ncclCommInitRank", num_devices));
+ std::vector<ncclResult_t> results(num_devices);
+ for (int rank = 0; rank < num_devices; ++rank) {
+ CommunicatorMember* member = &members[rank];
+ ncclResult_t* result = &results[rank];
+ pool->Schedule([member, num_devices, result, rank, &id]() {
+ ScopedActivateExecutorContext scoped_context(
+ member->nccl_stream->executor);
+ LOG(INFO) << "Calling ncclCommInitRank for rank " << rank;
+ *result = ncclCommInitRank(&member->nccl_comm, num_devices, id, rank);
+ LOG(INFO) << "Done calling ncclCommInitRank for rank " << rank << " : "
+ << *result;
+ });
+ }
+
+ pool.reset(); // wait for completion.
+ for (int i = 0; i < num_devices; ++i) {
+ CHECK_EQ(results[i], ncclSuccess);
+ }
+ communicators_.emplace_back(new Communicator(std::move(members)));
+ return communicators_.back().get();
+}
+
+void NcclManager::AddToAllReduce(int num_devices, const string& key,
+ ncclRedOp_t reduction_op,
+ perftools::gputools::StreamExecutor* executor,
+ EventMgr* event_mgr,
+ perftools::gputools::Stream* tensor_stream,
+ const Tensor* in_t, Tensor* out_t,
+ const DoneCallback& done_callback) {
+ std::unique_ptr<Participant> participant(new Participant(
+ in_t, out_t, event_mgr, tensor_stream, executor, done_callback));
+ AddParticipant(num_devices, key, std::move(participant), in_t->dtype(),
+ kAllReduce, reduction_op);
+}
+
+void NcclManager::AddBroadcastSend(
+ int num_devices, const string& key,
+ perftools::gputools::StreamExecutor* executor, EventMgr* event_mgr,
+ perftools::gputools::Stream* tensor_stream, const Tensor* in_t,
+ DoneCallback done_callback) {
+ std::unique_ptr<Participant> participant(
+ new Participant(in_t, nullptr /* out_t */, event_mgr, tensor_stream,
+ executor, done_callback));
+ participant->root = true;
+ AddParticipant(num_devices, key, std::move(participant), in_t->dtype(),
+ kBroadcast, ncclSum /* unused */);
+}
+
+void NcclManager::AddBroadcastRecv(
+ int num_devices, const string& key,
+ perftools::gputools::StreamExecutor* executor, EventMgr* event_mgr,
+ perftools::gputools::Stream* tensor_stream, Tensor* out_t,
+ DoneCallback done_callback) {
+ std::unique_ptr<Participant> participant(
+ new Participant(nullptr /* in_t */, out_t, event_mgr, tensor_stream,
+ executor, done_callback));
+ AddParticipant(num_devices, key, std::move(participant), out_t->dtype(),
+ kBroadcast, ncclSum /* unused */);
+}
+
+void NcclManager::AddParticipant(int num_devices, const string& key,
+ std::unique_ptr<Participant> participant,
+ DataType data_type,
+ CollectiveType collective_type,
+ ncclRedOp_t reduction_op) {
+ Collective* to_run = nullptr;
+ {
+ mutex_lock l(mu_);
+ auto& collective_ptr = collectives_[key];
+ if (collective_ptr == nullptr) {
+ collective_ptr.reset(new Collective(data_type, collective_type,
+ reduction_op, num_devices));
+ }
+ Collective* collective = collective_ptr.get();
+ DCHECK_EQ(collective->type, collective_type);
+ DCHECK_EQ(collective->participants.size(), num_devices);
+ collective->participants.emplace_back(std::move(participant));
+ ++collective->available_participants;
+
+ if (collective->available_participants == num_devices) {
+ to_run = collective;
+
+ // Ownership is going to be transferred to RunCollective.
+ collective_ptr.release();
+ collectives_.erase(key);
+ }
+ }
+
+ if (to_run != nullptr) {
+ RunCollective(key, to_run);
+ }
+}
+
+void NcclManager::RunCollective(const string& key, Collective* collective) {
+ static mutex collective_mu;
+
+ auto* communicator = GetCommunicator(collective);
+ collective->communicator = communicator;
+ const int size = communicator->num_devices;
+
+ for (int rank = 0; rank < size; ++rank) {
+ Participant* p = collective->participants[rank].get();
+ NcclStream* nccl_stream = communicator->members[rank].nccl_stream;
+ CHECK(nccl_stream != nullptr);
+
+ if (p->in_t != nullptr) {
+ // Wait to ensure that the kernel that produces the data in the input
+ // tensor has finished running before the nccl kernel runs on the
+ // communication stream.
+ nccl_stream->stream->ThenWaitFor(p->tensor_stream);
+ }
+ if (p->root) {
+ CHECK_EQ(collective->root_rank, -1);
+ collective->root_rank = rank;
+ }
+ }
+
+ if (collective->type == kBroadcast) {
+ CHECK_NE(collective->root_rank, -1);
+ }
+
+ {
+ // Allow only one collective at a time to queue kernels for launching. This
+ // is to prevent collectives from deadlocking each other.
+ // Note that it would be possible to run multiple collectives at once, if
+ // they have non-intersecting sets of devices.
+ mutex_lock l(collective_mu);
+ for (int rank = 0; rank < size; ++rank) {
+ NcclStream* nccl_stream = communicator->members[rank].nccl_stream;
+ mutex_lock l(nccl_stream->mu);
+ nccl_stream->pending_launches_.push_front(
+ std::make_pair(collective, rank));
+ nccl_stream->cv.notify_all();
+ }
+ }
+}
+
+void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
+ perftools::gputools::Stream* comm_stream = nccl_stream->stream.get();
+ ScopedActivateExecutorContext scoped_context(nccl_stream->executor);
+ const cudaStream_t* cu_stream = reinterpret_cast<const cudaStream_t*>(
+ comm_stream->implementation()->CudaStreamMemberHack());
+
+ while (true) {
+ // Find collective to run.
+ std::pair<Collective*, int> next_launch;
+ {
+ mutex_lock l(nccl_stream->mu);
+ while (nccl_stream->pending_launches_.empty()) {
+ if (nccl_stream->shutdown_requested) {
+ // No work and shutdown requested, exit.
+ return;
+ }
+ nccl_stream->cv.wait(l);
+ }
+ next_launch = nccl_stream->pending_launches_.back();
+ nccl_stream->pending_launches_.pop_back();
+ }
+ Collective* collective = next_launch.first;
+ int rank = next_launch.second;
+
+ // Launch the nccl kernel.
+ ncclDataType_t data_type = ToNcclType(collective->data_type);
+ Participant* p = collective->participants[rank].get();
+
+ auto nccl_comm = collective->communicator->members[rank].nccl_comm;
+ ncclResult_t nccl_result = ncclSuccess;
+ switch (collective->type) {
+ case kAllReduce: {
+ const void* sendbuff = p->in_t->tensor_data().data();
+ void* recvbuff = const_cast<char*>(p->out_t->tensor_data().data());
+
+ nccl_result =
+ ncclAllReduce(sendbuff, recvbuff, p->in_t->NumElements(), data_type,
+ collective->reduction_op, nccl_comm, *cu_stream);
+ break;
+ }
+ case kBroadcast: {
+ const Tensor* buf_t = p->in_t ? p->in_t : p->out_t;
+ void* buf = const_cast<char*>(buf_t->tensor_data().data());
+ nccl_result = ncclBcast(buf, buf_t->NumElements(), data_type,
+ collective->root_rank, nccl_comm, *cu_stream);
+ break;
+ }
+ }
+
+ // Run the done_callback when the nccl kernel finishes running.
+ auto done_callback = [collective, rank, nccl_result]() {
+ if (nccl_result == ncclSuccess) {
+ collective->participants[rank]->done_callback(Status::OK());
+ } else {
+ // Propagate the error, but note that if other members of the collective
+ // did launch their kernels, then they are hanging.
+ collective->participants[rank]->done_callback(errors::Unknown(
+ "Error invoking AllReduce: ", ncclGetErrorString(nccl_result)));
+ }
+
+ // TODO(cwhipkey): use RefCounted after figuring out how to use in a
+ // custom op library.
+ // See tensorflow/core/lib/core/refcount.h for details on this locking.
+ if (collective->remaining_participants.load(std::memory_order_acquire) ==
+ 1 ||
+ collective->remaining_participants.fetch_sub(1) == 1) {
+ delete collective;
+ }
+ };
+ p->event_mgr->ThenExecute(comm_stream, done_callback);
+ }
+}
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.h b/tensorflow/contrib/nccl/kernels/nccl_manager.h
new file mode 100644
index 0000000000..8d5e5ddf76
--- /dev/null
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager.h
@@ -0,0 +1,122 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_
+
+#ifdef GOOGLE_CUDA
+
+#include <unordered_map>
+#include <vector>
+
+#include "external/nccl_archive/src/nccl.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/stream_executor.h"
+
+namespace tensorflow {
+
+// The communicator is used to make the asynchronous communicator calls and to
+// manage the per-device streams used for communication.
+//
+// See nccl_ops.cc for example usage, including description of memory
+// management and stream synchronization.
+class NcclManager {
+ public:
+ typedef std::function<void(Status)> DoneCallback;
+ NcclManager();
+ ~NcclManager();
+
+ static NcclManager* instance();
+
+ // Add one participant to an all-reduce, sending in data from <in_t> and
+ // receiving the result of the all-reduce in <out_t>. The device for this
+ // participant is managed by <executor>, and its events are polled by
+ // <event_mgr>.
+ //
+ // This is an asynchronous call. When <done_callback> is called, <out_t> has
+ // been set to the all-reduce result (note: the stream may not yet have been
+ // synced).
+ //
+ // <tensor_stream> is the stream that should be waited on to ensure <in_t>'s
+ // data is available on the GPU for the communication stream to access. It
+ // is also the stream that will use the produced data; <done_callback> is
+ // not called until the next kernel launched on <stream> would see the data.
+ void AddToAllReduce(int num_devices, const string& key,
+ ncclRedOp_t reduction_op,
+ perftools::gputools::StreamExecutor* executor,
+ EventMgr* event_mgr,
+ perftools::gputools::Stream* tensor_stream,
+ const Tensor* in_t, Tensor* out_t,
+ const DoneCallback& done_callback);
+
+ // AddBroadcastSend and AddBroadcastRecv combine to sent data from one sender
+ // to all receivers.
+ void AddBroadcastSend(int num_devices, const string& key,
+ perftools::gputools::StreamExecutor* executor,
+ EventMgr* event_mgr,
+ perftools::gputools::Stream* tensor_stream,
+ const Tensor* in_t, DoneCallback done_callback);
+ void AddBroadcastRecv(int num_devices, const string& key,
+ perftools::gputools::StreamExecutor* executor,
+ EventMgr* event_mgr,
+ perftools::gputools::Stream* tensor_stream,
+ Tensor* out_t, DoneCallback done_callback);
+
+ private:
+ enum CollectiveType {
+ kAllReduce = 1,
+ kBroadcast = 2,
+ };
+ struct Collective;
+ struct Communicator;
+ struct CommunicatorMember;
+ struct NcclStream;
+ struct Participant;
+
+ Communicator* GetCommunicator(Collective* collective);
+
+ void AddParticipant(int num_devices, const string& key,
+ std::unique_ptr<Participant> participant,
+ DataType data_type, CollectiveType collective_type,
+ ncclRedOp_t reduction_op);
+
+ // Run <collective>. This calls takes ownership of <collective>.
+ void RunCollective(const string& key, Collective* collective);
+ void LoopKernelLaunches(NcclStream* stream);
+
+ mutex mu_;
+
+ // Maps key to collectives currently being assembled or run.
+ std::unordered_map<string, std::unique_ptr<Collective>> collectives_
+ GUARDED_BY(mu_);
+
+ // Maps a device to the communication streams that make up its collective.
+ // This is used to share the stream across different communicators that
+ // include the same device.
+ std::map<perftools::gputools::StreamExecutor*,
+ std::vector<std::unique_ptr<NcclStream>>>
+ device_to_comm_streams_ GUARDED_BY(mu_);
+
+ std::vector<std::unique_ptr<Communicator>> communicators_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(NcclManager);
+};
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
new file mode 100644
index 0000000000..b53cb82440
--- /dev/null
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
@@ -0,0 +1,285 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef GOOGLE_CUDA
+
+#include <algorithm>
+#include <vector>
+
+#include "tensorflow/contrib/nccl/kernels/nccl_manager.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_device.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+static std::vector<BaseGPUDevice*> GetGPUDevices() {
+ std::vector<Device*> devices;
+ SessionOptions session_options;
+ session_options.env = Env::Default();
+ Status s = DeviceFactory::GetFactory(DEVICE_GPU)
+ ->AddDevices(session_options, "", &devices);
+ TF_CHECK_OK(s);
+ std::vector<BaseGPUDevice*> gpus;
+ for (Device* d : devices) {
+ if (d->device_type() == "GPU") {
+ gpus.push_back(static_cast<BaseGPUDevice*>(d));
+ } else {
+ delete d;
+ }
+ }
+ return gpus;
+}
+
+class NcclManagerTest : public ::testing::Test {
+ protected:
+ static void SetUpTestCase() {
+ setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
+ devices = new std::vector<BaseGPUDevice*>(GetGPUDevices());
+ CHECK(!devices->empty());
+ LOG(ERROR) << "Running test with " << devices->size() << " gpus";
+ }
+ static void TearDownTestCase() {
+ for (auto device : *devices) delete device;
+ delete devices;
+ }
+
+ static Allocator* gpu_allocator(BaseGPUDevice* device) {
+ return device->GetStepAllocator(AllocatorAttributes(),
+ nullptr /* step_resource_manager */);
+ }
+
+ static std::vector<BaseGPUDevice*>* devices;
+
+ template <typename Scalar>
+ perftools::gputools::DeviceMemory<Scalar> AsDeviceMemory(
+ const Scalar* cuda_memory) {
+ perftools::gputools::DeviceMemoryBase wrapped(
+ const_cast<Scalar*>(cuda_memory));
+ perftools::gputools::DeviceMemory<Scalar> typed(wrapped);
+ return typed;
+ }
+
+ // A single all-reduce to apply.
+ struct TestCase {
+ string key;
+ std::vector<Tensor> ins;
+ std::vector<Tensor> outs;
+ Tensor expected;
+
+ mutex mu;
+ Status final_status;
+ int num_completed = 0;
+ };
+
+ TestCase* MakeTestCase(int num_ranks, ncclRedOp_t reduction_op,
+ TensorShape shape, float value_offset) {
+ TestCase* test_case = new TestCase();
+ test_case->expected = Tensor(DT_FLOAT, shape);
+ if (reduction_op == ncclProd) {
+ test::FillFn<float>(&test_case->expected, [](int) { return 1; });
+ } else if (reduction_op == ncclSum) {
+ test::FillFn<float>(&test_case->expected, [](int) { return 0; });
+ } else if (reduction_op == ncclMax) {
+ test::FillFn<float>(&test_case->expected, [](int) {
+ return -1 * std::numeric_limits<float>::max();
+ });
+ } else if (reduction_op == ncclMin) {
+ test::FillFn<float>(&test_case->expected, [](int) {
+ return std::numeric_limits<float>::max();
+ });
+ } else {
+ LOG(FATAL) << "Invalid reduction_op " << reduction_op;
+ }
+
+ int mult = 1;
+ for (int i = 0; i < num_ranks; ++i) {
+ auto* device = devices->at(i % devices->size());
+ auto* stream = device->tensorflow_gpu_device_info()->stream;
+
+ Tensor in_cpu(DT_FLOAT, shape);
+ test::FillFn<float>(&in_cpu, [mult, value_offset](int index) {
+ return value_offset + (index + 1) * mult;
+ });
+ for (int j = 0; j < shape.num_elements(); ++j) {
+ auto in_val = in_cpu.flat<float>()(j);
+ auto out_expr = test_case->expected.flat<float>();
+ if (reduction_op == ncclProd) {
+ out_expr(j) *= in_val;
+ } else if (reduction_op == ncclSum) {
+ out_expr(j) += in_val;
+ } else if (reduction_op == ncclMax) {
+ if (in_val > out_expr(j)) {
+ out_expr(j) = in_val;
+ }
+ } else if (reduction_op == ncclMin) {
+ if (in_val < out_expr(j)) {
+ out_expr(j) = in_val;
+ }
+ }
+ }
+
+ mult *= 10;
+ test_case->ins.emplace_back(gpu_allocator(device), DT_FLOAT, shape);
+ test_case->outs.emplace_back(gpu_allocator(device), DT_FLOAT, shape);
+
+ const Tensor& in_gpu = test_case->ins.back();
+ auto in_gpu_mem = AsDeviceMemory(in_gpu.flat<float>().data());
+ stream->ThenMemcpy(&in_gpu_mem, in_cpu.flat<float>().data(),
+ in_cpu.TotalBytes());
+ }
+ return test_case;
+ }
+
+ NcclManager::DoneCallback CreateDoneCallback(TestCase* test_case) {
+ return [this, test_case](Status s) {
+ mutex_lock l(test_case->mu);
+ ++test_case->num_completed;
+ test_case->final_status.Update(s);
+ };
+ }
+
+ void VerifyResults(const string& case_label, TestCase* test_case) {
+ // Wait for the done callback to be called.
+ {
+ test_case->mu.lock();
+ while (test_case->num_completed != test_case->outs.size()) {
+ test_case->mu.unlock();
+ Env::Default()->SleepForMicroseconds(10);
+ test_case->mu.lock();
+ }
+ test_case->mu.unlock();
+ }
+ // Copy memory to host and verify.
+ for (int i = 0; i < test_case->outs.size(); ++i) {
+ auto* device = devices->at(i % devices->size());
+ auto* stream = device->tensorflow_gpu_device_info()->stream;
+ const Tensor& out_gpu = test_case->outs[i];
+ Tensor out_cpu(DT_FLOAT, out_gpu.shape());
+ auto out_gpu_mem = AsDeviceMemory(out_gpu.flat<float>().data());
+ stream->ThenMemcpy(out_cpu.flat<float>().data(), out_gpu_mem,
+ out_cpu.TotalBytes());
+ stream->BlockHostUntilDone();
+ test::ExpectTensorEqual<float>(test_case->expected, out_cpu);
+ }
+ }
+};
+std::vector<BaseGPUDevice*>* NcclManagerTest::devices = nullptr;
+
+// Test basic sum reduction.
+TEST_F(NcclManagerTest, BasicSumReduction) {
+ const int num_ranks = 3;
+
+ for (int op = 0; op < 4; ++op) {
+ ncclRedOp_t reduction_op = static_cast<ncclRedOp_t>(op);
+ std::unique_ptr<TestCase> test_case(
+ MakeTestCase(num_ranks, reduction_op, TensorShape({2, 3}), 0));
+ for (int device_num = 0; device_num < num_ranks; ++device_num) {
+ auto* device = devices->at(device_num % devices->size());
+ auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr;
+ auto* stream = device->tensorflow_gpu_device_info()->stream;
+ NcclManager::instance()->AddToAllReduce(
+ num_ranks, "allreduce", reduction_op, device->executor(), event_mgr,
+ stream, &test_case->ins[device_num], &test_case->outs[device_num],
+ CreateDoneCallback(test_case.get()));
+ }
+
+ LOG(ERROR) << "Verifying results";
+ VerifyResults("test_case", test_case.get());
+ }
+}
+
+// Same as the Basic test, but with multiple threads launching parts of many
+// reductions.
+//
+// Testing the multi-rank execution is currently reduced as it can hang when run
+// with num_ranks > devices->size(), for some GPUs (e.g. K20m).
+// To test the higher settings, increase num_ranks,
+// num_collectives_per_iteration and time_limit_micros.
+TEST_F(NcclManagerTest, MultipleCallers) {
+ const int num_ranks = 1; // 2;
+ const int num_collectives_per_iteration = 1; // 1000;
+ const int num_threads = 3;
+ const int time_limit_micros = 1; // 60 * 30 * 1000 * 1000;
+
+ int64 start = Env::Default()->NowMicros();
+ srand(Env::Default()->NowMicros());
+
+ for (;;) {
+ std::vector<std::pair<int, int>> case_and_device_num;
+ std::vector<std::unique_ptr<TestCase>> test_cases;
+ for (int i = 0; i < num_collectives_per_iteration; ++i) {
+ test_cases.emplace_back(
+ MakeTestCase(num_ranks, ncclSum,
+ TensorShape({100, i % 5 + 1, i % 3 + 1}), i + 0.1 * i));
+ for (int j = 0; j < num_ranks; ++j) {
+ case_and_device_num.emplace_back(i, j);
+ }
+ }
+
+ for (int i = 0; i < num_ranks; ++i) {
+ auto* device = devices->at(i % devices->size());
+ auto* stream = device->tensorflow_gpu_device_info()->stream;
+ stream->BlockHostUntilDone();
+ }
+
+ std::random_shuffle(case_and_device_num.begin(), case_and_device_num.end());
+
+ mutex mu; // guards case_and_device_num.
+ std::unique_ptr<thread::ThreadPool> pool(
+ new thread::ThreadPool(Env::Default(), "test", num_threads));
+ const int to_schedule = case_and_device_num.size();
+ for (int i = 0; i < to_schedule; ++i) {
+ auto fn = [&]() {
+ int device_num;
+ int test_num;
+ {
+ mutex_lock l(mu);
+ test_num = case_and_device_num.back().first;
+ device_num = case_and_device_num.back().second;
+ case_and_device_num.pop_back();
+ }
+ auto* device = devices->at(device_num % devices->size());
+ auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr;
+ auto* stream = device->tensorflow_gpu_device_info()->stream;
+ TestCase* test_case = test_cases[test_num].get();
+ NcclManager::instance()->AddToAllReduce(
+ num_ranks, strings::StrCat("allreduce", test_num), ncclSum,
+ device->executor(), event_mgr, stream, &test_case->ins[device_num],
+ &test_case->outs[device_num], CreateDoneCallback(test_case));
+ };
+ pool->Schedule(fn);
+ }
+ pool.reset(); // wait for all work to be scheduled.
+
+ LOG(ERROR) << "Verifying results for " << num_collectives_per_iteration
+ << " collectives";
+ for (int i = 0; i < test_cases.size(); ++i) {
+ VerifyResults(strings::StrCat("collective", i), test_cases[i].get());
+ }
+
+ int64 delta = Env::Default()->NowMicros() - start;
+ if (delta > time_limit_micros) {
+ LOG(ERROR) << "Ran for " << delta << " quitting";
+ break;
+ }
+ }
+}
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/nccl/kernels/nccl_ops.cc b/tensorflow/contrib/nccl/kernels/nccl_ops.cc
new file mode 100644
index 0000000000..db6ee3e0e7
--- /dev/null
+++ b/tensorflow/contrib/nccl/kernels/nccl_ops.cc
@@ -0,0 +1,157 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include <unordered_map>
+#include <vector>
+
+#include "external/nccl_archive/src/nccl.h"
+#include "tensorflow/contrib/nccl/kernels/nccl_manager.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+// Base class for all communicator ops that use nccl.
+//
+// About memory management and stream syncing:
+// 1. The nccl communicator has a stream for each rank.
+// 2. For input tensors to the communicator, the compute stream is passed to the
+// NcclManager which will do a needed
+// communicator_stream.ThenWaitFor(input_tensor_stream).
+// 3. The done_callback of the async kernel is not called by the
+// NcclManager until after the communicator kernel is complete. This
+// is enough to a) keep the input tensor data valid for the lifetime of the
+// collective; and b) ensure the data in the output tensor is available
+// when the async op kernel's done callback is called.
+class NcclAsyncOpBase : public AsyncOpKernel {
+ public:
+ NcclAsyncOpBase(OpKernelConstruction* c) : AsyncOpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("num_devices", &num_devices_));
+ OP_REQUIRES_OK(c, c->GetAttr("shared_name", &collective_prefix_));
+ }
+
+ string GetCollectiveKey(OpKernelContext* c) {
+ return strings::StrCat(collective_prefix_, ";", c->step_id(), ";",
+ c->frame_iter().frame_id, ":",
+ c->frame_iter().iter_id);
+ }
+
+ int num_devices() const { return num_devices_; }
+
+ private:
+ int num_devices_;
+ string collective_prefix_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(NcclAsyncOpBase);
+};
+
+// To execute a single all-reduce, this kernel is called once for each of the
+// <k> devices in the communicator.
+class NcclAllReduceOpKernel : public NcclAsyncOpBase {
+ public:
+ NcclAllReduceOpKernel(OpKernelConstruction* c) : NcclAsyncOpBase(c) {
+ string reduction;
+ OP_REQUIRES_OK(c, c->GetAttr("reduction", &reduction));
+ if (reduction == "min") {
+ reduction_op_ = ncclMin;
+ } else if (reduction == "max") {
+ reduction_op_ = ncclMax;
+ } else if (reduction == "sum") {
+ reduction_op_ = ncclSum;
+ } else if (reduction == "prod") {
+ reduction_op_ = ncclProd;
+ } else {
+ OP_REQUIRES_OK(c,
+ errors::InvalidArgument("Invalid reduction: ", reduction));
+ }
+ }
+
+ void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+ const Tensor* in_t = &c->input(0);
+ Tensor* out_t;
+ OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, in_t->shape(), &out_t), done);
+
+ auto actual_done = [c, done](Status s) {
+ OP_REQUIRES_OK_ASYNC(c, s, done);
+ done();
+ };
+
+ auto* compute_stream = c->op_device_context()->stream();
+ EventMgr* event_mgr = c->device()->tensorflow_gpu_device_info()->event_mgr;
+ NcclManager::instance()->AddToAllReduce(
+ num_devices(), GetCollectiveKey(c), reduction_op_,
+ compute_stream->parent(), event_mgr, compute_stream, in_t, out_t,
+ actual_done);
+ }
+
+ private:
+ ncclRedOp_t reduction_op_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("NcclAllReduce").Device(DEVICE_GPU),
+ NcclAllReduceOpKernel);
+
+class NcclBroadcastSendKernel : public NcclAsyncOpBase {
+ public:
+ NcclBroadcastSendKernel(OpKernelConstruction* c) : NcclAsyncOpBase(c) {}
+
+ void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+ auto actual_done = [c, done](Status s) {
+ OP_REQUIRES_OK_ASYNC(c, s, done);
+ done();
+ };
+
+ auto* compute_stream = c->op_device_context()->stream();
+ EventMgr* event_mgr = c->device()->tensorflow_gpu_device_info()->event_mgr;
+ NcclManager::instance()->AddBroadcastSend(
+ num_devices(), GetCollectiveKey(c), compute_stream->parent(), event_mgr,
+ compute_stream, &c->input(0), std::move(actual_done));
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("NcclBroadcastSend").Device(DEVICE_GPU),
+ NcclBroadcastSendKernel);
+
+class NcclBroadcastRecvKernel : public NcclAsyncOpBase {
+ public:
+ NcclBroadcastRecvKernel(OpKernelConstruction* c) : NcclAsyncOpBase(c) {}
+
+ void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+ const Tensor& shape_t = c->input(0);
+ TensorShape shape;
+ OP_REQUIRES_OK_ASYNC(
+ c, TensorShapeUtils::MakeShape(shape_t.vec<int64>(), &shape), done);
+ Tensor* out_t;
+ OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape, &out_t), done);
+
+ auto actual_done = [c, done](Status s) {
+ OP_REQUIRES_OK_ASYNC(c, s, done);
+ done();
+ };
+
+ auto* compute_stream = c->op_device_context()->stream();
+ EventMgr* event_mgr = c->device()->tensorflow_gpu_device_info()->event_mgr;
+ NcclManager::instance()->AddBroadcastRecv(
+ num_devices(), GetCollectiveKey(c), compute_stream->parent(), event_mgr,
+ compute_stream, out_t, std::move(actual_done));
+ }
+};
+REGISTER_KERNEL_BUILDER(
+ Name("NcclBroadcastRecv").Device(DEVICE_GPU).HostMemory("shape"),
+ NcclBroadcastRecvKernel);
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/nccl/ops/nccl_ops.cc b/tensorflow/contrib/nccl/ops/nccl_ops.cc
new file mode 100644
index 0000000000..d767636fad
--- /dev/null
+++ b/tensorflow/contrib/nccl/ops/nccl_ops.cc
@@ -0,0 +1,94 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
+REGISTER_OP("NcclAllReduce")
+ .Input("input: T")
+ .Output("data: T")
+ .Attr("reduction: {'min', 'max', 'prod', 'sum'}")
+ .Attr("T: {float, float64, int32, int64}")
+ .Attr("num_devices: int")
+ .Attr("shared_name: string")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Outputs a tensor containing the reduction across all input tensors passed to ops
+within the same `shared_name.
+
+The graph should be constructed so if one op runs with shared_name value `c`,
+then `num_devices` ops will run with shared_name value `c`. Failure to do so
+will cause the graph execution to fail to complete.
+
+input: the input to the reduction
+data: the value of the reduction across all `num_devices` devices.
+reduction: the reduction operation to perform.
+num_devices: The number of devices participating in this reduction.
+shared_name: Identifier that shared between ops of the same reduction.
+)doc");
+
+REGISTER_OP("NcclBroadcastSend")
+ .Input("input: T")
+ .Attr("T: {float, float64, int32, int64}")
+ .Attr("num_devices: int")
+ .Attr("shared_name: string")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Sends `input` to the NcclBroadcastRecv ops registered in the same `shared_name`.
+
+The graph should be constructed so that one device runs `NcclBroadcastSend` and
+`num_devices-1` devices run NcclBroadcastRecv ops with shared_name value `c`.
+Failure to do so will cause the graph execution to fail to complete.
+
+input: The input to the broadcast
+num_devices: The number of devices participating in this reduction.
+shared_name: Identifier that is shared between ops of the same broadcast.
+ )doc");
+
+REGISTER_OP("NcclBroadcastRecv")
+ .Input("shape: int64")
+ .Output("output: T")
+ .Attr("T: {float, float64, int32, int64}")
+ .Attr("num_devices: int")
+ .Attr("shared_name: string")
+ .SetIsStateful()
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
+ c->set_output(0, out);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Sends data of shape `shape` from the NcclBroadcastSend op registered in the
+same `shared_name`.
+
+The graph should be constructed so that one device runs `NcclBroadcastSend` and
+`num_devices-1` devices run NcclBroadcastRecv ops with shared_name value `c`.
+Failure to do so will cause the graph execution to fail to complete.
+
+shape: The shape of the output.
+output: The broadcast data received from the NcclBroadcastSend op.
+num_devices: The number of devices participating in this reduction.
+shared_name: Identifier that is shared between ops of the same broadcast.
+ )doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops.py b/tensorflow/contrib/nccl/python/ops/nccl_ops.py
new file mode 100644
index 0000000000..b31cc53e0a
--- /dev/null
+++ b/tensorflow/contrib/nccl/python/ops/nccl_ops.py
@@ -0,0 +1,168 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops for GPU collective operations implemented using NVIDIA nccl."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+from tensorflow.contrib.nccl.ops import gen_nccl_ops
+from tensorflow.contrib.util import loader
+from tensorflow.python.framework import device
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import resource_loader
+
+_nccl_ops_so = loader.load_op_library(
+ resource_loader.get_path_to_datafile('_nccl_ops.so'))
+
+
+def all_sum(tensors):
+ """Returns a list of tensors with the all-reduce sum across `tensors`.
+
+ The computation is done with an all-reduce operation, so if only some of the
+ returned tensors are evaluated then the computation will hang.
+
+ Args:
+ tensors: The input tensors across which to sum; must be assigned
+ to GPU devices.
+
+ Returns:
+ List of tensors, each with the sum of the input tensors, where tensor i has
+ the same device as `tensors[i]`.
+ """
+ return _apply_all_reduce('sum', tensors)
+
+
+def all_prod(tensors):
+ """Returns a list of tensors with the all-reduce product across `tensors`.
+
+ The computation is done with an all-reduce operation, so if only some of the
+ returned tensors are evaluated then the computation will hang.
+
+ Args:
+ tensors: The input tensors across which to multiply; must be assigned
+ to GPU devices.
+
+ Returns:
+ List of tensors, each with the product of the input tensors, where tensor i
+ has the same device as `tensors[i]`.
+ """
+ return _apply_all_reduce('prod', tensors)
+
+
+def all_min(tensors):
+ """Returns a list of tensors with the all-reduce min across `tensors`.
+
+ The computation is done with an all-reduce operation, so if only some of the
+ returned tensors are evaluated then the computation will hang.
+
+ Args:
+ tensors: The input tensors across which to reduce; must be assigned
+ to GPU devices.
+
+ Returns:
+ List of tensors, each with the minimum of the input tensors, where tensor i
+ has the same device as `tensors[i]`.
+ """
+ return _apply_all_reduce('min', tensors)
+
+
+def all_max(tensors):
+ """Returns a list of tensors with the all-reduce max across `tensors`.
+
+ The computation is done with an all-reduce operation, so if only some of the
+ returned tensors are evaluated then the computation will hang.
+
+ Args:
+ tensors: The input tensors across which to reduce; must be assigned
+ to GPU devices.
+
+ Returns:
+ List of tensors, each with the maximum of the input tensors, where tensor i
+ has the same device as `tensors[i]`.
+ """
+ return _apply_all_reduce('max', tensors)
+
+
+def broadcast(src_tensor, dst_devices):
+ """Returns a list of tensors on `dst_devices`, each with value `tensor`.
+
+ The computation is done with a broadcast nccl operation, so if only some of
+ the returned tensors and src_tensor are evaluated then the computation will
+ hang.
+
+ Args:
+ src_tensor: The tensor to send; must be assigned to a GPU device.
+ dst_devices: The GPU devices to receive the sent tensor.
+
+ Returns:
+ List of tensors, each with the value of `src_tensor`, which the device
+ of tensor i is `dst_devices[i]`.
+ """
+ if not dst_devices:
+ raise ValueError('Must pass >0 dst_devices to broadcast')
+ all_devices = [src_tensor.device] + dst_devices
+ shared_name = _get_shared_name()
+
+ with ops.device(src_tensor.device):
+ send = gen_nccl_ops.nccl_broadcast_send(
+ input=src_tensor, num_devices=len(all_devices), shared_name=shared_name)
+
+ shape_op = array_ops.shape(src_tensor, out_type=dtypes.int64)
+ recvs = []
+ for d in dst_devices:
+ with ops.device(d):
+ recvs.append(
+ gen_nccl_ops.nccl_broadcast_recv(
+ shape=shape_op,
+ T=src_tensor.dtype,
+ num_devices=len(all_devices),
+ shared_name=shared_name))
+
+ return send, recvs
+
+
+def _apply_all_reduce(reduction_op, tensors):
+ if not tensors:
+ raise ValueError('Must pass >0 tensors to all reduce operations')
+ shared_name = _get_shared_name()
+ res = []
+ for t in tensors:
+ if not device.canonical_name(t.device):
+ raise ValueError('Device assignment required for nccl collective ops')
+ with ops.device(t.device):
+ res.append(
+ gen_nccl_ops.nccl_all_reduce(
+ t,
+ reduction=reduction_op,
+ num_devices=len(tensors),
+ shared_name=shared_name))
+ return res
+
+
+_lock = threading.Lock()
+_shared_name_counter = 0
+
+
+def _get_shared_name():
+ global _shared_name_counter
+
+ with _lock:
+ val = _shared_name_counter
+ _shared_name_counter += 1
+ return 'c%s' % val
diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
new file mode 100644
index 0000000000..130cb4ca12
--- /dev/null
+++ b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
@@ -0,0 +1,151 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for nccl ops. See also the cc test for nccl_communicator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib import nccl
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class AllReduceTest(test.TestCase):
+
+ def testAllReduce(self):
+ if not test.is_gpu_available():
+ return # Test requires access to a GPU
+
+ for dtype in [np.float32, np.int32, np.int64, np.float64]:
+ # Create session inside outer loop to test use of
+ # same communicator across multiple sessions.
+ with self.test_session(use_gpu=True) as sess:
+ self._testSingleAllReduce(sess, dtype, nccl.all_sum, lambda x, y: x + y)
+ self._testSingleAllReduce(sess, dtype, nccl.all_prod,
+ lambda x, y: x * y)
+ self._testSingleAllReduce(sess, dtype, nccl.all_min, np.minimum)
+ self._testSingleAllReduce(sess, dtype, nccl.all_max, np.maximum)
+
+ def _testSingleAllReduce(self, sess, np_type, nccl_fn, numpy_accumulation_fn):
+ for devices in [['/gpu:0', '/gpu:0', '/gpu:0'], ['/gpu:0', '/gpu:0']]:
+ shape = (3, 4)
+ np_ans = None
+ tensors = []
+ for d in devices:
+ with ops.device(d):
+ t = ((np.random.random_sample(shape) - .5) * 1024).astype(np_type)
+ if np_ans is None:
+ np_ans = t
+ else:
+ np_ans = numpy_accumulation_fn(np_ans, t)
+ tensors.append(array_ops.identity(t))
+
+ all_reduce_tensors = nccl_fn(tensors)
+
+ # Test shape inference.
+ for r in all_reduce_tensors:
+ self.assertEqual(shape, r.get_shape())
+
+ # Test execution and results.
+ nccl_results = sess.run(all_reduce_tensors)
+ for r in nccl_results:
+ self.assertAllClose(r, np_ans)
+
+ def testErrors(self):
+ with self.assertRaisesRegexp(ValueError, 'Device assignment required'):
+ nccl.all_sum([array_ops.identity(np.random.random_sample((3, 4)))])
+ with self.assertRaisesRegexp(ValueError, 'Must pass >0 tensors'):
+ nccl.all_sum([])
+
+
+class BroadcastTest(test.TestCase):
+
+ def testBroadcast(self):
+ if not test.is_gpu_available():
+ return # Test requires access to a GPU
+
+ for dtype in [np.float32, np.int32, np.int64, np.float64]:
+ # Create session inside outer loop to test use of
+ # same communicator across multiple sessions.
+ with self.test_session(use_gpu=True) as sess:
+ for devices in [['/gpu:0', '/gpu:0', '/gpu:0'], ['/gpu:0', '/gpu:0']]:
+ shape = (3, 4)
+ sender = np.random.randint(0, len(devices) - 1)
+ with ops.device(devices[sender]):
+ np_ans = ((
+ (np.random.random_sample(shape) - .5) * 1024).astype(dtype))
+ t = array_ops.identity(np_ans)
+ other_devices = devices[:sender] + devices[sender + 1:]
+ send_op, received_tensors = nccl.broadcast(t, other_devices)
+
+ # Verify shape inference.
+ for r in received_tensors:
+ self.assertEqual(shape, r.get_shape())
+
+ # Run and verify results.
+ nccl_results = sess.run(received_tensors + [send_op])
+ for r in nccl_results[:-1]:
+ self.assertAllClose(r, np_ans)
+
+
+class CombinedTest(test.TestCase):
+ """Tests using a mix of all-reduce ops in one session.run call."""
+
+ def testCombined(self):
+ if not test.is_gpu_available():
+ return # Test requires access to a GPU
+
+ for dtype in [np.float32, np.int32, np.int64, np.float64]:
+ # Create session inside outer loop to test use of
+ # same communicator across multiple sessions.
+ with self.test_session(use_gpu=True) as sess:
+ for devices in [['/gpu:0', '/gpu:0', '/gpu:0'], ['/gpu:0', '/gpu:0']]:
+ shape = (3, 4)
+
+ # all-reduce
+ np_ans = np.zeros(shape=shape, dtype=dtype)
+ tensors = []
+ for d in devices:
+ with ops.device(d):
+ t = ((np.random.random_sample(shape) - .5) * 1024).astype(dtype)
+ np_ans += t
+ tensors.append(array_ops.identity(t))
+ all_reduce_tensors = nccl.all_sum(tensors)
+
+ sender = np.random.randint(0, len(devices) - 1)
+ other_devices = devices[:sender] + devices[sender + 1:]
+ send_op, received_tensors = nccl.broadcast(all_reduce_tensors[sender],
+ other_devices)
+
+ # sender doesn't need to be fetched as part of outputs of session.run.
+ del all_reduce_tensors[sender]
+
+ # Verify shape inference.
+ for r in received_tensors:
+ self.assertEqual(shape, r.get_shape())
+
+ # Run and verify results.
+ nccl_results = sess.run(
+ received_tensors + [send_op] + all_reduce_tensors)
+ for r in nccl_results[:len(received_tensors)]:
+ self.assertAllClose(r, np_ans)
+
+
+if __name__ == '__main__':
+ test.main()