aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/mpi_collectives/mpi_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/mpi_collectives/mpi_ops.cc')
-rw-r--r--tensorflow/contrib/mpi_collectives/mpi_ops.cc1236
1 files changed, 1236 insertions, 0 deletions
diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/mpi_ops.cc
new file mode 100644
index 0000000000..a051ab0004
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/mpi_ops.cc
@@ -0,0 +1,1236 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_MPI
+
+#include <queue>
+#include <thread>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/mutex.h"
+
+#define EIGEN_USE_THREADS
+
+#if GOOGLE_CUDA
+#include <cuda_runtime.h>
+#include "tensorflow/stream_executor/stream.h"
+#endif
+
+#include "tensorflow/stream_executor/lib/statusor.h"
+
+#define OMPI_SKIP_MPICXX
+#include "third_party/mpi/mpi.h"
+#include "tensorflow/contrib/mpi_collectives/mpi_message.pb.h"
+#include "tensorflow/contrib/mpi_collectives/ring.h"
+
+/*
+ * MPI Allreduce and Allgather Ops for TensorFlow.
+ *
+ * TensorFlow natively provides inter-device communication through send and
+ * receive ops and inter-node communication through Distributed TensorFlow,
+ * based on the same send and receive abstractions. These end up being
+ * insufficient for synchronous data-parallel training on HPC clusters where
+ * Infiniband or other high-speed interconnects are available. This module
+ * implements MPI ops for allgather and allreduce, which do bandwidth-optimal
+ * gathers and reductions and can take advantage of hardware-optimized
+ * communication libraries through the MPI implementation.
+ *
+ * The primary logic of the allreduce and allgather are in RingAllgather() and
+ * RingAllreduce(). The background thread which facilitates MPI operations is
+ * run in BackgroundThreadLoop(). The provided MPI ops are:
+ * – MPIInit:
+ * Initialize MPI on a given device (CPU or GPU).
+ * Should only be run on a single device in every process.
+ * – MPISize:
+ * Get the number of MPI processes in the global communicator.
+ * – MPIRank:
+ * Get the rank of the current MPI process in the global communicator.
+ * – MPILocalRank:
+ * Get the local rank of the current MPI process within its node.
+ * – MPIAllreduce:
+ * Perform an allreduce on a Tensor, returning the sum
+ * across all MPI processes in the global communicator.
+ * – MPIAllgather:
+ * Perform an allgather on a Tensor, returning the concatenation of
+ * the tensor on the first dimension across all MPI processes in the
+ * global communicator.
+ *
+ */
+
+template <class T>
+using StatusOr = perftools::gputools::port::StatusOr<T>;
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi {
+
+// Make sure template specializations are generated in the ring.cu.cc and the
+// ring.cc file, not in this file.
+extern template Status RingAllreduce<GPUDevice, int>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+extern template Status RingAllreduce<GPUDevice, long long>(OpKernelContext*,
+ const Tensor*,
+ Tensor*, Tensor*);
+extern template Status RingAllreduce<GPUDevice, float>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+extern template Status RingAllgather<GPUDevice, int>(OpKernelContext*,
+ const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+extern template Status RingAllgather<GPUDevice, long long>(
+ OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
+extern template Status RingAllgather<GPUDevice, float>(
+ OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
+extern template Status RingAllreduce<CPUDevice, int>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+extern template Status RingAllreduce<CPUDevice, long long>(OpKernelContext*,
+ const Tensor*,
+ Tensor*, Tensor*);
+extern template Status RingAllreduce<CPUDevice, float>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+extern template Status RingAllgather<CPUDevice, int>(OpKernelContext*,
+ const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+extern template Status RingAllgather<CPUDevice, long long>(
+ OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
+extern template Status RingAllgather<CPUDevice, float>(
+ OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
+
+namespace {
+
+// Return true if the templated type is GPUDevice, otherwise false.
+template <typename T>
+bool IsGPUDevice();
+template <>
+bool IsGPUDevice<GPUDevice>() {
+ return true;
+};
+template <>
+bool IsGPUDevice<CPUDevice>() {
+ return false;
+};
+
+// A callback to call after the MPI communication completes. Since the
+// allreduce and allgather ops are asynchronous, this callback is what resumes
+// computation after the reduction is completed.
+typedef std::function<void(StatusOr<Tensor>)> CommunicationDoneCallback;
+
+struct CollectiveOpRecord {
+ // The rank performing this piece of the op
+ int rank;
+
+ // The name of the op/tensor to be reduced
+ std::string name;
+
+ // The op's kernel context
+ OpKernelContext* context;
+
+ // Data type of the op
+ DataType dtype;
+
+ // The input tensor
+ const Tensor* in_t;
+
+ // Allgather: Vector of per-rank first-dimension sizes
+ std::vector<size_t> sizes_vec;
+
+ // The temp tensor for intermediate results
+ Tensor temp_t;
+
+ // The output tensor
+ Tensor* out_t;
+
+ // Whether to run this op on the gpu
+ bool on_gpu;
+
+ // The callback to call after the op has completed
+ CommunicationDoneCallback callback;
+};
+
+// Table storing Tensors to be reduced, keyed by unique name.
+// This table contains everything necessary to do the reduction
+typedef std::unordered_map<std::string, CollectiveOpRecord> TensorTable;
+
+// Table for storing Tensor metadata on rank zero. This is used for error
+// checking and size calculations, as well as determining when a reduction is
+// ready to be done (when all nodes are ready to do it).
+typedef std::unordered_map<std::string, std::vector<MPIRequest> > MessageTable;
+
+// The global state required for the MPI ops.
+//
+// MPI is a library that stores a lot of global per-program state and often
+// requires running on a single thread. As a result, we have to have a single
+// background thread responsible for all MPI operations, and communicate with
+// that background thread through global state.
+struct MPIGlobalState {
+ // An atomic boolean which is set to true when MPI is initialized.
+ // This ensures that MPI_Init is never called twice.
+ std::atomic_flag initialized_flag = ATOMIC_FLAG_INIT;
+
+ // Condition variable to wait for initialization
+ condition_variable cv;
+
+ // Whether MPI_Init has been completed on the background thread.
+ bool initialization_done = false;
+
+ // Whether MPI_Init succeeded on the background thread.
+ Status init_status;
+
+ // A mutex that needs to be used whenever MPI operations touch
+ // shared structures.
+ mutex mu;
+
+ // Tensors waiting to be allreduced or allgathered.
+ TensorTable tensor_table;
+
+ // Queue of MPI requests waiting to be sent to the coordinator node.
+ std::queue<MPIRequest> message_queue;
+
+ // Background thread running MPI communication.
+ std::thread background_thread;
+
+ // Whether the background thread should shutdown.
+ bool shut_down = false;
+
+ // Only exists on the coordinator node (rank zero). Maintains a count of
+ // how many nodes are ready to allreduce every tensor (keyed by tensor
+ // name).
+ std::unique_ptr<MessageTable> message_table;
+
+ // The MPI rank, local rank, and size.
+ int rank = 0;
+ int local_rank = 0;
+ int size = 1;
+
+ // The device that MPI was initialized on. (-1 for no GPU)
+ int device = -1;
+
+ // The CUDA stream used for data transfers and within-allreduce operations.
+ // A naive implementation would use the TensorFlow StreamExecutor CUDA
+ // stream. However, the allreduce and allgather require doing memory copies
+ // and kernel executions (for accumulation of values on the GPU). However,
+ // the subsequent operations must wait for those operations to complete,
+ // otherwise MPI (which uses its own stream internally) will begin the data
+ // transfers before the CUDA calls are complete. In order to wait for those
+ // CUDA operations, if we were using the TensorFlow stream, we would have
+ // to synchronize that stream; however, other TensorFlow threads may be
+ // submitting more work to that stream, so synchronizing on it can cause
+ // the allreduce to be delayed, waiting for compute totally unrelated to it
+ // in other parts of the graph. Overlaying memory transfers and compute
+ // during backpropagation is crucial for good performance, so we cannot use
+ // the TensorFlow stream, and must use our own stream.
+#if GOOGLE_CUDA
+ cudaStream_t stream;
+ std::atomic_flag stream_created_flag = ATOMIC_FLAG_INIT;
+#endif
+
+ ~MPIGlobalState() {
+ // Make sure that the destructor of the background thread is safe to
+ // call. If a thread is still joinable (not detached or complete) its
+ // destructor cannot be called.
+ if (background_thread.joinable()) {
+ shut_down = true;
+ background_thread.join();
+ }
+ }
+};
+
+// All the MPI state that must be stored globally per-process.
+static MPIGlobalState mpi_global;
+
+// For clarify in argument lists.
+#define RANK_ZERO 0
+
+// A tag used for all coordinator messaging.
+#define TAG_NOTIFY 1
+
+// Store the MPIRequest for a name, and return whether the total count of
+// MPIRequests for that tensor is now equal to the MPI size (and thus we are
+// ready to reduce the tensor).
+bool IncrementTensorCount(std::unique_ptr<MessageTable>& message_table,
+ MPIRequest msg, int mpi_size) {
+ auto name = msg.tensor_name();
+ auto table_iter = message_table->find(name);
+ if (table_iter == message_table->end()) {
+ message_table->emplace(name, std::vector<MPIRequest>({msg}));
+ table_iter = message_table->find(name);
+ } else {
+ table_iter->second.push_back(msg);
+ }
+
+ int count = table_iter->second.size();
+ return count == mpi_size;
+}
+
+// Once a tensor is ready to be reduced, the coordinator sends an MPIResponse
+// instructing all ranks to start the reduction to all ranks. The MPIResponse
+// also contains error messages in case the submitted MPIRequests were not
+// valid (for example, contained mismatched shapes or types).
+//
+// Constructing the MPIResponse, thus, requires a whole lot of error checking.
+MPIResponse ConstructMPIResponse(std::unique_ptr<MessageTable>& message_table,
+ std::string name) {
+ bool error = false;
+ auto it = message_table->find(name);
+ assert(it != message_table->end());
+
+ std::vector<MPIRequest> requests = it->second;
+ assert(requests.size() > 0);
+
+ std::ostringstream error_message_stream;
+
+ // Check that all data types being reduced or gathered are identical
+ auto data_type = requests[0].tensor_type();
+ for (unsigned int i = 1; i < requests.size(); i++) {
+ auto request_type = requests[i].tensor_type();
+ if (data_type != request_type) {
+ error = true;
+ error_message_stream << "Mismatched data types: One rank had type "
+ << DataType_Name(data_type)
+ << ", but another rank had type "
+ << DataType_Name(request_type) << ".";
+ break;
+ }
+ }
+
+ // Check that all requested operations are the same
+ auto message_type = requests[0].request_type();
+ for (unsigned int i = 1; i < requests.size(); i++) {
+ if (error) {
+ break;
+ }
+
+ auto request_type = requests[i].request_type();
+ if (message_type != request_type) {
+ error = true;
+ error_message_stream << "Mismatched MPI operations: One rank did an "
+ << message_type << ", but another rank did an "
+ << request_type << ".";
+ break;
+ }
+ }
+
+ // If we are doing an allreduce, check that all tensor shapes
+ // are identical
+ if (message_type == MPIRequest::ALLREDUCE) {
+ TensorShape tensor_shape = requests[0].tensor_shape();
+ for (unsigned int i = 1; i < requests.size(); i++) {
+ if (error) {
+ break;
+ }
+
+ TensorShape request_shape = requests[i].tensor_shape();
+ if (tensor_shape != request_shape) {
+ error = true;
+ error_message_stream << "Mismatched allreduce tensor shapes: "
+ << "One rank reduced a tensor of shape "
+ << tensor_shape.DebugString()
+ << ", but another rank sent a tensor of shape "
+ << request_shape.DebugString() << ".";
+ break;
+ }
+ }
+ }
+
+ // If we are doing an allgather, make sure all but the first dimension are
+ // the same. The first dimension may be different and the output tensor is
+ // the sum of the first dimension. Collect the sizes by rank.
+ if (message_type == MPIRequest::ALLGATHER) {
+ TensorShape tensor_shape = requests[0].tensor_shape();
+
+ if (tensor_shape.dims() == 0) {
+ error = true;
+ error_message_stream << "Rank zero tried to gather a rank-zero tensor.";
+ }
+
+ for (unsigned int i = 1; i < requests.size(); i++) {
+ if (error) {
+ break;
+ }
+
+ TensorShape request_shape = requests[i].tensor_shape();
+ if (tensor_shape.dims() != request_shape.dims()) {
+ error = true;
+ error_message_stream << "Mismatched allgather tensor shapes: "
+ << "One rank gathered a tensor of rank "
+ << tensor_shape.dims()
+ << ", but another rank sent a tensor of rank "
+ << request_shape.dims() << ".";
+ break;
+ }
+
+ for (unsigned int dim = 1; dim < tensor_shape.dims(); dim++) {
+ if (tensor_shape.dim_size(dim) != request_shape.dim_size(dim)) {
+ error = true;
+ error_message_stream
+ << "Mismatched allgather tensor shapes: "
+ << "One rank gathered a tensor with dimension " << dim
+ << " equal to " << tensor_shape.dim_size(dim)
+ << ", but another rank sent a tensor with dimension " << dim
+ << " equal to " << request_shape.dim_size(dim) << ".";
+ break;
+ }
+ }
+ }
+ }
+
+ MPIResponse response;
+ response.set_tensor_name(name);
+ if (error) {
+ std::string error_message = error_message_stream.str();
+ response.set_response_type(MPIResponse::ERROR);
+ response.set_error_message(error_message);
+ } else {
+ auto response_type = MPIResponse::ERROR;
+ if (message_type == MPIRequest::ALLREDUCE) {
+ response_type = MPIResponse::ALLREDUCE;
+ } else {
+ response_type = MPIResponse::ALLGATHER;
+ }
+ response.set_response_type(response_type);
+ }
+
+ // Clear all queued up requests for this name. They are now taken care of
+ // by the constructed MPI response.
+ message_table->erase(it);
+
+ return response;
+}
+
+// Process an MPIResponse by doing a reduction, a gather, or raising an error.
+void PerformCollectiveOp(TensorTable& tensor_table, MPIResponse response) {
+ OpKernelContext* context;
+ const Tensor* input_tensor;
+ std::vector<size_t> sizes_vec;
+ Tensor temp_tensor;
+ Tensor* output_tensor;
+ CommunicationDoneCallback callback;
+ bool on_gpu;
+ {
+ // Lock on the tensor table.
+ mutex_lock guard(mpi_global.mu);
+
+ // We should never fail at finding this key in the tensor table.
+ auto name = response.tensor_name();
+ auto iter = tensor_table.find(name);
+ assert(iter != tensor_table.end());
+
+ assert(response.response_type() == MPIResponse::ALLREDUCE ||
+ response.response_type() == MPIResponse::ALLGATHER ||
+ response.response_type() == MPIResponse::ERROR);
+
+ CollectiveOpRecord record = iter->second;
+ context = record.context;
+ input_tensor = record.in_t;
+ sizes_vec = record.sizes_vec;
+ temp_tensor = record.temp_t;
+ output_tensor = record.out_t;
+ on_gpu = record.on_gpu;
+ callback = record.callback;
+
+ // Clear the tensor table of this tensor and its callbacks; the rest of
+ // this function takes care of it.
+ tensor_table.erase(iter);
+ }
+
+ // Use CPUDevice instead of GPUDevice if no CUDA, to ensure we don't
+ // link to non-existent symbols.
+#if GOOGLE_CUDA
+#define GPU_DEVICE_IF_CUDA GPUDevice
+#else
+#define GPU_DEVICE_IF_CUDA CPUDevice
+#endif
+
+ Status status;
+ auto dtype = input_tensor->dtype();
+ if (response.response_type() == MPIResponse::ALLGATHER) {
+ if (dtype == DT_FLOAT) {
+ status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, float>(
+ context, input_tensor, sizes_vec, output_tensor)
+ : RingAllgather<CPUDevice, float>(
+ context, input_tensor, sizes_vec, output_tensor);
+ } else if (dtype == DT_INT32) {
+ status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, int>(
+ context, input_tensor, sizes_vec, output_tensor)
+ : RingAllgather<CPUDevice, int>(context, input_tensor,
+ sizes_vec, output_tensor);
+ } else if (dtype == DT_INT64) {
+ status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, long long>(
+ context, input_tensor, sizes_vec, output_tensor)
+ : RingAllgather<CPUDevice, long long>(
+ context, input_tensor, sizes_vec, output_tensor);
+ } else {
+ status = errors::Unknown("Invalid tensor type for MPI allgather.");
+ }
+ } else if (response.response_type() == MPIResponse::ALLREDUCE) {
+ if (dtype == DT_FLOAT) {
+ status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, float>(
+ context, input_tensor, &temp_tensor, output_tensor)
+ : RingAllreduce<CPUDevice, float>(
+ context, input_tensor, &temp_tensor, output_tensor);
+ } else if (dtype == DT_INT32) {
+ status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, int>(
+ context, input_tensor, &temp_tensor, output_tensor)
+ : RingAllreduce<CPUDevice, int>(
+ context, input_tensor, &temp_tensor, output_tensor);
+ } else if (dtype == DT_INT64) {
+ status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, long long>(
+ context, input_tensor, &temp_tensor, output_tensor)
+ : RingAllreduce<CPUDevice, long long>(
+ context, input_tensor, &temp_tensor, output_tensor);
+ } else {
+ status = errors::Unknown("Invalid tensor type for MPI allreduce.");
+ }
+ } else if (response.response_type() == MPIResponse::ERROR) {
+ status = errors::FailedPrecondition(response.error_message());
+ }
+
+ if (status.ok()) {
+ callback(StatusOr<Tensor>(*output_tensor));
+ } else {
+ callback(StatusOr<Tensor>(status));
+ }
+}
+
+// The MPI background thread loop coordinates all the MPI processes and the
+// tensor reductions. The design of the communicator mechanism is limited by a
+// few considerations:
+//
+// 1. Some MPI implementations require all MPI calls to happen from a
+// single thread. Since TensorFlow may use several threads for graph
+// processing, this means we must have our own dedicated thread for
+// dealing with MPI.
+// 2. We want to gracefully handle errors, when MPI processes do not
+// properly agree upon what should happen (such as mismatched types or
+// shapes). To do so requires the MPI processes to know about the shapes
+// and types of the relevant tensors on the other processes.
+// 3. The MPI reductions and gathers should be able to happen in parallel
+// with other ongoing operations. Since MPI uses an internal
+// (inaccessible) GPU stream separate from the TF GPUDevice streams, we
+// cannot explicitly synchronize memcpys or kernels with it. As a result,
+// MPIAllreduce and MPIAllgather must be AsyncOpKernels to ensure proper
+// ordering of memcpys and kernels with respect to TF streams.
+// 4. NOTE: We cannot guarantee that all the MPI processes reduce their
+// tensors in the same order. Thus, there must be a way to ensure the
+// reduction memcpys and kernels occur for correct tensors across all
+// ranks at the same time. We choose to use a coordinator (rank ID 0) to
+// gather and trigger the reduction operations that are ready to execute.
+//
+// The coordinator currently follows a master-worker paradigm. Rank zero acts
+// as the master (the "coordinator"), whereas all other ranks are simply
+// workers. Each rank runs its own background thread which progresses in ticks.
+// In each tick, the following actions happen:
+//
+// a) The workers send any available MPIRequests to the coordinator. These
+// MPIRequests indicate what the worker would like to do (i.e. which
+// tensor they would like to gather or reduce, as well as their shape and
+// type). They repeat this for every tensor that they would like to
+// operate on after that tensor's collective op has executed ComputeAsync.
+//
+// b) The workers send an empty "DONE" message to the coordinator to
+// indicate that there are no more tensors they wish to operate on.
+//
+// c) The coordinator receives the MPIRequests from the workers, as well
+// as from its own TensorFlow ops, and stores them in a request table. The
+// coordinator continues to receive MPIRequest messages until it has
+// received MPI_SIZE number of empty "DONE" messages.
+//
+// d) The coordinator finds all tensors that are ready to be reduced,
+// gathered, or all operations that result in an error. For each of those,
+// it sends an MPIResponse to all the workers. When no more MPIResponses
+// are available, it sends a "DONE" response to the workers. If the
+// process is being shutdown, it instead sends a "SHUTDOWN" response.
+//
+// e) The workers listen for MPIResponse messages, processing each one by
+// doing the required reduce or gather, until they receive a "DONE"
+// response from the coordinator. At that point, the tick ends.
+// If instead of "DONE" they receive "SHUTDOWN", they exit their
+// background loop.
+// TODO: Use the global mpi_global state variable instead of a local one
+void BackgroundThreadLoop() {
+#if GOOGLE_CUDA
+ // Set the device, so that this thread uses the same GPU context as the
+ // calling thread.
+ // TODO: Ensure that this is operating correctly. The background thread
+ // needs to be able to control all GPUs that the rank has access to, and
+ // might be more than 1 GPU. Tensors could be resident in any of the
+ // GPUs, so the background thread's accumulate and copy kernels might need
+ // to correctly set the device and it might be necessary for the background
+ // thread to manage multiple streams.
+ cudaSetDevice(mpi_global.device);
+ cudaStreamCreate(&mpi_global.stream);
+#endif
+
+ // Initialize MPI. This must happen on the background thread, since not all
+ // MPI implementations support being called from multiple threads.
+ auto init_result = MPI_Init(NULL, NULL);
+ if (init_result != MPI_SUCCESS) {
+ mpi_global.init_status =
+ errors::Unknown("Could not initialize MPI; MPI_Init() failed.");
+ mpi_global.initialization_done = true;
+ mpi_global.cv.notify_all();
+ return;
+ } else {
+ mpi_global.init_status = Status::OK();
+ }
+
+ // Get MPI rank to determine if we are rank zero.
+ int rank;
+ MPI_Comm_rank(MPI_COMM_WORLD, &rank);
+ bool is_coordinator = rank == 0;
+
+ // Get MPI size to determine how many tensors to wait for before reducing.
+ int size;
+ MPI_Comm_size(MPI_COMM_WORLD, &size);
+
+ // Determine local rank by querying the local communicator.
+ MPI_Comm local_comm;
+ MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL,
+ &local_comm);
+ int local_rank;
+ MPI_Comm_rank(local_comm, &local_rank);
+
+ mpi_global.rank = rank;
+ mpi_global.local_rank = local_rank;
+ mpi_global.size = size;
+ mpi_global.initialization_done = true;
+
+ // Notify calling thread that initialization is complete
+ mpi_global.cv.notify_all();
+
+ // TODO: MOVE MESSAGE TABLE INITIALIZATION TO LIBRARY LOAD!
+ // Initialize the tensor count table. No tensors are available yet.
+ if (is_coordinator) {
+ mpi_global.message_table =
+ std::unique_ptr<MessageTable>(new MessageTable());
+ }
+
+ // The coordinator sends a SHUTDOWN message to trigger shutdown.
+ bool should_shut_down = false;
+ do {
+ // TODO: Eliminate the need for thread sleep by making all activity
+ // depend on other activity (e.g. condition or MPI waits).
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+
+ // Copy the data structures from global state under this lock.
+ // However, don't keep the lock for the rest of the loop, so that
+ // enqueued stream callbacks can continue.
+ std::queue<MPIRequest> message_queue;
+ {
+ mutex_lock guard(mpi_global.mu);
+ while (!mpi_global.message_queue.empty()) {
+ MPIRequest message = mpi_global.message_queue.front();
+ mpi_global.message_queue.pop();
+ message_queue.push(message);
+ }
+ }
+
+ // Collect all tensors that are ready to be reduced. Record them in the
+ // tensor count table (rank zero) or send them to rank zero to be
+ // recorded (everyone else).
+ std::vector<std::string> ready_to_reduce;
+ while (!message_queue.empty()) {
+ // Pop the first available message message
+ MPIRequest message = message_queue.front();
+ message_queue.pop();
+
+ if (is_coordinator) {
+ bool reduce =
+ IncrementTensorCount(mpi_global.message_table, message, size);
+ if (reduce) {
+ ready_to_reduce.push_back(message.tensor_name());
+ }
+ } else {
+ std::string encoded_message;
+ message.SerializeToString(&encoded_message);
+ MPI_Send(encoded_message.c_str(), encoded_message.length() + 1,
+ MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
+ }
+ }
+
+ // Rank zero has put all its own tensors in the tensor count table.
+ // Now, it should count all the tensors that are coming from other
+ // ranks at this tick. It should keep getting tensors until it gets a
+ // DONE message from all the other ranks.
+ if (is_coordinator) {
+ // Count of DONE messages. Keep receiving messages until the number
+ // of messages is equal to the number of processes. Initialize to
+ // one since the coordinator is effectively done.
+ int completed_ranks = 1;
+ while (completed_ranks != size) {
+ MPI_Status status;
+ MPI_Probe(MPI_ANY_SOURCE, TAG_NOTIFY, MPI_COMM_WORLD, &status);
+
+ // Find number of characters in message (including zero byte).
+ int source_rank = status.MPI_SOURCE;
+ int msg_length;
+ MPI_Get_count(&status, MPI_BYTE, &msg_length);
+
+ // If the length is zero, this is a DONE message.
+ if (msg_length == 0) {
+ completed_ranks++;
+ MPI_Recv(NULL, 0, MPI_BYTE, source_rank, TAG_NOTIFY, MPI_COMM_WORLD,
+ &status);
+ continue;
+ }
+
+ // Get tensor name from MPI into an std::string.
+ char* buffer = new char[msg_length];
+ MPI_Recv(buffer, msg_length, MPI_BYTE, source_rank, TAG_NOTIFY,
+ MPI_COMM_WORLD, &status);
+ std::string received_data(buffer);
+ delete[] buffer;
+
+ MPIRequest received_message;
+ received_message.ParseFromString(received_data);
+ auto received_name = received_message.tensor_name();
+
+ bool reduce = IncrementTensorCount(mpi_global.message_table,
+ received_message, size);
+ if (reduce) {
+ ready_to_reduce.push_back(received_name);
+ }
+ }
+
+ // At this point, rank zero should have a fully updated tensor
+ // count table and should know all the tensors that need to be
+ // reduced or gathered, and everyone else should have sent all
+ // their information to rank zero. We can now do reductions and
+ // gathers; rank zero will choose which ones and in what order,
+ // and will notify the other ranks before doing each reduction.
+ for (int i = 0; i < ready_to_reduce.size(); i++) {
+ // Notify all nodes which tensor we'd like to reduce now
+ auto name = ready_to_reduce[i];
+ MPIResponse response =
+ ConstructMPIResponse(mpi_global.message_table, name);
+
+ std::string encoded_response;
+ response.SerializeToString(&encoded_response);
+ for (int r = 1; r < size; r++) {
+ MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
+ MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
+ }
+
+ // Perform the reduction. All nodes should end up performing
+ // the same reduction.
+ PerformCollectiveOp(mpi_global.tensor_table, response);
+ }
+
+ // Notify all nodes that we are done with the reductions for this
+ // tick.
+ MPIResponse done_response;
+ should_shut_down = mpi_global.shut_down;
+ done_response.set_response_type(
+ mpi_global.shut_down ? MPIResponse::SHUTDOWN : MPIResponse::DONE);
+ std::string encoded_response;
+ done_response.SerializeToString(&encoded_response);
+ for (int r = 1; r < size; r++) {
+ MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
+ MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
+ }
+ } else {
+ // Notify the coordinator that this node is done sending messages.
+ // A DONE message is encoded as a zero-length message.
+ MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
+
+ // Receive names for tensors to reduce from rank zero. Once we
+ // receive a empty DONE message, stop waiting for more names.
+ while (true) {
+ MPI_Status status;
+ MPI_Probe(0, TAG_NOTIFY, MPI_COMM_WORLD, &status);
+
+ // Find number of characters in message (including zero byte).
+ int msg_length;
+ MPI_Get_count(&status, MPI_BYTE, &msg_length);
+
+ // Get tensor name from MPI into an std::string.
+ char* buffer = new char[msg_length];
+ MPI_Recv(buffer, msg_length, MPI_BYTE, 0, TAG_NOTIFY, MPI_COMM_WORLD,
+ &status);
+ std::string received_message(buffer);
+ delete[] buffer;
+
+ MPIResponse response;
+ response.ParseFromString(received_message);
+ if (response.response_type() == MPIResponse::DONE) {
+ // No more messages this tick
+ break;
+ } else if (response.response_type() == MPIResponse::SHUTDOWN) {
+ // No more messages this tick, and the background thread
+ // should shut down
+ should_shut_down = true;
+ break;
+ } else {
+ // Process the current message
+ PerformCollectiveOp(mpi_global.tensor_table, response);
+ }
+ }
+ }
+ } while (!should_shut_down);
+
+ MPI_Finalize();
+}
+
+// Initialize MPI and start the MPI background thread. Ensure that this is
+// only done once no matter how many times this function is called.
+Status InitializeMPIOnce(bool gpu) {
+ // Ensure MPI is only initialized once.
+ if (mpi_global.initialized_flag.test_and_set()) return mpi_global.init_status;
+
+ mpi_global.device = -1;
+#if GOOGLE_CUDA
+ if (gpu) {
+ cudaGetDevice(&mpi_global.device);
+ }
+#endif
+
+ // Start the MPI background thread, which assumes MPI is initialized
+ // TODO: Change this to a Tensorflow thread
+ mpi_global.background_thread = std::thread(BackgroundThreadLoop);
+
+ // Wait to ensure that the background thread has finished initializing MPI
+ mutex_lock guard(mpi_global.mu);
+ mpi_global.cv.wait(guard);
+ if (!mpi_global.initialization_done) {
+ mpi_global.init_status =
+ errors::Unknown("Failed to wait for MPI initialization.");
+ }
+
+ return mpi_global.init_status;
+}
+
+// Check that MPI is initialized.
+Status IsMPIInitialized() {
+ if (!mpi_global.initialization_done) {
+ return errors::FailedPrecondition(
+ "MPI has not been initialized; use tf.contrib.mpi.Session.");
+ }
+ return Status::OK();
+}
+
+// This function (called from the callback set up in MPIAll*Op::ComputeAsync)
+// only adds the op's record into the local op queue (to track the op's
+// progress), and sends a message to the coordinator indicating that this rank
+// is ready to begin. The MPI background thread will handle the MPI message.
+void EnqueueTensorCollective(CollectiveOpRecord record,
+ MPIRequest::RequestType rtype) {
+ const Tensor* input_tensor = record.in_t;
+ MPIRequest message;
+ message.set_request_rank(record.rank);
+ message.set_tensor_name(record.name);
+ message.set_tensor_type(record.dtype);
+ message.set_request_type(rtype);
+ input_tensor->shape().AsProto(message.mutable_tensor_shape());
+
+ mutex_lock guard(mpi_global.mu);
+ mpi_global.tensor_table.emplace(record.name, record);
+ mpi_global.message_queue.push(message);
+}
+
+} // namespace
+
+#if GOOGLE_CUDA
+cudaStream_t CudaStreamForMPI() { return mpi_global.stream; }
+#endif
+
+// Op to initialize MPI in the current process. The settings used in the
+// configuration are the same that must be used for all future MPI ops.
+template <typename Device>
+class MPIInitOp : public OpKernel {
+ public:
+ explicit MPIInitOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ bool on_gpu = IsGPUDevice<Device>();
+ OP_REQUIRES_OK(context, InitializeMPIOnce(on_gpu));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_CPU),
+ MPIInitOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_GPU),
+ MPIInitOp<GPUDevice>);
+#endif
+
+REGISTER_OP("MPIInit").Doc(R"doc(
+Initialize MPI for the current process.
+
+If this is run on a GPU, then that GPU must be used for all future MPI
+operations. If it is run on CPU, then all future MPI operations must also
+run on CPU.
+)doc");
+
+// Op to get the current MPI Size.
+template <typename Device>
+class MPISizeOp : public OpKernel {
+ public:
+ explicit MPISizeOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ OP_REQUIRES_OK(context, IsMPIInitialized());
+
+ // Write integer to output tensor
+ Tensor* output;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({}), &output));
+
+ auto flat = output->flat<int>();
+ flat(0) = mpi_global.size;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_CPU),
+ MPISizeOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_GPU).HostMemory("size"),
+ MPISizeOp<GPUDevice>);
+#endif
+
+REGISTER_OP("MPISize")
+ .Output("size: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Returns the number of running MPI processes.
+
+More precisely, returns the number of MPI processes in the group associated
+with the MPI_COMM_WORLD communicator.
+
+size: Size of the MPI group.
+)doc");
+
+// Op to get the current MPI Rank.
+template <typename Device>
+class MPIRankOp : public OpKernel {
+ public:
+ explicit MPIRankOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ OP_REQUIRES_OK(context, IsMPIInitialized());
+
+ // Write integer to output tensor
+ Tensor* output;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({}), &output));
+
+ auto flat = output->flat<int>();
+ flat(0) = mpi_global.rank;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_CPU),
+ MPIRankOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_GPU).HostMemory("rank"),
+ MPIRankOp<GPUDevice>);
+#endif
+
+REGISTER_OP("MPIRank")
+ .Output("rank: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Returns the index of the current process in the MPI group.
+
+More precisely, returns the rank of the calling process in the MPI_COMM_WORLD
+communicator.
+
+rank: Rank of the calling process.
+)doc");
+
+// Op to get the current local MPI Rank.
+template <typename Device>
+class MPILocalRankOp : public OpKernel {
+ public:
+ explicit MPILocalRankOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ OP_REQUIRES_OK(context, IsMPIInitialized());
+
+ // Write integer to output tensor
+ Tensor* output;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({}), &output));
+
+ auto flat = output->flat<int>();
+ flat(0) = mpi_global.local_rank;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPILocalRank").Device(DEVICE_CPU),
+ MPILocalRankOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(
+ Name("MPILocalRank").Device(DEVICE_GPU).HostMemory("rank"),
+ MPILocalRankOp<GPUDevice>);
+#endif
+
+REGISTER_OP("MPILocalRank")
+ .Output("rank: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Returns the index of the current process in the node it is on.
+
+More precisely, returns the rank of the calling process in communicator that
+only spans the MPI processes running on that node.
+
+rank: Rank of the calling process on the node it is on.
+)doc");
+
+template <typename Device>
+class MPIAllreduceOp : public AsyncOpKernel {
+ public:
+ explicit MPIAllreduceOp(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {}
+
+ // Although this op is handled asynchronously, the ComputeAsync call is
+ // very inexpensive. It only sets up a CollectiveOpRecord and places it
+ // in the table for the background thread to handle. Thus, we do not need
+ // a TF pool thread to perform the op.
+ bool IsExpensive() override { return false; }
+
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
+ OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done);
+ const Tensor* input_tensor = &context->input(0);
+ Tensor* output_tensor;
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ context->allocate_output(0, input_tensor->shape(), &output_tensor),
+ done);
+
+ // Record allocated on stack so op can fail without memory leak
+ CollectiveOpRecord record;
+ record.name = name();
+ record.context = context;
+ record.in_t = input_tensor;
+ record.out_t = output_tensor;
+ record.on_gpu = IsGPUDevice<Device>();
+ record.dtype = input_tensor->dtype();
+
+ const size_t temp_size =
+ (input_tensor->NumElements() + mpi_global.size - 1) / mpi_global.size;
+ TensorShape temp_shape;
+ temp_shape.AddDim(temp_size);
+ OP_REQUIRES_OK_ASYNC(context,
+ context->allocate_temp(input_tensor->dtype(),
+ temp_shape, &record.temp_t),
+ done);
+
+ auto allreduce_done_callback = [done, context](StatusOr<Tensor> status) {
+ context->SetStatus(status.status());
+ done();
+ };
+ record.callback = allreduce_done_callback;
+
+ auto allreduce_launch_callback = [record] {
+ EnqueueTensorCollective(record, MPIRequest::ALLREDUCE);
+ };
+
+ // If we are on a CPU, our device context will be null and we can't
+ // get a stream to enqueue this on. On a CPU this op is called when the
+ // data is already available, so we can just immediately do the
+ // allreduce; we don't have to wait for the data to get populated.
+#if GOOGLE_CUDA
+ auto device_context = context->op_device_context();
+ if (device_context == nullptr) {
+ allreduce_launch_callback();
+ } else {
+ auto stream = device_context->stream();
+ stream->ThenDoHostCallback(allreduce_launch_callback);
+ }
+#else
+ allreduce_launch_callback();
+#endif
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_CPU),
+ MPIAllreduceOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_GPU),
+ MPIAllreduceOp<GPUDevice>);
+#endif
+
+REGISTER_OP("MPIAllreduce")
+ .Attr("T: {int32, int64, float32}")
+ .Input("tensor: T")
+ .Output("sum: T")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Perform an MPI Allreduce on a tensor. All other processes that do a reduction
+on a tensor with the same name must have the same dimension for that tensor.
+Tensors are reduced with other tensors that have the same node name for the
+allreduce.
+
+Arguments
+ tensor: A tensor to reduce.
+
+Output
+ sum: A tensor with the same shape as `tensor`, summed across all
+ MPI processes.
+)doc");
+
+template <typename Device>
+class MPIAllgatherOp : public AsyncOpKernel {
+ public:
+ explicit MPIAllgatherOp(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {}
+
+ // Although this op is handled asynchronously, the ComputeAsync call is
+ // very inexpensive. It only sets up a CollectiveOpRecord and places it
+ // in the table for the background thread to handle. Thus, we do not need
+ // a TF pool thread to perform the op.
+ bool IsExpensive() override { return false; }
+
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
+ OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done);
+ const Tensor* input_tensor = &context->input(0);
+ const Tensor* sizing_tensor = &context->input(1);
+
+ // Record allocated on stack so op can fail without memory leak
+ CollectiveOpRecord record;
+ record.name = name();
+ record.context = context;
+ record.in_t = input_tensor;
+ record.on_gpu = IsGPUDevice<Device>();
+
+ // Construct the output size from the sizing tensor
+ size_t output_first_dim = 0;
+ if (sizing_tensor->shape().dims() == 0) {
+ // 0-dim sizing_tensor implies that the op is just gathering
+ // a single element from each rank
+ output_first_dim = mpi_global.size;
+ for (int i = 0; i < mpi_global.size; i++) {
+ record.sizes_vec.push_back(1);
+ }
+ } else {
+ // Collect the total output tensor sizing from the sizing tensor
+ // NOTE: The sizing tensor is forced to be placed on the CPU by
+ // declaring the input as HostMemory, so it is valid to read it here.
+ const int64* sizing_array =
+ (const int64*)sizing_tensor->tensor_data().data();
+ for (int i = 0; i < mpi_global.size; i++) {
+ record.sizes_vec.push_back(sizing_array[i]);
+ output_first_dim += sizing_array[i];
+ }
+ }
+
+ TensorShape output_shape;
+ output_shape.AddDim(output_first_dim);
+ for (int i = 1; i < input_tensor->shape().dims(); i++) {
+ output_shape.AddDim(input_tensor->shape().dim_size(i));
+ }
+
+ Tensor* output_tensor;
+ OP_REQUIRES_OK_ASYNC(
+ context, context->allocate_output(0, output_shape, &output_tensor),
+ done);
+
+ record.out_t = output_tensor;
+ record.dtype = input_tensor->dtype();
+
+ auto allgather_done_callback = [done, context](StatusOr<Tensor> status) {
+ context->SetStatus(status.status());
+ done();
+ };
+ record.callback = allgather_done_callback;
+
+ auto allgather_launch_callback = [record] {
+ EnqueueTensorCollective(record, MPIRequest::ALLGATHER);
+ };
+
+ // If we are on a CPU, our device context will be null and we can't
+ // get a stream to enqueue this on. On a CPU this op is called when the
+ // data is already available, so we can just immediately do the
+ // allgather; we don't have to wait for the data to get populated.
+#if GOOGLE_CUDA
+ auto device_context = context->op_device_context();
+ if (device_context == nullptr) {
+ allgather_launch_callback();
+ } else {
+ auto stream = device_context->stream();
+ stream->ThenDoHostCallback(allgather_launch_callback);
+ }
+#else
+ allgather_launch_callback();
+#endif
+ }
+};
+
+REGISTER_OP("MPIAllgather")
+ .Attr("T: {int32, int64, float32}")
+ .Attr("S: {int64}")
+ .Input("tensor: T")
+ .Input("sizes: S")
+ .Output("gathered: T")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle output;
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &output));
+ c->set_output(0, output);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Perform an MPI Allgather on a tensor. All other processes that do a gather on a
+tensor with the same name must have the same rank for that tensor, and have the
+same dimension on all but the first dimension.
+
+Arguments
+ tensor: A tensor to gather.
+ sizes: A tensor containing the first-dimension sizes of tensors to be
+ gathered from other ranks
+
+Output
+ gathered: A tensor with the same shape as `tensor` except for the first
+ dimension, which is the sum of dimensions in `sizes`.
+)doc");
+
+REGISTER_KERNEL_BUILDER(
+ Name("MPIAllgather").Device(DEVICE_CPU).HostMemory("sizes"),
+ MPIAllgatherOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(
+ Name("MPIAllgather").Device(DEVICE_GPU).HostMemory("sizes"),
+ MPIAllgatherOp<GPUDevice>);
+#endif
+
+} // namespace mpi
+} // namespace contrib
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_MPI