aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/mpi_collectives
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-09-25 19:35:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-25 19:39:42 -0700
commite2e3a943c0a28b7656325acb3fcd035743d55ea0 (patch)
treef4b909d5410bdf3b94012392909e7805cd27a2a7 /tensorflow/contrib/mpi_collectives
parentdf22044be98c8b707601e03fe22ded53bcc28c7e (diff)
Merge changes from github.
END_PUBLIC --- Commit 1e1b3d902 authored by Pete Warden<pete@petewarden.com> Committed by gunan<gunan@google.com>: Changed output directory for Pi CI build to fix permissions problem with nightlies (#13257) * Fix for RTLD_GLOBAL breakage of Pi builds, and removed Eigen version change for Pi that's no longer needed * Fixed Pi Zero OpenBLAS build problems and tidied up directories used * More robust checks in Pi build script * Changed output directory for Pi CI build to fix permissions problem --- Commit fe3a2e65c authored by Yan Facai (???)<facai.yan@gmail.com> Committed by drpngx<drpngx@users.noreply.github.com>: check invalid string type for dest_nodes in extract_sub_graph (#13057) * BUG: check str type * TST: add unit test * CLN: remove list check * CLN: use warning * CLN: 2 indent * CLN: raise TypeError if not list * CLN: check string only --- Commit 225ab7629 authored by Jean Wanka<jm.wanka@gmail.com> Committed by Jean Wanka<jm.wanka@gmail.com>: Fix polynomial decay with cycle for global step=0 For polynomial decay with cycle=True the learning rate at step 0 becomes NaN, because in the process of calculating it we devide by 0. This change should fix it, by setting the multiplier for the decay steps to one for global_step=0. --- Commit 286f57061 authored by Bjarke Hammersholt Roune<broune@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Make Service::TransferToClient not attempt to manipulate the literal when the transfer failed, preventing a crash and allowing the caller to see the reason for the failed transfer. PiperOrigin-RevId: 169770126 --- Commit e0501bc4d authored by Yong Tang<yong.tang.github@outlook.com> Committed by Shanqing Cai<cais@google.com>: Fix GRUBlockCell parameter naming inconsistency (#13153) * Fix GRUBlockCell parameter naming inconsistency This fix tries to fix the issue in 13137 where parameter `cell_size` is used instead of `num_units`. This is inconsistent with other RNN cells. This fix adds support of `num_units` while at the same time maintains backward compatiblility for `cell_size`. This fix fixes 13137. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add `@deprecated_args` for 'cell_size' in `GRUBlockCell` This commit adds `@deprecated_args` for 'cell_size' in `GRUBlockCell` Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Address review comment Signed-off-by: Yong Tang <yong.tang.github@outlook.com> --- Commit 02a2eba05 authored by Pete Warden<pete@petewarden.com> Committed by gunan<gunan@google.com>: Fix for RTLD_GLOBAL breakage of Pi builds, and removed Eigen version change that's no longer needed (#13251) * Fix for RTLD_GLOBAL breakage of Pi builds, and removed Eigen version change for Pi that's no longer needed * Fixed Pi Zero OpenBLAS build problems and tidied up directories used * More robust checks in Pi build script --- Commit 8ef722253 authored by Sanjoy Das<sanjoy@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Remove a redundant setName. The EmitComputation should have emitted a function with the right name, so use a CHECK instead. PiperOrigin-RevId: 169764856 --- Commit 1b94147dc authored by Neal Wu<wun@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fix broken GitHub links in tensorflow and tensorflow_models resulting from The Great Models Move (a.k.a. the research subfolder) PiperOrigin-RevId: 169763373 --- Commit b1ada5f0c authored by Justine Tunney<jart@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fix TensorBoard python -m invoke in docs PiperOrigin-RevId: 169758752 --- Commit 2957cd894 authored by Mustafa Ispir<ispir@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Local run option of estimator training. PiperOrigin-RevId: 169756384 --- Commit 1dc2fe7ac authored by Gunhan Gulsoy<gunan@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BEGIN_PUBLIC Automated g4 rollback of changelist 166264198 PiperOrigin-RevId: 169998124
Diffstat (limited to 'tensorflow/contrib/mpi_collectives')
-rw-r--r--tensorflow/contrib/mpi_collectives/BUILD80
-rw-r--r--tensorflow/contrib/mpi_collectives/README.md5
-rw-r--r--tensorflow/contrib/mpi_collectives/__init__.py273
-rw-r--r--tensorflow/contrib/mpi_collectives/mpi_allgather_test.py114
-rw-r--r--tensorflow/contrib/mpi_collectives/mpi_allreduce_test.py153
-rw-r--r--tensorflow/contrib/mpi_collectives/mpi_message.proto64
-rw-r--r--tensorflow/contrib/mpi_collectives/mpi_ops.cc1236
-rw-r--r--tensorflow/contrib/mpi_collectives/mpi_ops.py165
-rw-r--r--tensorflow/contrib/mpi_collectives/mpi_ops_test.py296
-rw-r--r--tensorflow/contrib/mpi_collectives/ring.cc80
-rw-r--r--tensorflow/contrib/mpi_collectives/ring.cu.cc117
-rw-r--r--tensorflow/contrib/mpi_collectives/ring.h327
12 files changed, 2910 insertions, 0 deletions
diff --git a/tensorflow/contrib/mpi_collectives/BUILD b/tensorflow/contrib/mpi_collectives/BUILD
new file mode 100644
index 0000000000..11c5d6e776
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/BUILD
@@ -0,0 +1,80 @@
+# Ops that communicate with other processes via MPI.
+
+package(default_visibility = [
+ "//tensorflow:__subpackages__",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_proto_library_cc",
+)
+
+tf_proto_library_cc(
+ name = "mpi_message_proto",
+ srcs = ["mpi_message.proto"],
+ cc_api_version = 2,
+ protodeps = ["//tensorflow/core:protos_all"],
+ visibility = [
+ "//tensorflow:__subpackages__",
+ ],
+)
+
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+tf_custom_op_library(
+ name = "mpi_collectives.so",
+ srcs = [
+ "mpi_ops.cc",
+ "ring.cc",
+ "ring.h",
+ ],
+ gpu_srcs = [
+ "ring.cu.cc",
+ "ring.h",
+ ],
+ deps = [
+ ":mpi_message_proto_cc",
+ "//third_party/mpi",
+ ],
+)
+
+tf_py_test(
+ name = "mpi_ops_test",
+ srcs = ["mpi_ops_test.py"],
+ additional_deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:platform",
+ ],
+ data = [
+ ":mpi_collectives.so",
+ ],
+ tags = ["manual"],
+)
+
+py_library(
+ name = "mpi_ops_py",
+ srcs = [
+ "__init__.py",
+ "mpi_ops.py",
+ ],
+ data = [
+ ":mpi_collectives.so",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+)
diff --git a/tensorflow/contrib/mpi_collectives/README.md b/tensorflow/contrib/mpi_collectives/README.md
new file mode 100644
index 0000000000..c5e1a8c37e
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/README.md
@@ -0,0 +1,5 @@
+# MPI TensorFlow integration
+
+Tensorflow MPI integration allows communicating between different TensorFlow
+processes using MPI. This enables training across multiple nodes and GPUs
+using high-speed interconnects.
diff --git a/tensorflow/contrib/mpi_collectives/__init__.py b/tensorflow/contrib/mpi_collectives/__init__.py
new file mode 100644
index 0000000000..b94f7b0a35
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/__init__.py
@@ -0,0 +1,273 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=g-short-docstring-punctuation
+"""## Communicating Between Processes with MPI
+
+TensorFlow natively provides inter-device communication through send and
+receive ops and inter-node communication through Distributed TensorFlow, based
+on the same send and receive abstractions. On HPC clusters where Infiniband or
+other high-speed node interconnects are available, these can end up being
+insufficient for synchronous data-parallel training (without asynchronous
+gradient descent). This module implements a variety of MPI ops which can take
+advantage of hardware-specific MPI libraries for efficient communication.
+
+In order to use this module, TensorFlow must be built with an MPI library,
+which can be provided to the `./configure` script at build time. As a user of
+TensorFlow, you will need to build TensorFlow yourself to select the MPI
+library to use; to do so, follow the [instructions for building TensorFlow from
+source](https://www.tensorflow.org/get_started/os_setup#installing_from_sources).
+
+### Utility Ops
+
+In addition to reductions and gathers, this module provides utility operations
+for detecting the running MPI configuration.
+
+Example:
+
+```python
+from tensorflow.contrib import mpi
+
+# Use `mpi.Session` instead of `tf.Session`
+with mpi.Session() as session:
+ rank = session.run(mpi.rank())
+ print("My MPI Rank:", rank)
+
+ if rank == 0:
+ print("MPI Size:", session.run(mpi.size()))
+```
+
+@@rank
+@@size
+
+### Ring Allreduce and Allgather
+
+When summing or averaging tensors across many processes, communication can
+easily become a bottleneck. A naive implementation will send all the tensor
+values to the same process, perform the reduction, and then broadcast the
+values back to all other processes, effectively creating a synchronous
+parameter server in one process. However, the process responsible for
+performing the reduction will have to receive and send a massive amount of data
+which scales with the number of processes *and* the number of parameters in the
+model.
+
+Instead of centralizing the reduction and having one primary reducer, we can
+implement a distributed allreduce or allgather. A bandwidth-optimal allreduce
+will end up sending 2(N - 1) values for every value in the input tensor,
+and can be implemented with a ring allreduce [1]. (Intuitively, a linear reduce
+requires at least (N - 1) sends between the different nodes, and a broadcast of
+the result also requires (N - 1) sends, for a total of 2 (N - 1); these two
+steps cannot be combined in a clever way to reduce the number of required
+sends.) This module implements bandwidth-optimal ring allreduce and ring
+allgather operations using MPI; by choosing a hardware-appropriate MPI
+implementation (such as OpenMPI with CUDA-IPC support), you can train large
+models with synchronous gradient descent with minimal communication overhead.
+
+In addition to the `allreduce` and `allgather` functions, a convenience
+`DistributedOptimizer` wrapper is provided to simplify using these functions
+for reducing model gradients.
+
+Example:
+
+```python
+import tensorflow as tf
+from tensorflow.contrib import mpi_collectives as mpi
+
+# Construct a simple linear regression model to optimize
+W = tf.get_variable("W", shape=[20, 1], dtype=tf.float32)
+B = tf.get_variable("B", shape=[1, 1], dtype=tf.float32)
+inputs = tf.placeholder("Inputs", shape=[None, 20])
+outputs = tf.placeholder("Outputs", shape=[None, 1])
+loss = tf.nn.l2_loss(tf.matmul(inputs, W) + B - outputs)
+
+# Training using MPI allreduce with DistributedOptimizer
+optimizer = mpi.DistributedOptimizer(tf.train.AdamOptimizer())
+train = optimizer.minimize(loss)
+
+# Average loss over all ranks, for printing.
+# Do not pass this to an optimizer!
+avg_loss = mpi.allreduce(loss)
+
+# On different ranks, feed different input data.
+with mpi.Session() as session:
+ rank = session.run(mpi.rank())
+ batch_inputs, batch_outputs = construct_batch_for_rank(rank)
+ feed_dict = {inputs: batch_inputs, outputs: batch_outputs}
+ _, l = session.run([train, avg_loss], feed_dict=feed_dict)
+ print("Average Loss:", l)
+```
+
+[1] Patarasuk, Pitch and Yuan, Xin. "Bandwidth Optimal All-reduce Algorithms
+for Clusters of Workstations".
+
+@@Session
+@@DistributedOptimizer
+@@allreduce
+@@allgather
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.mpi_collectives.mpi_ops import size
+from tensorflow.contrib.mpi_collectives.mpi_ops import rank
+from tensorflow.contrib.mpi_collectives.mpi_ops import local_rank
+from tensorflow.contrib.mpi_collectives.mpi_ops import allgather
+from tensorflow.contrib.mpi_collectives.mpi_ops import _allreduce
+from tensorflow.contrib.mpi_collectives.mpi_ops import init
+
+
+def allreduce(tensor, average=True):
+ """Perform an MPI allreduce on a tf.Tensor or tf.IndexedSlices.
+
+ Arguments:
+ tensor: tf.Tensor, tf.Variable, or tf.IndexedSlices to reduce.
+ The shape of the input must be identical across all ranks.
+ average: If True, computes the average over all ranks.
+ Otherwise, computes the sum over all ranks.
+
+ This function performs a bandwidth-optimal ring allreduce on the input
+ tensor. If the input is an tf.IndexedSlices, the function instead does an
+ allgather on the values and the indices, effectively doing an allreduce on
+ the represented tensor.
+ """
+ if isinstance(tensor, tf.IndexedSlices):
+ # For IndexedSlices, do two allgathers intead of an allreduce.
+ mpi_size = tf.cast(size(), tensor.values.dtype)
+ values = allgather(tensor.values)
+ indices = allgather(tensor.indices)
+
+ # To make this operation into an average, divide all gathered values by
+ # the MPI size.
+ new_values = tf.div(values, mpi_size) if average else values
+ return tf.IndexedSlices(new_values, indices,
+ dense_shape=tensor.dense_shape)
+ else:
+ mpi_size = tf.cast(size(), tensor.dtype)
+ summed_tensor = _allreduce(tensor)
+ new_tensor = (tf.div(summed_tensor, mpi_size)
+ if average else summed_tensor)
+ return new_tensor
+
+
+class DistributedOptimizer(tf.train.Optimizer):
+ """An optimizer that wraps another tf.Optimizer, using an MPI allreduce to
+ average gradient values before applying gradients to model weights."""
+
+ def __init__(self, optimizer, name=None, use_locking=False):
+ """Construct a new DistributedOptimizer, which uses another optimizer
+ under the hood for computing single-process gradient values and
+ applying gradient updates after the gradient values have been averaged
+ across all the MPI ranks.
+
+ Args:
+ optimizer: Optimizer to use for computing gradients and applying updates.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "Distributed" followed by the provided
+ optimizer type.
+ use_locking: Whether to use locking when updating variables. See
+ Optimizer.__init__ for more info.
+ """
+ if name is None:
+ name = "Distributed{}".format(type(optimizer).__name__)
+
+ self._optimizer = optimizer
+ super(DistributedOptimizer, self).__init__(
+ name=name, use_locking=use_locking)
+
+ def compute_gradients(self, *args, **kwargs):
+ """Compute gradients of all trainable variables.
+
+ See Optimizer.compute_gradients() for more info.
+
+ In DistributedOptimizer, compute_gradients() is overriden to also
+ allreduce the gradients before returning them.
+ """
+ gradients = (super(DistributedOptimizer, self)
+ .compute_gradients(*args, **kwargs))
+ return [(allreduce(gradient), var) for (gradient, var) in gradients]
+
+ def _apply_dense(self, *args, **kwargs):
+ """Calls this same method on the underlying optimizer."""
+ return self._optimizer._apply_dense(*args, **kwargs)
+
+ def _apply_sparse(self, *args, **kwargs):
+ """Calls this same method on the underlying optimizer."""
+ return self._optimizer._apply_sparse(*args, **kwargs)
+
+ def _apply_sparse_duplicate_indices(self, *args, **kwargs):
+ """Calls this same method on the underlying optimizer."""
+ return self._optimizer._apply_sparse_duplicate_indices(*args,
+ **kwargs)
+
+ def _prepare(self, *args, **kwargs):
+ """Calls this same method on the underlying optimizer."""
+ return self._optimizer._prepare(*args, **kwargs)
+
+ def _create_slots(self, *args, **kwargs):
+ """Calls this same method on the underlying optimizer."""
+ return self._optimizer._create_slots(*args, **kwargs)
+
+ def _valid_dtypes(self, *args, **kwargs):
+ """Calls this same method on the underlying optimizer."""
+ return self._optimizer._valid_dtypes(*args, **kwargs)
+
+ def _finish(self, *args, **kwargs):
+ """Calls this same method on the underlying optimizer."""
+ return self._optimizer._finish(*args, **kwargs)
+
+
+class Session(tf.Session):
+ """A class for running TensorFlow operations, with copies of the same graph
+ running distributed across different MPI nodes.
+
+ The primary difference between `tf.Session` and
+ `tf.contrib.mpi_collectives.Session` is that the MPI `Session` ensures that
+ the `Session` options are correct for use with `tf.contrib.mpi`, and
+ initializes MPI immediately upon the start of the session.
+ """
+
+ def __init__(self, target='', graph=None, config=None):
+ """Creates a new TensorFlow MPI session.
+
+ Unlike a normal `tf.Session`, an MPI Session may only use a single GPU,
+ which must be specified in advance before the session is initialized.
+ In addition, it only uses a single graph evaluation thread, and
+ initializes MPI immediately upon starting.
+
+ If no `graph` argument is specified when constructing the session,
+ the default graph will be launched in the session. If you are
+ using more than one graph (created with `tf.Graph()` in the same
+ process, you will have to use different sessions for each graph,
+ but each graph can be used in multiple sessions. In this case, it
+ is often clearer to pass the graph to be launched explicitly to
+ the session constructor.
+
+ Args:
+ target: (Optional.) The execution engine to connect to.
+ graph: (Optional.) The `Graph` to be launched (described above).
+ config: (Optional.) A `ConfigProto` protocol buffer with configuration
+ options for the session.
+ """
+ super(Session, self).__init__(target, graph, config=config)
+
+ # Initialize MPI on the relevant device.
+ # TODO: Move this to library load and eliminate mpi.Session()
+ if graph is None:
+ graph = tf.get_default_graph()
+ with graph.as_default():
+ self.run(init())
diff --git a/tensorflow/contrib/mpi_collectives/mpi_allgather_test.py b/tensorflow/contrib/mpi_collectives/mpi_allgather_test.py
new file mode 100644
index 0000000000..c23dd33d57
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/mpi_allgather_test.py
@@ -0,0 +1,114 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import numpy as np
+import tensorflow as tf
+import tensorflow.contrib.mpi_collectives as mpi
+from tensorflow.python.platform import test
+
+
+average_allgather = False
+
+
+class AllgatherTest(test.TestCase):
+ def checkAllgather(self, num_ranks, all_gathered, local_gathered):
+ # Ensure that indices match.
+ all_gat_ind = np.sort(all_gathered.indices)
+ loc_gat_ind = np.sort(local_gathered.indices)
+ assert(len(loc_gat_ind) == len(all_gat_ind))
+ for i in range(len(loc_gat_ind)):
+ assert(loc_gat_ind[i] == all_gat_ind[i])
+
+ # For each index, verify same values.
+ local_checked = []
+ for i in range(len(local_gathered.indices)):
+ local_checked.append(False)
+ for i in range(len(all_gathered.indices)):
+ all_index = all_gathered.indices[i]
+ # TODO(jthestness): Make this lookup quicker using sorting.
+ loc_index = -1
+ for j in range(len(local_gathered.indices)):
+ if local_gathered.indices[j] == all_index and not local_checked[j]:
+ loc_index = j
+ local_checked[j] = True
+ break
+ assert(loc_index >= 0)
+ correct_output = local_gathered.values[loc_index][0]
+ if average_allgather:
+ correct_output = correct_output / float(num_ranks)
+ assert(all_gathered.values[i][0] == correct_output)
+
+
+ def test_mpi_allgather(self):
+ # Get MPI rank
+ my_rank = int(os.environ['PMI_RANK'])
+ num_ranks = int(os.environ['PMI_SIZE'])
+
+ indices_per_rank = 100
+ tensor_width = 10
+
+ # Create IndexedSlices for each rank, some with overlapping indices.
+ to_gather_indices = []
+ to_gather_values = []
+ to_gather = []
+ for rank_id in range(num_ranks):
+ indices = []
+ values = []
+ my_multiple = rank_id + 1
+ current_index = my_multiple
+ for i in range(indices_per_rank):
+ indices.append(current_index)
+ ones_tensor = tf.ones([tensor_width])
+ values.append(tf.multiply(ones_tensor,
+ tf.fill(ones_tensor.get_shape(),
+ float(current_index))))
+ current_index += my_multiple
+ concat_ind = tf.stack(indices)
+ concat_vals = tf.stack(values)
+ to_gather_indices.append(concat_ind)
+ to_gather_values.append(concat_vals)
+ to_gather.append(tf.IndexedSlices(concat_vals, concat_ind))
+
+ # Collect the local IndexedSlices (indices and values) to create
+ # correct IndexedSlices output.
+ correct_gather_indices = tf.concat(to_gather_indices, 0)
+ correct_gather_values = tf.concat(to_gather_values, 0)
+ correct_gather = tf.IndexedSlices(correct_gather_values,
+ correct_gather_indices)
+
+ all_gather = mpi.allreduce(to_gather[my_rank], average_allgather)
+
+ # NOTE: This assumes that device IDs are numbered the same as ranks.
+ gpu_options = tf.GPUOptions(visible_device_list=str(my_rank))
+ config = tf.ConfigProto(gpu_options=gpu_options)
+
+ # MPI Session to test allgather.
+ with mpi.Session(config=config) as sess:
+ sess.run(tf.global_variables_initializer())
+
+ all_gathered, local_gathered = sess.run([all_gather, correct_gather])
+
+ # Compare all_gathered with local_gathered.
+ self.checkAllgather(num_ranks, all_gathered, local_gathered)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/mpi_collectives/mpi_allreduce_test.py b/tensorflow/contrib/mpi_collectives/mpi_allreduce_test.py
new file mode 100644
index 0000000000..001f9170bc
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/mpi_allreduce_test.py
@@ -0,0 +1,153 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import numpy as np
+import tensorflow as tf
+import tensorflow.contrib.mpi_collectives as mpi
+from tensorflow.python.platform import test
+
+
+average_allreduce = False
+max_wrong_count = -1
+
+
+class AllreduceTest(test.TestCase):
+ def dumpFailure(self, my_rank, out_loc_red, my_correct, out_all_red,
+ our_correct):
+ # Find reduced/allreduced indices that are wrong and print all the
+ # values from output, slices, reduced, allreduced, so we can debug
+ # which is incorrect:
+ wrong_count = 0
+ red_dims = out_loc_red.shape
+ assert(len(red_dims) == 2)
+ for i in range(red_dims[0]):
+ for j in range(red_dims[1]):
+ suffix = ""
+ if out_loc_red[i][j] != my_correct[i][j] or \
+ out_all_red[i][j] != our_correct[i][j]:
+ suffix = "WRONG"
+ wrong_count += 1
+ print("{}\t{}\t{}\t{}\t{}\t{}"
+ .format(my_rank, i, j, out_loc_red[i][j],
+ out_all_red[i][j], suffix), flush=True)
+ if max_wrong_count > 0 and wrong_count >= max_wrong_count:
+ return
+
+ def test_mpi_allreduce(self):
+ # Get MPI rank
+ my_rank = int(os.environ['PMI_RANK'])
+ num_ranks = int(os.environ['PMI_SIZE'])
+
+ stages = 13
+ batch_size = 1331
+ hidden_size = batch_size
+ out_size = batch_size
+
+ # Input placeholder (batch_size x hidden) - init to 1s
+ inputs = tf.placeholder(tf.float32, shape=(batch_size, hidden_size),
+ name="Input")
+
+ # Large matrices (hidden x out_dim) - init random
+ weights = []
+ for i in range(stages):
+ initer = tf.constant_initializer(pow(2.0, i + 1.0))
+ weights.append(tf.get_variable("weights_{}".format(i),
+ shape=(hidden_size, out_size),
+ dtype=tf.float32,
+ initializer=initer))
+
+ # Calculate output through dependent allreduces
+ stage_input = inputs
+ for i in range(stages):
+ inter_output = tf.add(stage_input, weights[i],
+ name="add_red_{}".format(i))
+ stage_input = mpi.allreduce(inter_output,
+ average=average_allreduce)
+
+ all_reduced = stage_input
+
+ # Local reduced output for verification
+ local_input = inputs
+ for i in range(stages):
+ inter_output = tf.add(local_input, weights[i],
+ name="addin_loc_{}".format(i))
+ my_reducer = tf.Variable(initial_value=np.ones((hidden_size, out_size)),
+ dtype=tf.float32, name="loc_redr_{}".format(i))
+ for r in range(num_ranks):
+ my_reducer = tf.add(my_reducer, inter_output,
+ name="add_loc_{}_{}".format(i, r))
+ if average_allreduce:
+ local_input = tf.div(my_reducer, num_ranks,
+ name="div_loc_{}".format(i))
+ else:
+ local_input = my_reducer
+
+ local_reduced = local_input
+
+ # NOTE: This assumes that device IDs are numbered the same as ranks
+ gpu_options = tf.GPUOptions(visible_device_list=str(my_rank))
+ config = tf.ConfigProto(gpu_options=gpu_options)
+
+ # MPI Session to test allreduce
+ with mpi.Session(config=config) as sess:
+ sess.run(tf.global_variables_initializer())
+
+ input_feed = np.ones((batch_size, hidden_size), dtype=np.float32)
+ our_output = input_feed[0][0]
+ spread_var = 100
+ input_feed = input_feed + my_rank * spread_var
+ my_output = input_feed[0][0]
+ for i in range(stages):
+ curr_feed = my_output + pow(2.0, i + 1.0)
+ my_output = curr_feed * num_ranks + 1
+ curr_our_feed = our_output + pow(2.0, i + 1.0)
+ if i == 0:
+ sum_ranks = num_ranks * (num_ranks - 1) / 2
+ our_output = curr_our_feed * num_ranks + \
+ spread_var * sum_ranks
+ else:
+ our_output = curr_our_feed * num_ranks
+
+ print("rank {}: My output is {}".format(my_rank, my_output))
+ my_correct = np.zeros((batch_size, hidden_size), dtype=np.float32)
+ my_correct = my_correct + my_output
+ print("rank {}: Our output is {}".format(my_rank, our_output))
+ our_correct = np.zeros((batch_size, hidden_size), dtype=np.float32)
+ our_correct = our_correct + our_output
+
+ for i in range(1000):
+ if i % 100 == 0:
+ print("{}: iter {}".format(my_rank, i), flush=True)
+ feed_dict = {inputs: input_feed}
+ out_all_red, out_loc_red \
+ = sess.run([all_reduced, local_reduced],
+ feed_dict=feed_dict)
+
+ if not np.allclose(out_loc_red, my_correct) or \
+ not np.allclose(out_all_red, our_correct):
+ print("Test incorrect on iter {}".format(i), flush=True)
+ self.dumpFailure(my_rank, out_loc_red, my_correct, out_all_red,
+ our_correct)
+ assert(np.allclose(out_loc_red, my_correct) and
+ np.allclose(out_all_red, our_correct))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/mpi_collectives/mpi_message.proto b/tensorflow/contrib/mpi_collectives/mpi_message.proto
new file mode 100644
index 0000000000..7fa5e20301
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/mpi_message.proto
@@ -0,0 +1,64 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+package tensorflow.contrib.mpi;
+
+import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/types.proto";
+
+// An MPIRequest is a message sent from a rank greater than zero to the
+// coordinator (rank zero), informing the coordinator of an operation that
+// the rank wants to do and the tensor that it wants to apply the operation to.
+message MPIRequest {
+ enum RequestType {
+ ALLREDUCE = 0;
+ ALLGATHER = 1;
+ }
+
+ // The request rank is necessary to create a consistent ordering of results,
+ // for example in the allgather where the order of outputs should be sorted
+ // by rank.
+ int32 request_rank = 1;
+ RequestType request_type = 2;
+ DataType tensor_type = 3;
+ string tensor_name = 4;
+ TensorShapeProto tensor_shape = 5;
+};
+
+// An MPIResponse is a message sent from the coordinator (rank zero) to a rank
+// greater than zero, informing the rank of an operation should be performed
+// now. If the operation requested would result in an error (for example, due
+// to a type or shape mismatch), then the MPIResponse can contain an error and
+// an error message instead. Finally, an MPIResponse can be a DONE message (if
+// there are no more tensors to reduce on this tick of the background loop) or
+// SHUTDOWN if all MPI processes should shut down.
+message MPIResponse {
+ enum ResponseType {
+ ALLREDUCE = 0;
+ ALLGATHER = 1;
+ ERROR = 2;
+ DONE = 3;
+ SHUTDOWN = 4;
+ }
+
+ // Empty if the type is DONE or SHUTDOWN.
+ ResponseType response_type = 1;
+ string tensor_name = 2;
+
+ // Empty unless response_type is ERROR.
+ string error_message = 3;
+};
diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/mpi_ops.cc
new file mode 100644
index 0000000000..a051ab0004
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/mpi_ops.cc
@@ -0,0 +1,1236 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_MPI
+
+#include <queue>
+#include <thread>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/mutex.h"
+
+#define EIGEN_USE_THREADS
+
+#if GOOGLE_CUDA
+#include <cuda_runtime.h>
+#include "tensorflow/stream_executor/stream.h"
+#endif
+
+#include "tensorflow/stream_executor/lib/statusor.h"
+
+#define OMPI_SKIP_MPICXX
+#include "third_party/mpi/mpi.h"
+#include "tensorflow/contrib/mpi_collectives/mpi_message.pb.h"
+#include "tensorflow/contrib/mpi_collectives/ring.h"
+
+/*
+ * MPI Allreduce and Allgather Ops for TensorFlow.
+ *
+ * TensorFlow natively provides inter-device communication through send and
+ * receive ops and inter-node communication through Distributed TensorFlow,
+ * based on the same send and receive abstractions. These end up being
+ * insufficient for synchronous data-parallel training on HPC clusters where
+ * Infiniband or other high-speed interconnects are available. This module
+ * implements MPI ops for allgather and allreduce, which do bandwidth-optimal
+ * gathers and reductions and can take advantage of hardware-optimized
+ * communication libraries through the MPI implementation.
+ *
+ * The primary logic of the allreduce and allgather are in RingAllgather() and
+ * RingAllreduce(). The background thread which facilitates MPI operations is
+ * run in BackgroundThreadLoop(). The provided MPI ops are:
+ * – MPIInit:
+ * Initialize MPI on a given device (CPU or GPU).
+ * Should only be run on a single device in every process.
+ * – MPISize:
+ * Get the number of MPI processes in the global communicator.
+ * – MPIRank:
+ * Get the rank of the current MPI process in the global communicator.
+ * – MPILocalRank:
+ * Get the local rank of the current MPI process within its node.
+ * – MPIAllreduce:
+ * Perform an allreduce on a Tensor, returning the sum
+ * across all MPI processes in the global communicator.
+ * – MPIAllgather:
+ * Perform an allgather on a Tensor, returning the concatenation of
+ * the tensor on the first dimension across all MPI processes in the
+ * global communicator.
+ *
+ */
+
+template <class T>
+using StatusOr = perftools::gputools::port::StatusOr<T>;
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi {
+
+// Make sure template specializations are generated in the ring.cu.cc and the
+// ring.cc file, not in this file.
+extern template Status RingAllreduce<GPUDevice, int>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+extern template Status RingAllreduce<GPUDevice, long long>(OpKernelContext*,
+ const Tensor*,
+ Tensor*, Tensor*);
+extern template Status RingAllreduce<GPUDevice, float>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+extern template Status RingAllgather<GPUDevice, int>(OpKernelContext*,
+ const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+extern template Status RingAllgather<GPUDevice, long long>(
+ OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
+extern template Status RingAllgather<GPUDevice, float>(
+ OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
+extern template Status RingAllreduce<CPUDevice, int>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+extern template Status RingAllreduce<CPUDevice, long long>(OpKernelContext*,
+ const Tensor*,
+ Tensor*, Tensor*);
+extern template Status RingAllreduce<CPUDevice, float>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+extern template Status RingAllgather<CPUDevice, int>(OpKernelContext*,
+ const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+extern template Status RingAllgather<CPUDevice, long long>(
+ OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
+extern template Status RingAllgather<CPUDevice, float>(
+ OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
+
+namespace {
+
+// Return true if the templated type is GPUDevice, otherwise false.
+template <typename T>
+bool IsGPUDevice();
+template <>
+bool IsGPUDevice<GPUDevice>() {
+ return true;
+};
+template <>
+bool IsGPUDevice<CPUDevice>() {
+ return false;
+};
+
+// A callback to call after the MPI communication completes. Since the
+// allreduce and allgather ops are asynchronous, this callback is what resumes
+// computation after the reduction is completed.
+typedef std::function<void(StatusOr<Tensor>)> CommunicationDoneCallback;
+
+struct CollectiveOpRecord {
+ // The rank performing this piece of the op
+ int rank;
+
+ // The name of the op/tensor to be reduced
+ std::string name;
+
+ // The op's kernel context
+ OpKernelContext* context;
+
+ // Data type of the op
+ DataType dtype;
+
+ // The input tensor
+ const Tensor* in_t;
+
+ // Allgather: Vector of per-rank first-dimension sizes
+ std::vector<size_t> sizes_vec;
+
+ // The temp tensor for intermediate results
+ Tensor temp_t;
+
+ // The output tensor
+ Tensor* out_t;
+
+ // Whether to run this op on the gpu
+ bool on_gpu;
+
+ // The callback to call after the op has completed
+ CommunicationDoneCallback callback;
+};
+
+// Table storing Tensors to be reduced, keyed by unique name.
+// This table contains everything necessary to do the reduction
+typedef std::unordered_map<std::string, CollectiveOpRecord> TensorTable;
+
+// Table for storing Tensor metadata on rank zero. This is used for error
+// checking and size calculations, as well as determining when a reduction is
+// ready to be done (when all nodes are ready to do it).
+typedef std::unordered_map<std::string, std::vector<MPIRequest> > MessageTable;
+
+// The global state required for the MPI ops.
+//
+// MPI is a library that stores a lot of global per-program state and often
+// requires running on a single thread. As a result, we have to have a single
+// background thread responsible for all MPI operations, and communicate with
+// that background thread through global state.
+struct MPIGlobalState {
+ // An atomic boolean which is set to true when MPI is initialized.
+ // This ensures that MPI_Init is never called twice.
+ std::atomic_flag initialized_flag = ATOMIC_FLAG_INIT;
+
+ // Condition variable to wait for initialization
+ condition_variable cv;
+
+ // Whether MPI_Init has been completed on the background thread.
+ bool initialization_done = false;
+
+ // Whether MPI_Init succeeded on the background thread.
+ Status init_status;
+
+ // A mutex that needs to be used whenever MPI operations touch
+ // shared structures.
+ mutex mu;
+
+ // Tensors waiting to be allreduced or allgathered.
+ TensorTable tensor_table;
+
+ // Queue of MPI requests waiting to be sent to the coordinator node.
+ std::queue<MPIRequest> message_queue;
+
+ // Background thread running MPI communication.
+ std::thread background_thread;
+
+ // Whether the background thread should shutdown.
+ bool shut_down = false;
+
+ // Only exists on the coordinator node (rank zero). Maintains a count of
+ // how many nodes are ready to allreduce every tensor (keyed by tensor
+ // name).
+ std::unique_ptr<MessageTable> message_table;
+
+ // The MPI rank, local rank, and size.
+ int rank = 0;
+ int local_rank = 0;
+ int size = 1;
+
+ // The device that MPI was initialized on. (-1 for no GPU)
+ int device = -1;
+
+ // The CUDA stream used for data transfers and within-allreduce operations.
+ // A naive implementation would use the TensorFlow StreamExecutor CUDA
+ // stream. However, the allreduce and allgather require doing memory copies
+ // and kernel executions (for accumulation of values on the GPU). However,
+ // the subsequent operations must wait for those operations to complete,
+ // otherwise MPI (which uses its own stream internally) will begin the data
+ // transfers before the CUDA calls are complete. In order to wait for those
+ // CUDA operations, if we were using the TensorFlow stream, we would have
+ // to synchronize that stream; however, other TensorFlow threads may be
+ // submitting more work to that stream, so synchronizing on it can cause
+ // the allreduce to be delayed, waiting for compute totally unrelated to it
+ // in other parts of the graph. Overlaying memory transfers and compute
+ // during backpropagation is crucial for good performance, so we cannot use
+ // the TensorFlow stream, and must use our own stream.
+#if GOOGLE_CUDA
+ cudaStream_t stream;
+ std::atomic_flag stream_created_flag = ATOMIC_FLAG_INIT;
+#endif
+
+ ~MPIGlobalState() {
+ // Make sure that the destructor of the background thread is safe to
+ // call. If a thread is still joinable (not detached or complete) its
+ // destructor cannot be called.
+ if (background_thread.joinable()) {
+ shut_down = true;
+ background_thread.join();
+ }
+ }
+};
+
+// All the MPI state that must be stored globally per-process.
+static MPIGlobalState mpi_global;
+
+// For clarify in argument lists.
+#define RANK_ZERO 0
+
+// A tag used for all coordinator messaging.
+#define TAG_NOTIFY 1
+
+// Store the MPIRequest for a name, and return whether the total count of
+// MPIRequests for that tensor is now equal to the MPI size (and thus we are
+// ready to reduce the tensor).
+bool IncrementTensorCount(std::unique_ptr<MessageTable>& message_table,
+ MPIRequest msg, int mpi_size) {
+ auto name = msg.tensor_name();
+ auto table_iter = message_table->find(name);
+ if (table_iter == message_table->end()) {
+ message_table->emplace(name, std::vector<MPIRequest>({msg}));
+ table_iter = message_table->find(name);
+ } else {
+ table_iter->second.push_back(msg);
+ }
+
+ int count = table_iter->second.size();
+ return count == mpi_size;
+}
+
+// Once a tensor is ready to be reduced, the coordinator sends an MPIResponse
+// instructing all ranks to start the reduction to all ranks. The MPIResponse
+// also contains error messages in case the submitted MPIRequests were not
+// valid (for example, contained mismatched shapes or types).
+//
+// Constructing the MPIResponse, thus, requires a whole lot of error checking.
+MPIResponse ConstructMPIResponse(std::unique_ptr<MessageTable>& message_table,
+ std::string name) {
+ bool error = false;
+ auto it = message_table->find(name);
+ assert(it != message_table->end());
+
+ std::vector<MPIRequest> requests = it->second;
+ assert(requests.size() > 0);
+
+ std::ostringstream error_message_stream;
+
+ // Check that all data types being reduced or gathered are identical
+ auto data_type = requests[0].tensor_type();
+ for (unsigned int i = 1; i < requests.size(); i++) {
+ auto request_type = requests[i].tensor_type();
+ if (data_type != request_type) {
+ error = true;
+ error_message_stream << "Mismatched data types: One rank had type "
+ << DataType_Name(data_type)
+ << ", but another rank had type "
+ << DataType_Name(request_type) << ".";
+ break;
+ }
+ }
+
+ // Check that all requested operations are the same
+ auto message_type = requests[0].request_type();
+ for (unsigned int i = 1; i < requests.size(); i++) {
+ if (error) {
+ break;
+ }
+
+ auto request_type = requests[i].request_type();
+ if (message_type != request_type) {
+ error = true;
+ error_message_stream << "Mismatched MPI operations: One rank did an "
+ << message_type << ", but another rank did an "
+ << request_type << ".";
+ break;
+ }
+ }
+
+ // If we are doing an allreduce, check that all tensor shapes
+ // are identical
+ if (message_type == MPIRequest::ALLREDUCE) {
+ TensorShape tensor_shape = requests[0].tensor_shape();
+ for (unsigned int i = 1; i < requests.size(); i++) {
+ if (error) {
+ break;
+ }
+
+ TensorShape request_shape = requests[i].tensor_shape();
+ if (tensor_shape != request_shape) {
+ error = true;
+ error_message_stream << "Mismatched allreduce tensor shapes: "
+ << "One rank reduced a tensor of shape "
+ << tensor_shape.DebugString()
+ << ", but another rank sent a tensor of shape "
+ << request_shape.DebugString() << ".";
+ break;
+ }
+ }
+ }
+
+ // If we are doing an allgather, make sure all but the first dimension are
+ // the same. The first dimension may be different and the output tensor is
+ // the sum of the first dimension. Collect the sizes by rank.
+ if (message_type == MPIRequest::ALLGATHER) {
+ TensorShape tensor_shape = requests[0].tensor_shape();
+
+ if (tensor_shape.dims() == 0) {
+ error = true;
+ error_message_stream << "Rank zero tried to gather a rank-zero tensor.";
+ }
+
+ for (unsigned int i = 1; i < requests.size(); i++) {
+ if (error) {
+ break;
+ }
+
+ TensorShape request_shape = requests[i].tensor_shape();
+ if (tensor_shape.dims() != request_shape.dims()) {
+ error = true;
+ error_message_stream << "Mismatched allgather tensor shapes: "
+ << "One rank gathered a tensor of rank "
+ << tensor_shape.dims()
+ << ", but another rank sent a tensor of rank "
+ << request_shape.dims() << ".";
+ break;
+ }
+
+ for (unsigned int dim = 1; dim < tensor_shape.dims(); dim++) {
+ if (tensor_shape.dim_size(dim) != request_shape.dim_size(dim)) {
+ error = true;
+ error_message_stream
+ << "Mismatched allgather tensor shapes: "
+ << "One rank gathered a tensor with dimension " << dim
+ << " equal to " << tensor_shape.dim_size(dim)
+ << ", but another rank sent a tensor with dimension " << dim
+ << " equal to " << request_shape.dim_size(dim) << ".";
+ break;
+ }
+ }
+ }
+ }
+
+ MPIResponse response;
+ response.set_tensor_name(name);
+ if (error) {
+ std::string error_message = error_message_stream.str();
+ response.set_response_type(MPIResponse::ERROR);
+ response.set_error_message(error_message);
+ } else {
+ auto response_type = MPIResponse::ERROR;
+ if (message_type == MPIRequest::ALLREDUCE) {
+ response_type = MPIResponse::ALLREDUCE;
+ } else {
+ response_type = MPIResponse::ALLGATHER;
+ }
+ response.set_response_type(response_type);
+ }
+
+ // Clear all queued up requests for this name. They are now taken care of
+ // by the constructed MPI response.
+ message_table->erase(it);
+
+ return response;
+}
+
+// Process an MPIResponse by doing a reduction, a gather, or raising an error.
+void PerformCollectiveOp(TensorTable& tensor_table, MPIResponse response) {
+ OpKernelContext* context;
+ const Tensor* input_tensor;
+ std::vector<size_t> sizes_vec;
+ Tensor temp_tensor;
+ Tensor* output_tensor;
+ CommunicationDoneCallback callback;
+ bool on_gpu;
+ {
+ // Lock on the tensor table.
+ mutex_lock guard(mpi_global.mu);
+
+ // We should never fail at finding this key in the tensor table.
+ auto name = response.tensor_name();
+ auto iter = tensor_table.find(name);
+ assert(iter != tensor_table.end());
+
+ assert(response.response_type() == MPIResponse::ALLREDUCE ||
+ response.response_type() == MPIResponse::ALLGATHER ||
+ response.response_type() == MPIResponse::ERROR);
+
+ CollectiveOpRecord record = iter->second;
+ context = record.context;
+ input_tensor = record.in_t;
+ sizes_vec = record.sizes_vec;
+ temp_tensor = record.temp_t;
+ output_tensor = record.out_t;
+ on_gpu = record.on_gpu;
+ callback = record.callback;
+
+ // Clear the tensor table of this tensor and its callbacks; the rest of
+ // this function takes care of it.
+ tensor_table.erase(iter);
+ }
+
+ // Use CPUDevice instead of GPUDevice if no CUDA, to ensure we don't
+ // link to non-existent symbols.
+#if GOOGLE_CUDA
+#define GPU_DEVICE_IF_CUDA GPUDevice
+#else
+#define GPU_DEVICE_IF_CUDA CPUDevice
+#endif
+
+ Status status;
+ auto dtype = input_tensor->dtype();
+ if (response.response_type() == MPIResponse::ALLGATHER) {
+ if (dtype == DT_FLOAT) {
+ status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, float>(
+ context, input_tensor, sizes_vec, output_tensor)
+ : RingAllgather<CPUDevice, float>(
+ context, input_tensor, sizes_vec, output_tensor);
+ } else if (dtype == DT_INT32) {
+ status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, int>(
+ context, input_tensor, sizes_vec, output_tensor)
+ : RingAllgather<CPUDevice, int>(context, input_tensor,
+ sizes_vec, output_tensor);
+ } else if (dtype == DT_INT64) {
+ status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, long long>(
+ context, input_tensor, sizes_vec, output_tensor)
+ : RingAllgather<CPUDevice, long long>(
+ context, input_tensor, sizes_vec, output_tensor);
+ } else {
+ status = errors::Unknown("Invalid tensor type for MPI allgather.");
+ }
+ } else if (response.response_type() == MPIResponse::ALLREDUCE) {
+ if (dtype == DT_FLOAT) {
+ status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, float>(
+ context, input_tensor, &temp_tensor, output_tensor)
+ : RingAllreduce<CPUDevice, float>(
+ context, input_tensor, &temp_tensor, output_tensor);
+ } else if (dtype == DT_INT32) {
+ status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, int>(
+ context, input_tensor, &temp_tensor, output_tensor)
+ : RingAllreduce<CPUDevice, int>(
+ context, input_tensor, &temp_tensor, output_tensor);
+ } else if (dtype == DT_INT64) {
+ status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, long long>(
+ context, input_tensor, &temp_tensor, output_tensor)
+ : RingAllreduce<CPUDevice, long long>(
+ context, input_tensor, &temp_tensor, output_tensor);
+ } else {
+ status = errors::Unknown("Invalid tensor type for MPI allreduce.");
+ }
+ } else if (response.response_type() == MPIResponse::ERROR) {
+ status = errors::FailedPrecondition(response.error_message());
+ }
+
+ if (status.ok()) {
+ callback(StatusOr<Tensor>(*output_tensor));
+ } else {
+ callback(StatusOr<Tensor>(status));
+ }
+}
+
+// The MPI background thread loop coordinates all the MPI processes and the
+// tensor reductions. The design of the communicator mechanism is limited by a
+// few considerations:
+//
+// 1. Some MPI implementations require all MPI calls to happen from a
+// single thread. Since TensorFlow may use several threads for graph
+// processing, this means we must have our own dedicated thread for
+// dealing with MPI.
+// 2. We want to gracefully handle errors, when MPI processes do not
+// properly agree upon what should happen (such as mismatched types or
+// shapes). To do so requires the MPI processes to know about the shapes
+// and types of the relevant tensors on the other processes.
+// 3. The MPI reductions and gathers should be able to happen in parallel
+// with other ongoing operations. Since MPI uses an internal
+// (inaccessible) GPU stream separate from the TF GPUDevice streams, we
+// cannot explicitly synchronize memcpys or kernels with it. As a result,
+// MPIAllreduce and MPIAllgather must be AsyncOpKernels to ensure proper
+// ordering of memcpys and kernels with respect to TF streams.
+// 4. NOTE: We cannot guarantee that all the MPI processes reduce their
+// tensors in the same order. Thus, there must be a way to ensure the
+// reduction memcpys and kernels occur for correct tensors across all
+// ranks at the same time. We choose to use a coordinator (rank ID 0) to
+// gather and trigger the reduction operations that are ready to execute.
+//
+// The coordinator currently follows a master-worker paradigm. Rank zero acts
+// as the master (the "coordinator"), whereas all other ranks are simply
+// workers. Each rank runs its own background thread which progresses in ticks.
+// In each tick, the following actions happen:
+//
+// a) The workers send any available MPIRequests to the coordinator. These
+// MPIRequests indicate what the worker would like to do (i.e. which
+// tensor they would like to gather or reduce, as well as their shape and
+// type). They repeat this for every tensor that they would like to
+// operate on after that tensor's collective op has executed ComputeAsync.
+//
+// b) The workers send an empty "DONE" message to the coordinator to
+// indicate that there are no more tensors they wish to operate on.
+//
+// c) The coordinator receives the MPIRequests from the workers, as well
+// as from its own TensorFlow ops, and stores them in a request table. The
+// coordinator continues to receive MPIRequest messages until it has
+// received MPI_SIZE number of empty "DONE" messages.
+//
+// d) The coordinator finds all tensors that are ready to be reduced,
+// gathered, or all operations that result in an error. For each of those,
+// it sends an MPIResponse to all the workers. When no more MPIResponses
+// are available, it sends a "DONE" response to the workers. If the
+// process is being shutdown, it instead sends a "SHUTDOWN" response.
+//
+// e) The workers listen for MPIResponse messages, processing each one by
+// doing the required reduce or gather, until they receive a "DONE"
+// response from the coordinator. At that point, the tick ends.
+// If instead of "DONE" they receive "SHUTDOWN", they exit their
+// background loop.
+// TODO: Use the global mpi_global state variable instead of a local one
+void BackgroundThreadLoop() {
+#if GOOGLE_CUDA
+ // Set the device, so that this thread uses the same GPU context as the
+ // calling thread.
+ // TODO: Ensure that this is operating correctly. The background thread
+ // needs to be able to control all GPUs that the rank has access to, and
+ // might be more than 1 GPU. Tensors could be resident in any of the
+ // GPUs, so the background thread's accumulate and copy kernels might need
+ // to correctly set the device and it might be necessary for the background
+ // thread to manage multiple streams.
+ cudaSetDevice(mpi_global.device);
+ cudaStreamCreate(&mpi_global.stream);
+#endif
+
+ // Initialize MPI. This must happen on the background thread, since not all
+ // MPI implementations support being called from multiple threads.
+ auto init_result = MPI_Init(NULL, NULL);
+ if (init_result != MPI_SUCCESS) {
+ mpi_global.init_status =
+ errors::Unknown("Could not initialize MPI; MPI_Init() failed.");
+ mpi_global.initialization_done = true;
+ mpi_global.cv.notify_all();
+ return;
+ } else {
+ mpi_global.init_status = Status::OK();
+ }
+
+ // Get MPI rank to determine if we are rank zero.
+ int rank;
+ MPI_Comm_rank(MPI_COMM_WORLD, &rank);
+ bool is_coordinator = rank == 0;
+
+ // Get MPI size to determine how many tensors to wait for before reducing.
+ int size;
+ MPI_Comm_size(MPI_COMM_WORLD, &size);
+
+ // Determine local rank by querying the local communicator.
+ MPI_Comm local_comm;
+ MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL,
+ &local_comm);
+ int local_rank;
+ MPI_Comm_rank(local_comm, &local_rank);
+
+ mpi_global.rank = rank;
+ mpi_global.local_rank = local_rank;
+ mpi_global.size = size;
+ mpi_global.initialization_done = true;
+
+ // Notify calling thread that initialization is complete
+ mpi_global.cv.notify_all();
+
+ // TODO: MOVE MESSAGE TABLE INITIALIZATION TO LIBRARY LOAD!
+ // Initialize the tensor count table. No tensors are available yet.
+ if (is_coordinator) {
+ mpi_global.message_table =
+ std::unique_ptr<MessageTable>(new MessageTable());
+ }
+
+ // The coordinator sends a SHUTDOWN message to trigger shutdown.
+ bool should_shut_down = false;
+ do {
+ // TODO: Eliminate the need for thread sleep by making all activity
+ // depend on other activity (e.g. condition or MPI waits).
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+
+ // Copy the data structures from global state under this lock.
+ // However, don't keep the lock for the rest of the loop, so that
+ // enqueued stream callbacks can continue.
+ std::queue<MPIRequest> message_queue;
+ {
+ mutex_lock guard(mpi_global.mu);
+ while (!mpi_global.message_queue.empty()) {
+ MPIRequest message = mpi_global.message_queue.front();
+ mpi_global.message_queue.pop();
+ message_queue.push(message);
+ }
+ }
+
+ // Collect all tensors that are ready to be reduced. Record them in the
+ // tensor count table (rank zero) or send them to rank zero to be
+ // recorded (everyone else).
+ std::vector<std::string> ready_to_reduce;
+ while (!message_queue.empty()) {
+ // Pop the first available message message
+ MPIRequest message = message_queue.front();
+ message_queue.pop();
+
+ if (is_coordinator) {
+ bool reduce =
+ IncrementTensorCount(mpi_global.message_table, message, size);
+ if (reduce) {
+ ready_to_reduce.push_back(message.tensor_name());
+ }
+ } else {
+ std::string encoded_message;
+ message.SerializeToString(&encoded_message);
+ MPI_Send(encoded_message.c_str(), encoded_message.length() + 1,
+ MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
+ }
+ }
+
+ // Rank zero has put all its own tensors in the tensor count table.
+ // Now, it should count all the tensors that are coming from other
+ // ranks at this tick. It should keep getting tensors until it gets a
+ // DONE message from all the other ranks.
+ if (is_coordinator) {
+ // Count of DONE messages. Keep receiving messages until the number
+ // of messages is equal to the number of processes. Initialize to
+ // one since the coordinator is effectively done.
+ int completed_ranks = 1;
+ while (completed_ranks != size) {
+ MPI_Status status;
+ MPI_Probe(MPI_ANY_SOURCE, TAG_NOTIFY, MPI_COMM_WORLD, &status);
+
+ // Find number of characters in message (including zero byte).
+ int source_rank = status.MPI_SOURCE;
+ int msg_length;
+ MPI_Get_count(&status, MPI_BYTE, &msg_length);
+
+ // If the length is zero, this is a DONE message.
+ if (msg_length == 0) {
+ completed_ranks++;
+ MPI_Recv(NULL, 0, MPI_BYTE, source_rank, TAG_NOTIFY, MPI_COMM_WORLD,
+ &status);
+ continue;
+ }
+
+ // Get tensor name from MPI into an std::string.
+ char* buffer = new char[msg_length];
+ MPI_Recv(buffer, msg_length, MPI_BYTE, source_rank, TAG_NOTIFY,
+ MPI_COMM_WORLD, &status);
+ std::string received_data(buffer);
+ delete[] buffer;
+
+ MPIRequest received_message;
+ received_message.ParseFromString(received_data);
+ auto received_name = received_message.tensor_name();
+
+ bool reduce = IncrementTensorCount(mpi_global.message_table,
+ received_message, size);
+ if (reduce) {
+ ready_to_reduce.push_back(received_name);
+ }
+ }
+
+ // At this point, rank zero should have a fully updated tensor
+ // count table and should know all the tensors that need to be
+ // reduced or gathered, and everyone else should have sent all
+ // their information to rank zero. We can now do reductions and
+ // gathers; rank zero will choose which ones and in what order,
+ // and will notify the other ranks before doing each reduction.
+ for (int i = 0; i < ready_to_reduce.size(); i++) {
+ // Notify all nodes which tensor we'd like to reduce now
+ auto name = ready_to_reduce[i];
+ MPIResponse response =
+ ConstructMPIResponse(mpi_global.message_table, name);
+
+ std::string encoded_response;
+ response.SerializeToString(&encoded_response);
+ for (int r = 1; r < size; r++) {
+ MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
+ MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
+ }
+
+ // Perform the reduction. All nodes should end up performing
+ // the same reduction.
+ PerformCollectiveOp(mpi_global.tensor_table, response);
+ }
+
+ // Notify all nodes that we are done with the reductions for this
+ // tick.
+ MPIResponse done_response;
+ should_shut_down = mpi_global.shut_down;
+ done_response.set_response_type(
+ mpi_global.shut_down ? MPIResponse::SHUTDOWN : MPIResponse::DONE);
+ std::string encoded_response;
+ done_response.SerializeToString(&encoded_response);
+ for (int r = 1; r < size; r++) {
+ MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
+ MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
+ }
+ } else {
+ // Notify the coordinator that this node is done sending messages.
+ // A DONE message is encoded as a zero-length message.
+ MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
+
+ // Receive names for tensors to reduce from rank zero. Once we
+ // receive a empty DONE message, stop waiting for more names.
+ while (true) {
+ MPI_Status status;
+ MPI_Probe(0, TAG_NOTIFY, MPI_COMM_WORLD, &status);
+
+ // Find number of characters in message (including zero byte).
+ int msg_length;
+ MPI_Get_count(&status, MPI_BYTE, &msg_length);
+
+ // Get tensor name from MPI into an std::string.
+ char* buffer = new char[msg_length];
+ MPI_Recv(buffer, msg_length, MPI_BYTE, 0, TAG_NOTIFY, MPI_COMM_WORLD,
+ &status);
+ std::string received_message(buffer);
+ delete[] buffer;
+
+ MPIResponse response;
+ response.ParseFromString(received_message);
+ if (response.response_type() == MPIResponse::DONE) {
+ // No more messages this tick
+ break;
+ } else if (response.response_type() == MPIResponse::SHUTDOWN) {
+ // No more messages this tick, and the background thread
+ // should shut down
+ should_shut_down = true;
+ break;
+ } else {
+ // Process the current message
+ PerformCollectiveOp(mpi_global.tensor_table, response);
+ }
+ }
+ }
+ } while (!should_shut_down);
+
+ MPI_Finalize();
+}
+
+// Initialize MPI and start the MPI background thread. Ensure that this is
+// only done once no matter how many times this function is called.
+Status InitializeMPIOnce(bool gpu) {
+ // Ensure MPI is only initialized once.
+ if (mpi_global.initialized_flag.test_and_set()) return mpi_global.init_status;
+
+ mpi_global.device = -1;
+#if GOOGLE_CUDA
+ if (gpu) {
+ cudaGetDevice(&mpi_global.device);
+ }
+#endif
+
+ // Start the MPI background thread, which assumes MPI is initialized
+ // TODO: Change this to a Tensorflow thread
+ mpi_global.background_thread = std::thread(BackgroundThreadLoop);
+
+ // Wait to ensure that the background thread has finished initializing MPI
+ mutex_lock guard(mpi_global.mu);
+ mpi_global.cv.wait(guard);
+ if (!mpi_global.initialization_done) {
+ mpi_global.init_status =
+ errors::Unknown("Failed to wait for MPI initialization.");
+ }
+
+ return mpi_global.init_status;
+}
+
+// Check that MPI is initialized.
+Status IsMPIInitialized() {
+ if (!mpi_global.initialization_done) {
+ return errors::FailedPrecondition(
+ "MPI has not been initialized; use tf.contrib.mpi.Session.");
+ }
+ return Status::OK();
+}
+
+// This function (called from the callback set up in MPIAll*Op::ComputeAsync)
+// only adds the op's record into the local op queue (to track the op's
+// progress), and sends a message to the coordinator indicating that this rank
+// is ready to begin. The MPI background thread will handle the MPI message.
+void EnqueueTensorCollective(CollectiveOpRecord record,
+ MPIRequest::RequestType rtype) {
+ const Tensor* input_tensor = record.in_t;
+ MPIRequest message;
+ message.set_request_rank(record.rank);
+ message.set_tensor_name(record.name);
+ message.set_tensor_type(record.dtype);
+ message.set_request_type(rtype);
+ input_tensor->shape().AsProto(message.mutable_tensor_shape());
+
+ mutex_lock guard(mpi_global.mu);
+ mpi_global.tensor_table.emplace(record.name, record);
+ mpi_global.message_queue.push(message);
+}
+
+} // namespace
+
+#if GOOGLE_CUDA
+cudaStream_t CudaStreamForMPI() { return mpi_global.stream; }
+#endif
+
+// Op to initialize MPI in the current process. The settings used in the
+// configuration are the same that must be used for all future MPI ops.
+template <typename Device>
+class MPIInitOp : public OpKernel {
+ public:
+ explicit MPIInitOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ bool on_gpu = IsGPUDevice<Device>();
+ OP_REQUIRES_OK(context, InitializeMPIOnce(on_gpu));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_CPU),
+ MPIInitOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_GPU),
+ MPIInitOp<GPUDevice>);
+#endif
+
+REGISTER_OP("MPIInit").Doc(R"doc(
+Initialize MPI for the current process.
+
+If this is run on a GPU, then that GPU must be used for all future MPI
+operations. If it is run on CPU, then all future MPI operations must also
+run on CPU.
+)doc");
+
+// Op to get the current MPI Size.
+template <typename Device>
+class MPISizeOp : public OpKernel {
+ public:
+ explicit MPISizeOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ OP_REQUIRES_OK(context, IsMPIInitialized());
+
+ // Write integer to output tensor
+ Tensor* output;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({}), &output));
+
+ auto flat = output->flat<int>();
+ flat(0) = mpi_global.size;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_CPU),
+ MPISizeOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_GPU).HostMemory("size"),
+ MPISizeOp<GPUDevice>);
+#endif
+
+REGISTER_OP("MPISize")
+ .Output("size: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Returns the number of running MPI processes.
+
+More precisely, returns the number of MPI processes in the group associated
+with the MPI_COMM_WORLD communicator.
+
+size: Size of the MPI group.
+)doc");
+
+// Op to get the current MPI Rank.
+template <typename Device>
+class MPIRankOp : public OpKernel {
+ public:
+ explicit MPIRankOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ OP_REQUIRES_OK(context, IsMPIInitialized());
+
+ // Write integer to output tensor
+ Tensor* output;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({}), &output));
+
+ auto flat = output->flat<int>();
+ flat(0) = mpi_global.rank;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_CPU),
+ MPIRankOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_GPU).HostMemory("rank"),
+ MPIRankOp<GPUDevice>);
+#endif
+
+REGISTER_OP("MPIRank")
+ .Output("rank: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Returns the index of the current process in the MPI group.
+
+More precisely, returns the rank of the calling process in the MPI_COMM_WORLD
+communicator.
+
+rank: Rank of the calling process.
+)doc");
+
+// Op to get the current local MPI Rank.
+template <typename Device>
+class MPILocalRankOp : public OpKernel {
+ public:
+ explicit MPILocalRankOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ OP_REQUIRES_OK(context, IsMPIInitialized());
+
+ // Write integer to output tensor
+ Tensor* output;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({}), &output));
+
+ auto flat = output->flat<int>();
+ flat(0) = mpi_global.local_rank;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPILocalRank").Device(DEVICE_CPU),
+ MPILocalRankOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(
+ Name("MPILocalRank").Device(DEVICE_GPU).HostMemory("rank"),
+ MPILocalRankOp<GPUDevice>);
+#endif
+
+REGISTER_OP("MPILocalRank")
+ .Output("rank: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Returns the index of the current process in the node it is on.
+
+More precisely, returns the rank of the calling process in communicator that
+only spans the MPI processes running on that node.
+
+rank: Rank of the calling process on the node it is on.
+)doc");
+
+template <typename Device>
+class MPIAllreduceOp : public AsyncOpKernel {
+ public:
+ explicit MPIAllreduceOp(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {}
+
+ // Although this op is handled asynchronously, the ComputeAsync call is
+ // very inexpensive. It only sets up a CollectiveOpRecord and places it
+ // in the table for the background thread to handle. Thus, we do not need
+ // a TF pool thread to perform the op.
+ bool IsExpensive() override { return false; }
+
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
+ OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done);
+ const Tensor* input_tensor = &context->input(0);
+ Tensor* output_tensor;
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ context->allocate_output(0, input_tensor->shape(), &output_tensor),
+ done);
+
+ // Record allocated on stack so op can fail without memory leak
+ CollectiveOpRecord record;
+ record.name = name();
+ record.context = context;
+ record.in_t = input_tensor;
+ record.out_t = output_tensor;
+ record.on_gpu = IsGPUDevice<Device>();
+ record.dtype = input_tensor->dtype();
+
+ const size_t temp_size =
+ (input_tensor->NumElements() + mpi_global.size - 1) / mpi_global.size;
+ TensorShape temp_shape;
+ temp_shape.AddDim(temp_size);
+ OP_REQUIRES_OK_ASYNC(context,
+ context->allocate_temp(input_tensor->dtype(),
+ temp_shape, &record.temp_t),
+ done);
+
+ auto allreduce_done_callback = [done, context](StatusOr<Tensor> status) {
+ context->SetStatus(status.status());
+ done();
+ };
+ record.callback = allreduce_done_callback;
+
+ auto allreduce_launch_callback = [record] {
+ EnqueueTensorCollective(record, MPIRequest::ALLREDUCE);
+ };
+
+ // If we are on a CPU, our device context will be null and we can't
+ // get a stream to enqueue this on. On a CPU this op is called when the
+ // data is already available, so we can just immediately do the
+ // allreduce; we don't have to wait for the data to get populated.
+#if GOOGLE_CUDA
+ auto device_context = context->op_device_context();
+ if (device_context == nullptr) {
+ allreduce_launch_callback();
+ } else {
+ auto stream = device_context->stream();
+ stream->ThenDoHostCallback(allreduce_launch_callback);
+ }
+#else
+ allreduce_launch_callback();
+#endif
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_CPU),
+ MPIAllreduceOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_GPU),
+ MPIAllreduceOp<GPUDevice>);
+#endif
+
+REGISTER_OP("MPIAllreduce")
+ .Attr("T: {int32, int64, float32}")
+ .Input("tensor: T")
+ .Output("sum: T")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Perform an MPI Allreduce on a tensor. All other processes that do a reduction
+on a tensor with the same name must have the same dimension for that tensor.
+Tensors are reduced with other tensors that have the same node name for the
+allreduce.
+
+Arguments
+ tensor: A tensor to reduce.
+
+Output
+ sum: A tensor with the same shape as `tensor`, summed across all
+ MPI processes.
+)doc");
+
+template <typename Device>
+class MPIAllgatherOp : public AsyncOpKernel {
+ public:
+ explicit MPIAllgatherOp(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {}
+
+ // Although this op is handled asynchronously, the ComputeAsync call is
+ // very inexpensive. It only sets up a CollectiveOpRecord and places it
+ // in the table for the background thread to handle. Thus, we do not need
+ // a TF pool thread to perform the op.
+ bool IsExpensive() override { return false; }
+
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
+ OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done);
+ const Tensor* input_tensor = &context->input(0);
+ const Tensor* sizing_tensor = &context->input(1);
+
+ // Record allocated on stack so op can fail without memory leak
+ CollectiveOpRecord record;
+ record.name = name();
+ record.context = context;
+ record.in_t = input_tensor;
+ record.on_gpu = IsGPUDevice<Device>();
+
+ // Construct the output size from the sizing tensor
+ size_t output_first_dim = 0;
+ if (sizing_tensor->shape().dims() == 0) {
+ // 0-dim sizing_tensor implies that the op is just gathering
+ // a single element from each rank
+ output_first_dim = mpi_global.size;
+ for (int i = 0; i < mpi_global.size; i++) {
+ record.sizes_vec.push_back(1);
+ }
+ } else {
+ // Collect the total output tensor sizing from the sizing tensor
+ // NOTE: The sizing tensor is forced to be placed on the CPU by
+ // declaring the input as HostMemory, so it is valid to read it here.
+ const int64* sizing_array =
+ (const int64*)sizing_tensor->tensor_data().data();
+ for (int i = 0; i < mpi_global.size; i++) {
+ record.sizes_vec.push_back(sizing_array[i]);
+ output_first_dim += sizing_array[i];
+ }
+ }
+
+ TensorShape output_shape;
+ output_shape.AddDim(output_first_dim);
+ for (int i = 1; i < input_tensor->shape().dims(); i++) {
+ output_shape.AddDim(input_tensor->shape().dim_size(i));
+ }
+
+ Tensor* output_tensor;
+ OP_REQUIRES_OK_ASYNC(
+ context, context->allocate_output(0, output_shape, &output_tensor),
+ done);
+
+ record.out_t = output_tensor;
+ record.dtype = input_tensor->dtype();
+
+ auto allgather_done_callback = [done, context](StatusOr<Tensor> status) {
+ context->SetStatus(status.status());
+ done();
+ };
+ record.callback = allgather_done_callback;
+
+ auto allgather_launch_callback = [record] {
+ EnqueueTensorCollective(record, MPIRequest::ALLGATHER);
+ };
+
+ // If we are on a CPU, our device context will be null and we can't
+ // get a stream to enqueue this on. On a CPU this op is called when the
+ // data is already available, so we can just immediately do the
+ // allgather; we don't have to wait for the data to get populated.
+#if GOOGLE_CUDA
+ auto device_context = context->op_device_context();
+ if (device_context == nullptr) {
+ allgather_launch_callback();
+ } else {
+ auto stream = device_context->stream();
+ stream->ThenDoHostCallback(allgather_launch_callback);
+ }
+#else
+ allgather_launch_callback();
+#endif
+ }
+};
+
+REGISTER_OP("MPIAllgather")
+ .Attr("T: {int32, int64, float32}")
+ .Attr("S: {int64}")
+ .Input("tensor: T")
+ .Input("sizes: S")
+ .Output("gathered: T")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle output;
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &output));
+ c->set_output(0, output);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Perform an MPI Allgather on a tensor. All other processes that do a gather on a
+tensor with the same name must have the same rank for that tensor, and have the
+same dimension on all but the first dimension.
+
+Arguments
+ tensor: A tensor to gather.
+ sizes: A tensor containing the first-dimension sizes of tensors to be
+ gathered from other ranks
+
+Output
+ gathered: A tensor with the same shape as `tensor` except for the first
+ dimension, which is the sum of dimensions in `sizes`.
+)doc");
+
+REGISTER_KERNEL_BUILDER(
+ Name("MPIAllgather").Device(DEVICE_CPU).HostMemory("sizes"),
+ MPIAllgatherOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(
+ Name("MPIAllgather").Device(DEVICE_GPU).HostMemory("sizes"),
+ MPIAllgatherOp<GPUDevice>);
+#endif
+
+} // namespace mpi
+} // namespace contrib
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_MPI
diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops.py b/tensorflow/contrib/mpi_collectives/mpi_ops.py
new file mode 100644
index 0000000000..81567cc688
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/mpi_ops.py
@@ -0,0 +1,165 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Inter-process communication using MPI."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import load_library
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import tf_logging as logging
+
+
+def _load_library(name, op_list=None):
+ """Loads a .so file containing the specified operators.
+
+ Args:
+ name: The name of the .so file to load.
+ op_list: A list of names of operators that the library should have. If None
+ then the .so file's contents will not be verified.
+
+ Raises:
+ NameError if one of the required ops is missing.
+ """
+ try:
+ filename = resource_loader.get_path_to_datafile(name)
+ library = load_library.load_op_library(filename)
+ for expected_op in (op_list or []):
+ for lib_op in library.OP_LIST.op:
+ if lib_op.name == expected_op:
+ break
+ else:
+ raise NameError(
+ 'Could not find operator %s in dynamic library %s' %
+ (expected_op, name))
+ return library
+ except errors.NotFoundError:
+ logging.warning('%s file could not be loaded.', name)
+
+
+MPI_LIB = _load_library('mpi_collectives.so', ['MPISize', 'MPIRank',
+ 'MPILocalRank', 'MPIAllgather',
+ 'MPIAllreduce'])
+
+
+def size(name=None):
+ """An op which returns the number of MPI processes.
+
+ This is equivalent to running `MPI_Comm_size(MPI_COMM_WORLD, ...)` to get the
+ size of the global communicator.
+
+ Returns:
+ An integer scalar containing the number of MPI processes.
+ """
+ return MPI_LIB.mpi_size(name=name)
+
+
+ops.NotDifferentiable('MPISize')
+
+
+def rank(name=None):
+ """An op which returns the MPI rank of the calling process.
+
+ This is equivalent to running `MPI_Comm_rank(MPI_COMM_WORLD, ...)` to get the
+ rank of the current process in the global communicator.
+
+ Returns:
+ An integer scalar with the MPI rank of the calling process.
+ """
+ return MPI_LIB.mpi_rank(name=name)
+
+
+ops.NotDifferentiable('MPIRank')
+
+
+def init(name=None):
+ """An op which initializes MPI on the device on which it is run.
+
+ All future MPI ops must be run on the same device that the `init` op was run
+ on.
+ """
+ return MPI_LIB.mpi_init(name=name)
+
+
+ops.NotDifferentiable('MPIInit')
+
+
+def local_rank(name=None):
+ """An op which returns the local MPI rank of the calling process, within the
+ node that it is running on. For example, if there are seven processes running
+ on a node, their local ranks will be zero through six, inclusive.
+
+ This is equivalent to running `MPI_Comm_rank(...)` on a new communicator
+ which only includes processes on the same node.
+
+ Returns:
+ An integer scalar with the local MPI rank of the calling process.
+ """
+ return MPI_LIB.mpi_local_rank(name=name)
+
+
+ops.NotDifferentiable('MPILocalRank')
+
+
+def _allreduce(tensor, name=None):
+ """An op which sums an input tensor over all the MPI processes.
+
+ The reduction operation is keyed by the name of the op. The tensor type and
+ shape must be the same on all MPI processes for a given name. The reduction
+ will not start until all processes are ready to send and receive the tensor.
+
+ Returns:
+ A tensor of the same shape and type as `tensor`, summed across all
+ processes.
+ """
+ return MPI_LIB.mpi_allreduce(tensor, name=name)
+
+
+ops.NotDifferentiable('MPIAllreduce')
+
+
+def allgather(tensor, name=None):
+ """An op which concatenates the input tensor with the same input tensor on
+ all other MPI processes.
+
+ The concatenation is done on the first dimension, so the input tensors on the
+ different processes must have the same rank and shape, except for the first
+ dimension, which is allowed to be different.
+
+ Returns:
+ A tensor of the same type as `tensor`, concatenated on dimension zero
+ across all processes. The shape is identical to the input shape, except for
+ the first dimension, which may be greater and is the sum of all first
+ dimensions of the tensors in different MPI processes.
+ """
+ # Specify that first allgather is to collect the tensor gather sizes,
+ # indicated by passing in a scalar (0-D tensor) of value 0
+ sizes_flag = tf.constant(0, dtype=tf.int64, name="size_flag_const")
+ my_size = tf.slice(tf.shape(tensor, out_type=tf.int64), [0], [1], name="size_slice")
+ if name is None:
+ name = "allgather"
+ sizing_name = "{}_sizing".format(name)
+ sizes = MPI_LIB.mpi_allgather(my_size, sizes_flag, name=sizing_name)
+ return MPI_LIB.mpi_allgather(tensor, sizes, name=name)
+
+
+ops.NotDifferentiable('MPIAllgather')
+
+
diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops_test.py b/tensorflow/contrib/mpi_collectives/mpi_ops_test.py
new file mode 100644
index 0000000000..48e5c0a0c7
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/mpi_ops_test.py
@@ -0,0 +1,296 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Tests for tensorflow.contrib.mpi_collectives.mpi_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+import itertools
+
+import tensorflow as tf
+
+import tensorflow.contrib.mpi_collectives as mpi
+
+
+def mpi_env_rank_and_size():
+ """Get MPI rank and size from environment variables and return them as a
+ tuple of integers.
+
+ Most MPI implementations have an `mpirun` or `mpiexec` command that will
+ run an MPI executable and set up all communication necessary between the
+ different processors. As part of that set up, they will set environment
+ variables that contain the rank and size of the MPI_COMM_WORLD
+ communicator. We can read those environment variables from Python in order
+ to ensure that `mpi.rank()` and `mpi.size()` return the expected values.
+
+ Since MPI is just a standard, not an implementation, implementations
+ typically choose their own environment variable names. This function tries
+ to support several different implementation, but really it only needs to
+ support whatever implementation we want to use for the TensorFlow test
+ suite.
+
+ If this is not running under MPI, then defaults of rank zero and size one
+ are returned. (This is appropriate because when you call MPI_Init in an
+ application not started with mpirun, it will create a new independent
+ communicator with only one process in it.)
+ """
+ rank_env = "PMI_RANK OMPI_COMM_WORLD_RANK".split()
+ size_env = "PMI_SIZE OMPI_COMM_WORLD_SIZE".split()
+
+ for rank_var, size_var in zip(rank_env, size_env):
+ rank = os.environ.get(rank_var)
+ size = os.environ.get(size_var)
+ if rank is not None and size is not None:
+ return int(rank), int(size)
+
+ # Default to rank zero and size one if there are no environment variables
+ return 0, 1
+
+
+class MPITests(tf.test.TestCase):
+ """
+ Tests for MPI ops in tensorflow.contrib.mpi_collectives.
+ """
+
+ def test_mpi_rank(self):
+ """Test that the rank returned by mpi.rank() is correct."""
+ true_rank, _ = mpi_env_rank_and_size()
+ with self.test_session() as session:
+ rank = session.run(mpi.rank())
+ self.assertEqual(true_rank, rank)
+
+ def test_mpi_size(self):
+ """Test that the size returned by mpi.size() is correct."""
+ _, true_size = mpi_env_rank_and_size()
+ with self.test_session() as session:
+ size = session.run(mpi.size())
+ self.assertEqual(true_size, size)
+
+ def test_mpi_allreduce_cpu(self):
+ """Test on CPU that the allreduce correctly sums 1D, 2D, 3D tensors."""
+ with self.test_session() as session:
+ size = session.run(mpi.size())
+
+ dtypes = [tf.int32, tf.float32]
+ dims = [1, 2, 3]
+ for dtype, dim in itertools.product(dtypes, dims):
+ tf.set_random_seed(1234)
+ tensor = tf.random_uniform([17] * dim, -100, 100, dtype=dtype)
+ summed = mpi.allreduce(tensor, average=False)
+ multiplied = tensor * size
+ max_difference = tf.reduce_max(tf.abs(summed - multiplied))
+
+ # Threshold for floating point equality depends on number of
+ # ranks, since we're comparing against precise multiplication.
+ if size <= 3:
+ threshold = 0
+ elif size < 10:
+ threshold = 1e-4
+ elif size < 15:
+ threshold = 5e-4
+ else:
+ break
+
+ diff = session.run(max_difference)
+ self.assertTrue(diff <= threshold,
+ "mpi.allreduce produces incorrect results")
+
+ def test_mpi_allreduce_gpu(self):
+ """Test that the allreduce works on GPUs.
+
+ This test will crash badly if used with an MPI implementation that does
+ not support GPU memory transfers directly, as it will call MPI_Send on
+ a GPU data pointer."""
+ # Only do this test if there are GPUs available.
+ if not tf.test.is_gpu_available(cuda_only=True):
+ return
+
+ no_gpus = tf.GPUOptions(visible_device_list="")
+ cpu_config = tf.ConfigProto(gpu_options=no_gpus)
+ with self.test_session(config=cpu_config) as session:
+ local_rank = session.run(mpi.local_rank())
+
+ one_gpu = tf.GPUOptions(visible_device_list=str(local_rank))
+ gpu_config = tf.ConfigProto(gpu_options=one_gpu)
+ with self.test_session(config=gpu_config) as session:
+ size = session.run(mpi.size())
+
+ dtype = tf.float32
+ dim = 3
+ with tf.device("/gpu:0"):
+ tf.set_random_seed(1234)
+ tensor = tf.random_uniform([17] * dim, -100, 100, dtype=dtype)
+ summed = mpi.allreduce(tensor, average=False)
+ multiplied = tensor * size
+ max_difference = tf.reduce_max(tf.abs(summed - multiplied))
+
+ # Threshold for floating point equality depends on number of
+ # ranks, since we're comparing against precise multiplication.
+ if size <= 3:
+ threshold = 0
+ elif size < 10:
+ threshold = 1e-4
+ elif size < 15:
+ threshold = 5e-4
+ else:
+ return
+
+ diff = session.run(max_difference)
+ self.assertTrue(diff <= threshold,
+ "mpi.allreduce on GPU produces incorrect results")
+
+ def test_mpi_allreduce_error(self):
+ """Test that the allreduce raises an error if different ranks try to
+ send tensors of different rank or dimension."""
+ with self.test_session() as session:
+ rank = session.run(mpi.rank())
+ size = session.run(mpi.size())
+
+ # This test does not apply if there is only one worker.
+ if size == 1:
+ return
+
+ # Same rank, different dimension
+ tf.set_random_seed(1234)
+ dims = [17 + rank] * 3
+ tensor = tf.random_uniform(dims, -1.0, 1.0)
+ with self.assertRaises(tf.errors.FailedPreconditionError):
+ session.run(mpi.allreduce(tensor))
+
+ # Same number of elements, different rank
+ tf.set_random_seed(1234)
+ if rank == 0:
+ dims = [17, 23 * 57]
+ else:
+ dims = [17, 23, 57]
+ tensor = tf.random_uniform(dims, -1.0, 1.0)
+ with self.assertRaises(tf.errors.FailedPreconditionError):
+ session.run(mpi.allreduce(tensor))
+
+ def test_mpi_allreduce_type_error(self):
+ """Test that the allreduce raises an error if different ranks try to
+ send tensors of different type."""
+ with self.test_session() as session:
+ rank = session.run(mpi.rank())
+ size = session.run(mpi.size())
+
+ # This test does not apply if there is only one worker.
+ if size == 1:
+ return
+
+ # Same rank, different dimension
+ dims = [17] * 3
+ tensor = tf.ones(dims, dtype=tf.int32 if rank % 2 == 0 else tf.float32)
+ with self.assertRaises(tf.errors.FailedPreconditionError):
+ session.run(mpi.allreduce(tensor))
+
+ def test_mpi_allgather(self):
+ """Test that the allgather correctly gathers 1D, 2D, 3D tensors."""
+ with self.test_session() as session:
+ size = session.run(mpi.size())
+ rank = session.run(mpi.rank())
+
+ dtypes = tf.int32, tf.float32
+ dims = 1, 2, 3
+ for dtype, dim in itertools.product(dtypes, dims):
+ tensor = tf.ones([17] * dim, dtype=dtype) * rank
+ gathered = mpi.allgather(tensor)
+
+ gathered_tensor = session.run(gathered)
+ self.assertEqual(list(gathered_tensor.shape),
+ [17 * size] + [17] * (dim - 1))
+
+ for i in range(size):
+ rank_tensor = tf.slice(gathered_tensor, [i * 17] + [0] * (dim - 1),
+ [17] + [-1] * (dim - 1))
+ self.assertEqual(list(rank_tensor.shape), [17] * dim)
+ self.assertTrue(session.run(tf.reduce_all(tf.equal(rank_tensor, i))),
+ "mpi.allgather produces incorrect gathered tensor")
+
+ def test_mpi_allgather_variable_size(self):
+ """Test that the allgather correctly gathers 1D, 2D, 3D tensors,
+ even if those tensors have different sizes along the first dim."""
+ with self.test_session() as session:
+ size = session.run(mpi.size())
+ rank = session.run(mpi.rank())
+
+ dtypes = tf.int32, tf.float32
+ dims = 1, 2, 3
+ for dtype, dim in itertools.product(dtypes, dims):
+ # Support tests up to MPI Size of 35
+ if size > 35:
+ break
+
+ tensor_sizes = [17, 32, 81, 12, 15, 23, 22] * 5
+ tensor_sizes = tensor_sizes[:size]
+
+ tensor = tf.ones([tensor_sizes[rank]] + [17] * (dim - 1),
+ dtype=dtype) * rank
+ gathered = mpi.allgather(tensor)
+
+ gathered_tensor = session.run(gathered)
+ expected_size = sum(tensor_sizes)
+ self.assertEqual(list(gathered_tensor.shape),
+ [expected_size] + [17] * (dim - 1))
+
+ for i in range(size):
+ rank_size = [tensor_sizes[i]] + [17] * (dim - 1)
+ rank_tensor = tf.slice(gathered,
+ [sum(tensor_sizes[:i])] + [0] * (dim - 1),
+ rank_size)
+ self.assertEqual(list(rank_tensor.shape), rank_size)
+ self.assertTrue(session.run(tf.reduce_all(tf.equal(rank_tensor, i))),
+ "mpi.allgather produces incorrect gathered tensor")
+
+ def test_mpi_allgather_error(self):
+ """Test that the allgather returns an error if any dimension besides
+ the first is different among the tensors being gathered."""
+ with self.test_session() as session:
+ rank = session.run(mpi.rank())
+ size = session.run(mpi.size())
+
+ # This test does not apply if there is only one worker.
+ if size == 1:
+ return
+
+ tensor_size = [17] * 3
+ tensor_size[1] = 10 * (rank + 1)
+ tensor = tf.ones(tensor_size, dtype=tf.float32) * rank
+ with self.assertRaises(tf.errors.FailedPreconditionError):
+ session.run(mpi.allgather(tensor))
+
+ def test_mpi_allgather_type_error(self):
+ """Test that the allgather returns an error if the types being gathered
+ differ among the processes"""
+ with self.test_session() as session:
+ rank = session.run(mpi.rank())
+ size = session.run(mpi.size())
+
+ # This test does not apply if there is only one worker.
+ if size == 1:
+ return
+
+ tensor_size = [17] * 3
+ dtype = tf.int32 if rank % 2 == 0 else tf.float32
+ tensor = tf.ones(tensor_size, dtype=dtype) * rank
+ with self.assertRaises(tf.errors.FailedPreconditionError):
+ session.run(mpi.allgather(tensor))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/mpi_collectives/ring.cc b/tensorflow/contrib/mpi_collectives/ring.cc
new file mode 100644
index 0000000000..d93233eb21
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/ring.cc
@@ -0,0 +1,80 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_MPI
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/contrib/mpi_collectives/ring.h"
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi {
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+extern template MPI_Datatype MPIType<float>();
+extern template MPI_Datatype MPIType<int>();
+extern template MPI_Datatype MPIType<long long>();
+extern template DataType TensorFlowDataType<float>();
+extern template DataType TensorFlowDataType<int>();
+extern template DataType TensorFlowDataType<long long>();
+
+// Generate all necessary specializations for RingAllreduce.
+template Status RingAllreduce<CPUDevice, int>(OpKernelContext*, const Tensor*,
+ Tensor*, Tensor*);
+template Status RingAllreduce<CPUDevice, long long>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+template Status RingAllreduce<CPUDevice, float>(OpKernelContext*, const Tensor*,
+ Tensor*, Tensor*);
+
+// Generate all necessary specializations for RingAllgather.
+template Status RingAllgather<CPUDevice, int>(OpKernelContext*, const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+template Status RingAllgather<CPUDevice, long long>(OpKernelContext*,
+ const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+template Status RingAllgather<CPUDevice, float>(OpKernelContext*, const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+
+// Copy data on a CPU using a straight-forward memcpy.
+template <>
+void CopyTensorData<CPUDevice>(void* dst, void* src, size_t size) {
+ std::memcpy(dst, src, size);
+};
+
+// Accumulate values on a CPU.
+#define GENERATE_ACCUMULATE(type) \
+ template <> \
+ void AccumulateTensorData<CPUDevice, type>(type * dst, type * src, \
+ size_t size) { \
+ for (unsigned int i = 0; i < size; i++) { \
+ dst[i] += src[i]; \
+ } \
+ };
+GENERATE_ACCUMULATE(int);
+GENERATE_ACCUMULATE(long long);
+GENERATE_ACCUMULATE(float);
+#undef GENERATE_ACCUMULATE
+
+} // namespace mpi
+} // namespace contrib
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_MPI
diff --git a/tensorflow/contrib/mpi_collectives/ring.cu.cc b/tensorflow/contrib/mpi_collectives/ring.cu.cc
new file mode 100644
index 0000000000..2f3eef366a
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/ring.cu.cc
@@ -0,0 +1,117 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_MPI
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/contrib/mpi_collectives/ring.h"
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi {
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+template <>
+MPI_Datatype MPIType<float>() {
+ return MPI_FLOAT;
+};
+template <>
+MPI_Datatype MPIType<int>() {
+ return MPI_INT;
+};
+template <>
+MPI_Datatype MPIType<long long>() {
+ return MPI_LONG_LONG;
+};
+
+template <>
+DataType TensorFlowDataType<float>() {
+ return DT_FLOAT;
+};
+template <>
+DataType TensorFlowDataType<int>() {
+ return DT_INT32;
+};
+template <>
+DataType TensorFlowDataType<long long>() {
+ return DT_INT64;
+};
+
+// Generate all necessary specializations for RingAllreduce.
+template Status RingAllreduce<GPUDevice, int>(OpKernelContext*, const Tensor*,
+ Tensor*, Tensor*);
+template Status RingAllreduce<GPUDevice, long long>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+template Status RingAllreduce<GPUDevice, float>(OpKernelContext*, const Tensor*,
+ Tensor*, Tensor*);
+
+// Generate all necessary specializations for RingAllgather.
+template Status RingAllgather<GPUDevice, int>(OpKernelContext*, const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+template Status RingAllgather<GPUDevice, long long>(OpKernelContext*,
+ const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+template Status RingAllgather<GPUDevice, float>(OpKernelContext*, const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+
+// Synchronously copy data on the GPU, using a different stream than the default
+// and than TensorFlow to avoid synchronizing on operations unrelated to the
+// allreduce.
+template <>
+void CopyTensorData<GPUDevice>(void* dst, void* src, size_t size) {
+ auto stream = CudaStreamForMPI();
+ cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream);
+ cudaStreamSynchronize(stream);
+};
+
+// Elementwise accumulation kernel for GPU.
+template <typename T>
+__global__ void elemwise_accum(T* out, const T* in, const size_t N) {
+ for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
+ i += blockDim.x * gridDim.x) {
+ out[i] += in[i];
+ }
+}
+
+// Synchronously accumulate tensors on the GPU, using a different stream than
+// the default and than TensorFlow to avoid synchronizing on operations
+// unrelated to the allreduce.
+#define GENERATE_ACCUMULATE(type) \
+ template <> \
+ void AccumulateTensorData<GPUDevice, type>(type * dst, type * src, \
+ size_t size) { \
+ auto stream = CudaStreamForMPI(); \
+ elemwise_accum<type><<<32, 256, 0, stream>>>(dst, src, size); \
+ cudaStreamSynchronize(stream); \
+ };
+GENERATE_ACCUMULATE(int);
+GENERATE_ACCUMULATE(long long);
+GENERATE_ACCUMULATE(float);
+#undef GENERATE_ACCUMULATE
+
+} // namespace mpi
+} // namespace contrib
+} // namespace tensorflow
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_USE_MPI
diff --git a/tensorflow/contrib/mpi_collectives/ring.h b/tensorflow/contrib/mpi_collectives/ring.h
new file mode 100644
index 0000000000..cae57ce60e
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/ring.h
@@ -0,0 +1,327 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_MPI_H_
+#define TENSORFLOW_CONTRIB_MPI_H_
+
+#ifdef TENSORFLOW_USE_MPI
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+
+#if GOOGLE_CUDA
+#include "cuda_runtime.h"
+#endif
+
+// Needed to avoid header issues with C++-supporting MPI implementations
+#define OMPI_SKIP_MPICXX
+#include "third_party/mpi/mpi.h"
+
+#define TAG_TENSOR 12
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi {
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
+
+// Convert from templated types to values we can pass to MPI.
+template <typename T>
+MPI_Datatype MPIType();
+
+// Convert from templated types to TensorFlow data types.
+template <typename T>
+DataType TensorFlowDataType();
+
+#define MPI_REQUIRES_OK(MPI_STATUS) \
+ if ((MPI_STATUS) != MPI_SUCCESS) { \
+ return errors::Unknown("MPI operation failed unexpectedly."); \
+ }
+
+// Copy data from one tensor to another tensor.
+// This uses a custom CUDA stream on GPU, which is necessary to overlay the
+// backpropagation computations with the allreduce.
+template <typename Device>
+void CopyTensorData(void* destination, void* source, size_t size);
+
+// Add a tensor into another tensor, accumulating in place.
+// This uses a custom CUDA stream on GPU, which is necessary to overlay the
+// backpropagation computations with the allreduce.
+template <typename Device, typename T>
+void AccumulateTensorData(T* destination, T* source, size_t size);
+
+// We need to get the right stream for doing CUDA memory transfers and
+// operations, which is possibly different from the standard TensorFlow stream.
+#if GOOGLE_CUDA
+cudaStream_t CudaStreamForMPI();
+#endif
+
+/* Perform a ring allreduce on the data. Allocate the necessary output tensor
+ * and store it in the output parameter.
+ *
+ * Assumes that all MPI processes are doing an allreduce of the same tensor,
+ * with the same dimensions.
+ *
+ * A ring allreduce is a bandwidth-optimal way to do an allreduce. To do the
+ * allreduce, the nodes involved are arranged in a ring:
+ *
+ * .--0--.
+ * / \
+ * 3 1
+ * \ /
+ * *--2--*
+ *
+ * Each node always sends to the next clockwise node in the ring, and receives
+ * from the previous one.
+ *
+ * The allreduce is done in two parts: a scatter-reduce and an allgather. In
+ * the scatter reduce, a reduction is done, so that each node ends up with a
+ * chunk of the final output tensor which has contributions from all other
+ * nodes. In the allgather, those chunks are distributed among all the nodes,
+ * so that all nodes have the entire output tensor.
+ *
+ * Both of these operations are done by dividing the input tensor into N
+ * evenly sized chunks (where N is the number of nodes in the ring).
+ *
+ * The scatter-reduce is done in N-1 steps. In the ith step, node j will send
+ * the (j - i)th chunk and receive the (j - i - 1)th chunk, adding it in to
+ * its existing data for that chunk. For example, in the first iteration with
+ * the ring depicted above, you will have the following transfers:
+ *
+ * Segment 0: Node 0 --> Node 1
+ * Segment 1: Node 1 --> Node 2
+ * Segment 2: Node 2 --> Node 3
+ * Segment 3: Node 3 --> Node 0
+ *
+ * In the second iteration, you'll have the following transfers:
+ *
+ * Segment 0: Node 1 --> Node 2
+ * Segment 1: Node 2 --> Node 3
+ * Segment 2: Node 3 --> Node 0
+ * Segment 3: Node 0 --> Node 1
+ *
+ * After this iteration, Node 2 has 3 of the four contributions to Segment 0.
+ * The last iteration has the following transfers:
+ *
+ * Segment 0: Node 2 --> Node 3
+ * Segment 1: Node 3 --> Node 0
+ * Segment 2: Node 0 --> Node 1
+ * Segment 3: Node 1 --> Node 2
+ *
+ * After this iteration, Node 3 has the fully accumulated Segment 0; Node 0
+ * has the fully accumulated Segment 1; and so on. The scatter-reduce is
+ * complete.
+ *
+ * Next, the allgather distributes these fully accumululated chunks across all
+ * nodes. Communication proceeds in the same ring, once again in N-1 steps. At
+ * the ith step, node j will send chunk (j - i + 1) and receive chunk (j - i).
+ * For example, at the first iteration, the following transfers will occur:
+ *
+ * Segment 0: Node 3 --> Node 0
+ * Segment 1: Node 0 --> Node 1
+ * Segment 2: Node 1 --> Node 2
+ * Segment 3: Node 2 --> Node 3
+ *
+ * After the first iteration, Node 0 will have a fully accumulated Segment 0
+ * (from Node 3) and Segment 1. In the next iteration, Node 0 will send its
+ * just-received Segment 0 onward to Node 1, and receive Segment 3 from Node 3.
+ * After this has continued for N - 1 iterations, all nodes will have a the
+ * fully accumulated tensor.
+ *
+ * Each node will do (N-1) sends for the scatter-reduce and (N-1) sends for the
+ * allgather. Each send will contain K / N bytes, if there are K bytes in the
+ * original tensor on every node. Thus, each node sends and receives 2K(N - 1)/N
+ * bytes of data, and the performance of the allreduce (assuming no latency in
+ * connections) is constrained by the slowest interconnect between the nodes.
+ *
+ */
+template <typename Device, typename T>
+Status RingAllreduce(OpKernelContext* context, const Tensor* input,
+ Tensor* temp, Tensor* output) {
+ // Acquire MPI size and rank
+ int n, r;
+ MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n));
+ MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r));
+
+ T* buffer = (T*)output->tensor_data().data();
+
+ CopyTensorData<Device>((void*)buffer, (void*)input->tensor_data().data(),
+ output->tensor_data().size());
+
+ // Calculate segment sizes and segment ends
+ const size_t elements_to_reduce = input->NumElements();
+ const size_t segment_size = elements_to_reduce / n;
+ std::vector<size_t> segment_sizes(n, segment_size);
+
+ const size_t residual = elements_to_reduce % n;
+ for (size_t i = 0; i < residual; ++i) {
+ segment_sizes[i]++;
+ }
+
+ std::vector<size_t> segment_starts(n);
+ segment_starts[0] = 0;
+ for (size_t i = 1; i < segment_starts.size(); ++i) {
+ segment_starts[i] = segment_starts[i - 1] + segment_sizes[i - 1];
+ }
+
+ assert(segment_starts[n - 1] + segment_sizes[n - 1] == elements_to_reduce);
+
+ T* segment_recv = (T*)temp->tensor_data().data();
+
+ // Receive from your left neighbor with wrap-around
+ const size_t recv_from = ((r - 1) + n) % n;
+
+ // Send to your right neighbor with wrap-around
+ const size_t send_to = (r + 1) % n;
+
+ MPI_Status recv_status;
+ MPI_Request recv_req;
+
+ // Now start ring. At every step, for every rank, we iterate through
+ // segments with wraparound and send and recv from our neighbors and reduce
+ // locally. At the i'th iteration, rank r, sends segment (r-i) and receives
+ // segment (r-i-1).
+ for (int i = 0; i < n - 1; i++) {
+ const size_t send_seg_id = ((r - i) + n) % n;
+ const size_t recv_seg_id = ((r - i - 1) + n) % n;
+
+ T* segment_send = &(buffer[segment_starts[send_seg_id]]);
+
+ MPI_REQUIRES_OK(MPI_Irecv(segment_recv, segment_sizes[recv_seg_id],
+ MPIType<T>(), recv_from, TAG_TENSOR,
+ MPI_COMM_WORLD, &recv_req));
+
+ MPI_REQUIRES_OK(MPI_Send(segment_send, segment_sizes[send_seg_id],
+ MPIType<T>(), send_to, TAG_TENSOR,
+ MPI_COMM_WORLD));
+
+ T* segment_update = &(buffer[segment_starts[recv_seg_id]]);
+
+ // Wait for recv to complete before reduction
+ MPI_REQUIRES_OK(MPI_Wait(&recv_req, &recv_status));
+
+ const size_t recv_seg_size = segment_sizes[recv_seg_id];
+ AccumulateTensorData<Device, T>(segment_update, segment_recv,
+ recv_seg_size);
+ }
+
+ // Now start pipelined ring allgather. At every step, for every rank, we
+ // iterate through segments with wraparound and send and recv from our
+ // neighbors. At the i'th iteration, rank r, sends segment (r-i+1) and
+ // receives segment (r-i).
+ for (size_t i = 0; i < n - 1; ++i) {
+ const size_t send_seg_id = ((r - i + 1) + n) % n;
+ const size_t recv_seg_id = ((r - i) + n) % n;
+
+ // Segment to send - at every iteration we send segment (r-i+1)
+ T* segment_send = &(buffer[segment_starts[send_seg_id]]);
+
+ // Segment to recv - at every iteration we receive segment (r-i)
+ T* segment_recv = &(buffer[segment_starts[recv_seg_id]]);
+
+ MPI_REQUIRES_OK(MPI_Sendrecv(
+ segment_send, segment_sizes[send_seg_id], MPIType<T>(), send_to,
+ TAG_TENSOR, segment_recv, segment_sizes[recv_seg_id], MPIType<T>(),
+ recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status));
+ }
+
+ return Status::OK();
+}
+
+// Perform a ring allgather on a Tensor. Other ranks may allgather with a
+// tensor which differs in the first dimension only; all other dimensions must
+// be the same.
+//
+// For more information on the ring allgather, read the documentation for the
+// ring allreduce, which includes a ring allgather.
+template <typename Device, typename T>
+Status RingAllgather(OpKernelContext* context, const Tensor* input,
+ const std::vector<size_t>& sizes, Tensor* output) {
+ // Acquire MPI size and rank
+ int n, r;
+ MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n));
+ MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r));
+
+ assert(sizes.size() == n);
+ assert(input->dim_size(0) == sizes[r]);
+
+ // Compute number of elements in every "row". We can't compute number of
+ // elements in every chunks, because those chunks are variable length.
+ size_t elements_per_row = 1;
+ for (int i = 1; i < input->shape().dims(); i++) {
+ elements_per_row *= input->dim_size(i);
+ }
+
+ // Copy data from input tensor to correct place in output tensor.
+ std::vector<size_t> segment_starts(n);
+ segment_starts[0] = 0;
+ for (int i = 1; i < n; i++) {
+ segment_starts[i] = segment_starts[i - 1] + elements_per_row * sizes[i - 1];
+ }
+ size_t offset = segment_starts[r];
+
+ // Copy data to the right offset for this rank.
+ T* buffer = (T*)output->tensor_data().data();
+ CopyTensorData<Device>((void*)(buffer + offset),
+ (void*)input->tensor_data().data(),
+ elements_per_row * sizes[r] * sizeof(T));
+
+ // Receive from your left neighbor with wrap-around
+ const size_t recv_from = ((r - 1) + n) % n;
+
+ // Send to your right neighbor with wrap-around
+ const size_t send_to = (r + 1) % n;
+
+ // Perform a ring allgather. At every step, for every rank, we iterate
+ // through segments with wraparound and send and recv from our neighbors.
+ // At the i'th iteration, rank r, sends segment (r-i) and receives segment
+ // (r-1-i).
+ MPI_Status recv_status;
+ for (size_t i = 0; i < n - 1; ++i) {
+ const size_t send_seg_id = ((r - i) + n) % n;
+ const size_t recv_seg_id = ((r - i - 1) + n) % n;
+
+ // Segment to send - at every iteration we send segment (r-i)
+ size_t offset_send = segment_starts[send_seg_id];
+ size_t rows_send = sizes[send_seg_id];
+ T* segment_send = &(buffer[offset_send]);
+
+ // Segment to recv - at every iteration we receive segment (r-1-i)
+ size_t offset_recv = segment_starts[recv_seg_id];
+ size_t rows_recv = sizes[recv_seg_id];
+ T* segment_recv = &(buffer[offset_recv]);
+
+ MPI_REQUIRES_OK(MPI_Sendrecv(
+ segment_send, elements_per_row * rows_send, MPIType<T>(), send_to,
+ TAG_TENSOR, segment_recv, elements_per_row * rows_recv, MPIType<T>(),
+ recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status));
+ }
+
+ return Status::OK();
+}
+
+} // namespace mpi
+} // namespace contrib
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_MPI
+
+#undef TENSORFLOW_CONTRIB_MPI_H_
+#endif // TENSORFLOW_CONTRIB_MPI_H_