aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/mpi_collectives/ring.cu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/mpi_collectives/ring.cu.cc')
-rw-r--r--tensorflow/contrib/mpi_collectives/ring.cu.cc117
1 files changed, 117 insertions, 0 deletions
diff --git a/tensorflow/contrib/mpi_collectives/ring.cu.cc b/tensorflow/contrib/mpi_collectives/ring.cu.cc
new file mode 100644
index 0000000000..2f3eef366a
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/ring.cu.cc
@@ -0,0 +1,117 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_MPI
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/contrib/mpi_collectives/ring.h"
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi {
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+template <>
+MPI_Datatype MPIType<float>() {
+ return MPI_FLOAT;
+};
+template <>
+MPI_Datatype MPIType<int>() {
+ return MPI_INT;
+};
+template <>
+MPI_Datatype MPIType<long long>() {
+ return MPI_LONG_LONG;
+};
+
+template <>
+DataType TensorFlowDataType<float>() {
+ return DT_FLOAT;
+};
+template <>
+DataType TensorFlowDataType<int>() {
+ return DT_INT32;
+};
+template <>
+DataType TensorFlowDataType<long long>() {
+ return DT_INT64;
+};
+
+// Generate all necessary specializations for RingAllreduce.
+template Status RingAllreduce<GPUDevice, int>(OpKernelContext*, const Tensor*,
+ Tensor*, Tensor*);
+template Status RingAllreduce<GPUDevice, long long>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+template Status RingAllreduce<GPUDevice, float>(OpKernelContext*, const Tensor*,
+ Tensor*, Tensor*);
+
+// Generate all necessary specializations for RingAllgather.
+template Status RingAllgather<GPUDevice, int>(OpKernelContext*, const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+template Status RingAllgather<GPUDevice, long long>(OpKernelContext*,
+ const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+template Status RingAllgather<GPUDevice, float>(OpKernelContext*, const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+
+// Synchronously copy data on the GPU, using a different stream than the default
+// and than TensorFlow to avoid synchronizing on operations unrelated to the
+// allreduce.
+template <>
+void CopyTensorData<GPUDevice>(void* dst, void* src, size_t size) {
+ auto stream = CudaStreamForMPI();
+ cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream);
+ cudaStreamSynchronize(stream);
+};
+
+// Elementwise accumulation kernel for GPU.
+template <typename T>
+__global__ void elemwise_accum(T* out, const T* in, const size_t N) {
+ for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
+ i += blockDim.x * gridDim.x) {
+ out[i] += in[i];
+ }
+}
+
+// Synchronously accumulate tensors on the GPU, using a different stream than
+// the default and than TensorFlow to avoid synchronizing on operations
+// unrelated to the allreduce.
+#define GENERATE_ACCUMULATE(type) \
+ template <> \
+ void AccumulateTensorData<GPUDevice, type>(type * dst, type * src, \
+ size_t size) { \
+ auto stream = CudaStreamForMPI(); \
+ elemwise_accum<type><<<32, 256, 0, stream>>>(dst, src, size); \
+ cudaStreamSynchronize(stream); \
+ };
+GENERATE_ACCUMULATE(int);
+GENERATE_ACCUMULATE(long long);
+GENERATE_ACCUMULATE(float);
+#undef GENERATE_ACCUMULATE
+
+} // namespace mpi
+} // namespace contrib
+} // namespace tensorflow
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_USE_MPI