/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifdef TENSORFLOW_USE_MPI #define EIGEN_USE_THREADS #include "tensorflow/contrib/mpi_collectives/ring.h" namespace tensorflow { namespace contrib { namespace mpi { using CPUDevice = Eigen::ThreadPoolDevice; extern template MPI_Datatype MPIType(); extern template MPI_Datatype MPIType(); extern template MPI_Datatype MPIType(); extern template DataType TensorFlowDataType(); extern template DataType TensorFlowDataType(); extern template DataType TensorFlowDataType(); // Generate all necessary specializations for RingAllreduce. template Status RingAllreduce(OpKernelContext*, const Tensor*, Tensor*, Tensor*); template Status RingAllreduce(OpKernelContext*, const Tensor*, Tensor*, Tensor*); template Status RingAllreduce(OpKernelContext*, const Tensor*, Tensor*, Tensor*); // Generate all necessary specializations for RingAllgather. template Status RingAllgather(OpKernelContext*, const Tensor*, const std::vector&, Tensor*); template Status RingAllgather(OpKernelContext*, const Tensor*, const std::vector&, Tensor*); template Status RingAllgather(OpKernelContext*, const Tensor*, const std::vector&, Tensor*); // Copy data on a CPU using a straight-forward memcpy. template <> void CopyTensorData(void* dst, void* src, size_t size) { std::memcpy(dst, src, size); }; // Accumulate values on a CPU. #define GENERATE_ACCUMULATE(type) \ template <> \ void AccumulateTensorData(type * dst, type * src, \ size_t size) { \ for (unsigned int i = 0; i < size; i++) { \ dst[i] += src[i]; \ } \ }; GENERATE_ACCUMULATE(int); GENERATE_ACCUMULATE(long long); GENERATE_ACCUMULATE(float); #undef GENERATE_ACCUMULATE } // namespace mpi } // namespace contrib } // namespace tensorflow #endif // TENSORFLOW_USE_MPI