aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/mpi_collectives/ring.cc
blob: d93233eb210b80df10fd9c2c7975ce77112d18a2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifdef TENSORFLOW_USE_MPI

#define EIGEN_USE_THREADS

#include "tensorflow/contrib/mpi_collectives/ring.h"

namespace tensorflow {
namespace contrib {
namespace mpi {

using CPUDevice = Eigen::ThreadPoolDevice;

extern template MPI_Datatype MPIType<float>();
extern template MPI_Datatype MPIType<int>();
extern template MPI_Datatype MPIType<long long>();
extern template DataType TensorFlowDataType<float>();
extern template DataType TensorFlowDataType<int>();
extern template DataType TensorFlowDataType<long long>();

// Generate all necessary specializations for RingAllreduce.
template Status RingAllreduce<CPUDevice, int>(OpKernelContext*, const Tensor*,
                                              Tensor*, Tensor*);
template Status RingAllreduce<CPUDevice, long long>(OpKernelContext*,
                                                    const Tensor*, Tensor*,
                                                    Tensor*);
template Status RingAllreduce<CPUDevice, float>(OpKernelContext*, const Tensor*,
                                                Tensor*, Tensor*);

// Generate all necessary specializations for RingAllgather.
template Status RingAllgather<CPUDevice, int>(OpKernelContext*, const Tensor*,
                                              const std::vector<size_t>&,
                                              Tensor*);
template Status RingAllgather<CPUDevice, long long>(OpKernelContext*,
                                                    const Tensor*,
                                                    const std::vector<size_t>&,
                                                    Tensor*);
template Status RingAllgather<CPUDevice, float>(OpKernelContext*, const Tensor*,
                                                const std::vector<size_t>&,
                                                Tensor*);

// Copy data on a CPU using a straight-forward memcpy.
template <>
void CopyTensorData<CPUDevice>(void* dst, void* src, size_t size) {
  std::memcpy(dst, src, size);
};

// Accumulate values on a CPU.
#define GENERATE_ACCUMULATE(type)                                    \
  template <>                                                        \
  void AccumulateTensorData<CPUDevice, type>(type * dst, type * src, \
                                             size_t size) {          \
    for (unsigned int i = 0; i < size; i++) {                        \
      dst[i] += src[i];                                              \
    }                                                                \
  };
GENERATE_ACCUMULATE(int);
GENERATE_ACCUMULATE(long long);
GENERATE_ACCUMULATE(float);
#undef GENERATE_ACCUMULATE

}  // namespace mpi
}  // namespace contrib
}  // namespace tensorflow

#endif  // TENSORFLOW_USE_MPI