aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/mpi/mpi_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/mpi/mpi_utils.cc')
-rw-r--r--tensorflow/contrib/mpi/mpi_utils.cc72
1 files changed, 72 insertions, 0 deletions
diff --git a/tensorflow/contrib/mpi/mpi_utils.cc b/tensorflow/contrib/mpi/mpi_utils.cc
new file mode 100644
index 0000000000..b8e7d1a274
--- /dev/null
+++ b/tensorflow/contrib/mpi/mpi_utils.cc
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_MPI
+
+#include "tensorflow/contrib/mpi/mpi_utils.h"
+namespace tensorflow {
+
+#define max_worker_name_length 128
+
+MPIUtils::MPIUtils(const std::string& worker_name) {
+ InitMPI();
+ // Connect the MPI process IDs to the worker names that are used by TF.
+ // Gather the names of all the active processes (name can't be longer than
+ // 128 bytes)
+ int proc_id = 0, number_of_procs = 1;
+ char my_name[max_worker_name_length];
+ MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &proc_id));
+ MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &number_of_procs));
+
+ CHECK(worker_name.size() < max_worker_name_length)
+ << "Specified worker name is too long.";
+ snprintf(my_name, max_worker_name_length, worker_name.c_str());
+ std::vector<char> worker_names(number_of_procs * max_worker_name_length);
+ MPI_CHECK(MPI_Allgather(my_name, max_worker_name_length, MPI_CHAR,
+ &worker_names[0], max_worker_name_length, MPI_CHAR,
+ MPI_COMM_WORLD));
+
+ if (proc_id == 0) LOG(INFO) << "MPI process-ID to gRPC server name map: \n";
+ for (int i = 0; i < number_of_procs; i++) {
+ name_to_id_[std::string(&worker_names[i * 128])] = i;
+ if (proc_id == 0)
+ LOG(INFO) << "Process: " << i
+ << "\tgRPC-name: " << std::string(&worker_names[i * 128])
+ << std::endl;
+ }
+}
+
+void MPIUtils::InitMPI() {
+ // Initialize the MPI environment if that hasn't been done
+ int flag = 0;
+ MPI_CHECK(MPI_Initialized(&flag));
+ if (!flag) {
+ int proc_id = 0, number_of_procs = 1, len = -1;
+ char my_host_name[max_worker_name_length];
+ // MPI_CHECK(MPI_Init_thread(0, 0, MPI_THREAD_MULTIPLE, &flag));
+ MPI_CHECK(MPI_Init(0, 0));
+ MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &proc_id));
+ MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &number_of_procs));
+ MPI_CHECK(MPI_Get_processor_name(my_host_name, &len));
+ fprintf(stderr,
+ "MPI Environment initialised. Process id: %d Total processes: %d "
+ "|| Hostname: %s \n",
+ proc_id, number_of_procs, my_host_name);
+ }
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_MPI