aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc
blob: ed22ee667f1d73b3f86f77e09bad9bfec7e46391 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifdef TENSORFLOW_USE_MPI

#include <queue>
#include <thread>
#include <unordered_map>

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/mutex.h"

#define EIGEN_USE_THREADS

#if GOOGLE_CUDA
#include <cuda_runtime.h>
#include "tensorflow/stream_executor/stream.h"
#endif

#include "tensorflow/stream_executor/lib/statusor.h"

#define OMPI_SKIP_MPICXX
#include "third_party/mpi/mpi.h"
#include "tensorflow/contrib/mpi_collectives/kernels/ring.h"
#include "tensorflow/contrib/mpi_collectives/mpi_message.pb.h"

/*
 * MPI Allreduce and Allgather Ops for TensorFlow.
 *
 * TensorFlow natively provides inter-device communication through send and
 * receive ops and inter-node communication through Distributed TensorFlow,
 * based on the same send and receive abstractions. These end up being
 * insufficient for synchronous data-parallel training on HPC clusters where
 * Infiniband or other high-speed interconnects are available.  This module
 * implements MPI ops for allgather and allreduce, which do bandwidth-optimal
 * gathers and reductions and can take advantage of hardware-optimized
 * communication libraries through the MPI implementation.
 *
 * The primary logic of the allreduce and allgather are in RingAllgather() and
 * RingAllreduce(). The background thread which facilitates MPI operations is
 * run in BackgroundThreadLoop(). The provided MPI ops are:
 *      – MPIInit:
 *          Initialize MPI on a given device (CPU or GPU).
 *          Should only be run on a single device in every process.
 *      – MPISize:
 *          Get the number of MPI processes in the global communicator.
 *      – MPIRank:
 *          Get the rank of the current MPI process in the global communicator.
 *      – MPILocalRank:
 *          Get the local rank of the current MPI process within its node.
 *      – MPIAllreduce:
 *          Perform an allreduce on a Tensor, returning the sum
 *          across all MPI processes in the global communicator.
 *      – MPIAllgather:
 *          Perform an allgather on a Tensor, returning the concatenation of
 *          the tensor on the first dimension across all MPI processes in the
 *          global communicator.
 *
 */

template <class T>
using StatusOr = se::port::StatusOr<T>;

using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;

namespace tensorflow {
namespace contrib {
namespace mpi_collectives {

// Make sure template specializations are generated in the ring.cu.cc and the
// ring.cc file, not in this file.
extern template Status RingAllreduce<GPUDevice, int>(OpKernelContext*,
                                                     const Tensor*, Tensor*,
                                                     Tensor*);
extern template Status RingAllreduce<GPUDevice, long long>(OpKernelContext*,
                                                           const Tensor*,
                                                           Tensor*, Tensor*);
extern template Status RingAllreduce<GPUDevice, float>(OpKernelContext*,
                                                       const Tensor*, Tensor*,
                                                       Tensor*);
extern template Status RingAllgather<GPUDevice, int>(OpKernelContext*,
                                                     const Tensor*,
                                                     const std::vector<size_t>&,
                                                     Tensor*);
extern template Status RingAllgather<GPUDevice, long long>(
    OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
extern template Status RingAllgather<GPUDevice, float>(
    OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
extern template Status RingAllreduce<CPUDevice, int>(OpKernelContext*,
                                                     const Tensor*, Tensor*,
                                                     Tensor*);
extern template Status RingAllreduce<CPUDevice, long long>(OpKernelContext*,
                                                           const Tensor*,
                                                           Tensor*, Tensor*);
extern template Status RingAllreduce<CPUDevice, float>(OpKernelContext*,
                                                       const Tensor*, Tensor*,
                                                       Tensor*);
extern template Status RingAllgather<CPUDevice, int>(OpKernelContext*,
                                                     const Tensor*,
                                                     const std::vector<size_t>&,
                                                     Tensor*);
extern template Status RingAllgather<CPUDevice, long long>(
    OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
extern template Status RingAllgather<CPUDevice, float>(
    OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);

namespace {

// Return true if the templated type is GPUDevice, otherwise false.
template <typename T>
bool IsGPUDevice();
template <>
bool IsGPUDevice<GPUDevice>() {
  return true;
};
template <>
bool IsGPUDevice<CPUDevice>() {
  return false;
};

// A callback to call after the MPI communication completes. Since the
// allreduce and allgather ops are asynchronous, this callback is what resumes
// computation after the reduction is completed.
typedef std::function<void(StatusOr<Tensor>)> CommunicationDoneCallback;

struct CollectiveOpRecord {
  // The rank performing this piece of the op
  int rank;

  // The name of the op/tensor to be reduced
  std::string name;

  // The op's kernel context
  OpKernelContext* context;

  // Data type of the op
  DataType dtype;

  // The input tensor
  const Tensor* in_t;

  // Allgather: Vector of per-rank first-dimension sizes
  std::vector<size_t> sizes_vec;

  // The temp tensor for intermediate results
  Tensor temp_t;

  // The output tensor
  Tensor* out_t;

  // Whether to run this op on the gpu
  bool on_gpu;

  // The callback to call after the op has completed
  CommunicationDoneCallback callback;
};

// Table storing Tensors to be reduced, keyed by unique name.
// This table contains everything necessary to do the reduction
typedef std::unordered_map<std::string, CollectiveOpRecord> TensorTable;

// Table for storing Tensor metadata on rank zero. This is used for error
// checking and size calculations, as well as determining when a reduction is
// ready to be done (when all nodes are ready to do it).
typedef std::unordered_map<std::string, std::vector<MPIRequest> > MessageTable;

// The global state required for the MPI ops.
//
// MPI is a library that stores a lot of global per-program state and often
// requires running on a single thread. As a result, we have to have a single
// background thread responsible for all MPI operations, and communicate with
// that background thread through global state.
struct MPIGlobalState {
  // An atomic boolean which is set to true when MPI is initialized.
  // This ensures that MPI_Init is never called twice.
  std::atomic_flag initialized_flag = ATOMIC_FLAG_INIT;

  // Condition variable to wait for initialization
  condition_variable cv;

  // Whether MPI_Init has been completed on the background thread.
  bool initialization_done = false;

  // Whether MPI_Init succeeded on the background thread.
  Status init_status;

  // A mutex that needs to be used whenever MPI operations touch
  // shared structures.
  mutex mu;

  // Tensors waiting to be allreduced or allgathered.
  TensorTable tensor_table;

  // Queue of MPI requests waiting to be sent to the coordinator node.
  std::queue<MPIRequest> message_queue;

  // Background thread running MPI communication.
  std::thread background_thread;

  // Whether the background thread should shutdown.
  bool shut_down = false;

  // Only exists on the coordinator node (rank zero). Maintains a count of
  // how many nodes are ready to allreduce every tensor (keyed by tensor
  // name).
  std::unique_ptr<MessageTable> message_table;

  // The MPI rank, local rank, and size.
  int rank = 0;
  int local_rank = 0;
  int size = 1;

  // The device that MPI was initialized on. (-1 for no GPU)
  int device = -1;

  // The CUDA stream used for data transfers and within-allreduce operations.
  // A naive implementation would use the TensorFlow StreamExecutor CUDA
  // stream. However, the allreduce and allgather require doing memory copies
  // and kernel executions (for accumulation of values on the GPU). However,
  // the subsequent operations must wait for those operations to complete,
  // otherwise MPI (which uses its own stream internally) will begin the data
  // transfers before the CUDA calls are complete. In order to wait for those
  // CUDA operations, if we were using the TensorFlow stream, we would have
  // to synchronize that stream; however, other TensorFlow threads may be
  // submitting more work to that stream, so synchronizing on it can cause
  // the allreduce to be delayed, waiting for compute totally unrelated to it
  // in other parts of the graph. Overlaying memory transfers and compute
  // during backpropagation is crucial for good performance, so we cannot use
  // the TensorFlow stream, and must use our own stream.
#if GOOGLE_CUDA
  cudaStream_t stream;
  std::atomic_flag stream_created_flag = ATOMIC_FLAG_INIT;
#endif

  ~MPIGlobalState() {
    // Make sure that the destructor of the background thread is safe to
    // call. If a thread is still joinable (not detached or complete) its
    // destructor cannot be called.
    if (background_thread.joinable()) {
      shut_down = true;
      background_thread.join();
    }
  }
};

// All the MPI state that must be stored globally per-process.
static MPIGlobalState mpi_global;

// For clarify in argument lists.
#define RANK_ZERO 0

// A tag used for all coordinator messaging.
#define TAG_NOTIFY 1

// Store the MPIRequest for a name, and return whether the total count of
// MPIRequests for that tensor is now equal to the MPI size (and thus we are
// ready to reduce the tensor).
bool IncrementTensorCount(std::unique_ptr<MessageTable>& message_table,
                          MPIRequest msg, int mpi_size) {
  auto name = msg.tensor_name();
  auto table_iter = message_table->find(name);
  if (table_iter == message_table->end()) {
    message_table->emplace(name, std::vector<MPIRequest>({msg}));
    table_iter = message_table->find(name);
  } else {
    table_iter->second.push_back(msg);
  }

  int count = table_iter->second.size();
  return count == mpi_size;
}

// Once a tensor is ready to be reduced, the coordinator sends an MPIResponse
// instructing all ranks to start the reduction to all ranks. The MPIResponse
// also contains error messages in case the submitted MPIRequests were not
// valid (for example, contained mismatched shapes or types).
//
// Constructing the MPIResponse, thus, requires a whole lot of error checking.
MPIResponse ConstructMPIResponse(std::unique_ptr<MessageTable>& message_table,
                                 std::string name) {
  bool error = false;
  auto it = message_table->find(name);
  assert(it != message_table->end());

  std::vector<MPIRequest> requests = it->second;
  assert(requests.size() > 0);

  std::ostringstream error_message_stream;

  // Check that all data types being reduced or gathered are identical
  auto data_type = requests[0].tensor_type();
  for (unsigned int i = 1; i < requests.size(); i++) {
    auto request_type = requests[i].tensor_type();
    if (data_type != request_type) {
      error = true;
      error_message_stream << "Mismatched data types: One rank had type "
                           << DataType_Name(data_type)
                           << ", but another rank had type "
                           << DataType_Name(request_type) << ".";
      break;
    }
  }

  // Check that all requested operations are the same
  auto message_type = requests[0].request_type();
  for (unsigned int i = 1; i < requests.size(); i++) {
    if (error) {
      break;
    }

    auto request_type = requests[i].request_type();
    if (message_type != request_type) {
      error = true;
      error_message_stream << "Mismatched MPI operations: One rank did an "
                           << message_type << ", but another rank did an "
                           << request_type << ".";
      break;
    }
  }

  // If we are doing an allreduce, check that all tensor shapes
  // are identical
  if (message_type == MPIRequest::ALLREDUCE) {
    TensorShape tensor_shape = requests[0].tensor_shape();
    for (unsigned int i = 1; i < requests.size(); i++) {
      if (error) {
        break;
      }

      TensorShape request_shape = requests[i].tensor_shape();
      if (tensor_shape != request_shape) {
        error = true;
        error_message_stream << "Mismatched allreduce tensor shapes: "
                             << "One rank reduced a tensor of shape "
                             << tensor_shape.DebugString()
                             << ", but another rank sent a tensor of shape "
                             << request_shape.DebugString() << ".";
        break;
      }
    }
  }

  // If we are doing an allgather, make sure all but the first dimension are
  // the same. The first dimension may be different and the output tensor is
  // the sum of the first dimension. Collect the sizes by rank.
  if (message_type == MPIRequest::ALLGATHER) {
    TensorShape tensor_shape = requests[0].tensor_shape();

    if (tensor_shape.dims() == 0) {
      error = true;
      error_message_stream << "Rank zero tried to gather a rank-zero tensor.";
    }

    for (unsigned int i = 1; i < requests.size(); i++) {
      if (error) {
        break;
      }

      TensorShape request_shape = requests[i].tensor_shape();
      if (tensor_shape.dims() != request_shape.dims()) {
        error = true;
        error_message_stream << "Mismatched allgather tensor shapes: "
                             << "One rank gathered a tensor of rank "
                             << tensor_shape.dims()
                             << ", but another rank sent a tensor of rank "
                             << request_shape.dims() << ".";
        break;
      }

      for (unsigned int dim = 1; dim < tensor_shape.dims(); dim++) {
        if (tensor_shape.dim_size(dim) != request_shape.dim_size(dim)) {
          error = true;
          error_message_stream
              << "Mismatched allgather tensor shapes: "
              << "One rank gathered a tensor with dimension " << dim
              << " equal to " << tensor_shape.dim_size(dim)
              << ", but another rank sent a tensor with dimension " << dim
              << " equal to " << request_shape.dim_size(dim) << ".";
          break;
        }
      }
    }
  }

  MPIResponse response;
  response.set_tensor_name(name);
  if (error) {
    std::string error_message = error_message_stream.str();
    response.set_response_type(MPIResponse::ERROR);
    response.set_error_message(error_message);
  } else {
    auto response_type = MPIResponse::ERROR;
    if (message_type == MPIRequest::ALLREDUCE) {
      response_type = MPIResponse::ALLREDUCE;
    } else {
      response_type = MPIResponse::ALLGATHER;
    }
    response.set_response_type(response_type);
  }

  // Clear all queued up requests for this name. They are now taken care of
  // by the constructed MPI response.
  message_table->erase(it);

  return response;
}

// Process an MPIResponse by doing a reduction, a gather, or raising an error.
void PerformCollectiveOp(TensorTable& tensor_table, MPIResponse response) {
  OpKernelContext* context;
  const Tensor* input_tensor;
  std::vector<size_t> sizes_vec;
  Tensor temp_tensor;
  Tensor* output_tensor;
  CommunicationDoneCallback callback;
  bool on_gpu;
  {
    // Lock on the tensor table.
    mutex_lock guard(mpi_global.mu);

    // We should never fail at finding this key in the tensor table.
    auto name = response.tensor_name();
    auto iter = tensor_table.find(name);
    assert(iter != tensor_table.end());

    assert(response.response_type() == MPIResponse::ALLREDUCE ||
           response.response_type() == MPIResponse::ALLGATHER ||
           response.response_type() == MPIResponse::ERROR);

    CollectiveOpRecord record = iter->second;
    context = record.context;
    input_tensor = record.in_t;
    sizes_vec = record.sizes_vec;
    temp_tensor = record.temp_t;
    output_tensor = record.out_t;
    on_gpu = record.on_gpu;
    callback = record.callback;

    // Clear the tensor table of this tensor and its callbacks; the rest of
    // this function takes care of it.
    tensor_table.erase(iter);
  }

  // Use CPUDevice instead of GPUDevice if no CUDA, to ensure we don't
  // link to non-existent symbols.
#if GOOGLE_CUDA
#define GPU_DEVICE_IF_CUDA GPUDevice
#else
#define GPU_DEVICE_IF_CUDA CPUDevice
#endif

  Status status;
  auto dtype = input_tensor->dtype();
  if (response.response_type() == MPIResponse::ALLGATHER) {
    if (dtype == DT_FLOAT) {
      status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, float>(
                            context, input_tensor, sizes_vec, output_tensor)
                      : RingAllgather<CPUDevice, float>(
                            context, input_tensor, sizes_vec, output_tensor);
    } else if (dtype == DT_INT32) {
      status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, int>(
                            context, input_tensor, sizes_vec, output_tensor)
                      : RingAllgather<CPUDevice, int>(context, input_tensor,
                                                      sizes_vec, output_tensor);
    } else if (dtype == DT_INT64) {
      status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, long long>(
                            context, input_tensor, sizes_vec, output_tensor)
                      : RingAllgather<CPUDevice, long long>(
                            context, input_tensor, sizes_vec, output_tensor);
    } else {
      status = errors::Unknown("Invalid tensor type for MPI allgather.");
    }
  } else if (response.response_type() == MPIResponse::ALLREDUCE) {
    if (dtype == DT_FLOAT) {
      status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, float>(
                            context, input_tensor, &temp_tensor, output_tensor)
                      : RingAllreduce<CPUDevice, float>(
                            context, input_tensor, &temp_tensor, output_tensor);
    } else if (dtype == DT_INT32) {
      status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, int>(
                            context, input_tensor, &temp_tensor, output_tensor)
                      : RingAllreduce<CPUDevice, int>(
                            context, input_tensor, &temp_tensor, output_tensor);
    } else if (dtype == DT_INT64) {
      status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, long long>(
                            context, input_tensor, &temp_tensor, output_tensor)
                      : RingAllreduce<CPUDevice, long long>(
                            context, input_tensor, &temp_tensor, output_tensor);
    } else {
      status = errors::Unknown("Invalid tensor type for MPI allreduce.");
    }
  } else if (response.response_type() == MPIResponse::ERROR) {
    status = errors::FailedPrecondition(response.error_message());
  }

  if (status.ok()) {
    callback(StatusOr<Tensor>(*output_tensor));
  } else {
    callback(StatusOr<Tensor>(status));
  }
}

// The MPI background thread loop coordinates all the MPI processes and the
// tensor reductions. The design of the communicator mechanism is limited by a
// few considerations:
//
//      1. Some MPI implementations require all MPI calls to happen from a
//      single thread. Since TensorFlow may use several threads for graph
//      processing, this means we must have our own dedicated thread for
//      dealing with MPI.
//      2. We want to gracefully handle errors, when MPI processes do not
//      properly agree upon what should happen (such as mismatched types or
//      shapes). To do so requires the MPI processes to know about the shapes
//      and types of the relevant tensors on the other processes.
//      3. The MPI reductions and gathers should be able to happen in parallel
//      with other ongoing operations. Since MPI uses an internal
//      (inaccessible) GPU stream separate from the TF GPUDevice streams, we
//      cannot explicitly synchronize memcpys or kernels with it. As a result,
//      MPIAllreduce and MPIAllgather must be AsyncOpKernels to ensure proper
//      ordering of memcpys and kernels with respect to TF streams.
//      4. NOTE: We cannot guarantee that all the MPI processes reduce their
//      tensors in the same order. Thus, there must be a way to ensure the
//      reduction memcpys and kernels occur for correct tensors across all
//      ranks at the same time. We choose to use a coordinator (rank ID 0) to
//      gather and trigger the reduction operations that are ready to execute.
//
// The coordinator currently follows a master-worker paradigm. Rank zero acts
// as the master (the "coordinator"), whereas all other ranks are simply
// workers. Each rank runs its own background thread which progresses in ticks.
// In each tick, the following actions happen:
//
//      a) The workers send any available MPIRequests to the coordinator. These
//      MPIRequests indicate what the worker would like to do (i.e. which
//      tensor they would like to gather or reduce, as well as their shape and
//      type). They repeat this for every tensor that they would like to
//      operate on after that tensor's collective op has executed ComputeAsync.
//
//      b) The workers send an empty "DONE" message to the coordinator to
//      indicate that there are no more tensors they wish to operate on.
//
//      c) The coordinator receives the MPIRequests from the workers, as well
//      as from its own TensorFlow ops, and stores them in a request table. The
//      coordinator continues to receive MPIRequest messages until it has
//      received MPI_SIZE number of empty "DONE" messages.
//
//      d) The coordinator finds all tensors that are ready to be reduced,
//      gathered, or all operations that result in an error. For each of those,
//      it sends an MPIResponse to all the workers. When no more MPIResponses
//      are available, it sends a "DONE" response to the workers. If the
//      process is being shutdown, it instead sends a "SHUTDOWN" response.
//
//      e) The workers listen for MPIResponse messages, processing each one by
//      doing the required reduce or gather, until they receive a "DONE"
//      response from the coordinator. At that point, the tick ends.
//      If instead of "DONE" they receive "SHUTDOWN", they exit their
//      background loop.
// TODO: Use the global mpi_global state variable instead of a local one
void BackgroundThreadLoop() {
#if GOOGLE_CUDA
  // Set the device, so that this thread uses the same GPU context as the
  // calling thread.
  // TODO: Ensure that this is operating correctly. The background thread
  // needs to be able to control all GPUs that the rank has access to, and
  // might be more than 1 GPU. Tensors could be resident in any of the
  // GPUs, so the background thread's accumulate and copy kernels might need
  // to correctly set the device and it might be necessary for the background
  // thread to manage multiple streams.
  cudaSetDevice(mpi_global.device);
  cudaStreamCreate(&mpi_global.stream);
#endif

  // Initialize MPI. This must happen on the background thread, since not all
  // MPI implementations support being called from multiple threads.
  auto init_result = MPI_Init(NULL, NULL);
  if (init_result != MPI_SUCCESS) {
    mpi_global.init_status =
        errors::Unknown("Could not initialize MPI; MPI_Init() failed.");
    mpi_global.initialization_done = true;
    mpi_global.cv.notify_all();
    return;
  } else {
    mpi_global.init_status = Status::OK();
  }

  // Get MPI rank to determine if we are rank zero.
  int rank;
  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  bool is_coordinator = rank == 0;

  // Get MPI size to determine how many tensors to wait for before reducing.
  int size;
  MPI_Comm_size(MPI_COMM_WORLD, &size);

  // Determine local rank by querying the local communicator.
  MPI_Comm local_comm;
  MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL,
                      &local_comm);
  int local_rank;
  MPI_Comm_rank(local_comm, &local_rank);

  mpi_global.rank = rank;
  mpi_global.local_rank = local_rank;
  mpi_global.size = size;
  mpi_global.initialization_done = true;

  // Notify calling thread that initialization is complete
  mpi_global.cv.notify_all();

  // TODO: MOVE MESSAGE TABLE INITIALIZATION TO LIBRARY LOAD!
  // Initialize the tensor count table. No tensors are available yet.
  if (is_coordinator) {
    mpi_global.message_table =
        std::unique_ptr<MessageTable>(new MessageTable());
  }

  // The coordinator sends a SHUTDOWN message to trigger shutdown.
  bool should_shut_down = false;
  do {
    // TODO: Eliminate the need for thread sleep by making all activity
    // depend on other activity (e.g. condition or MPI waits).
    std::this_thread::sleep_for(std::chrono::milliseconds(1));

    // Copy the data structures from global state under this lock.
    // However, don't keep the lock for the rest of the loop, so that
    // enqueued stream callbacks can continue.
    std::queue<MPIRequest> message_queue;
    {
      mutex_lock guard(mpi_global.mu);
      while (!mpi_global.message_queue.empty()) {
        MPIRequest message = mpi_global.message_queue.front();
        mpi_global.message_queue.pop();
        message_queue.push(message);
      }
    }

    // Collect all tensors that are ready to be reduced. Record them in the
    // tensor count table (rank zero) or send them to rank zero to be
    // recorded (everyone else).
    std::vector<std::string> ready_to_reduce;
    while (!message_queue.empty()) {
      // Pop the first available message message
      MPIRequest message = message_queue.front();
      message_queue.pop();

      if (is_coordinator) {
        bool reduce =
            IncrementTensorCount(mpi_global.message_table, message, size);
        if (reduce) {
          ready_to_reduce.push_back(message.tensor_name());
        }
      } else {
        std::string encoded_message;
        message.SerializeToString(&encoded_message);
        MPI_Send(encoded_message.c_str(), encoded_message.length() + 1,
                 MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
      }
    }

    // Rank zero has put all its own tensors in the tensor count table.
    // Now, it should count all the tensors that are coming from other
    // ranks at this tick. It should keep getting tensors until it gets a
    // DONE message from all the other ranks.
    if (is_coordinator) {
      // Count of DONE messages. Keep receiving messages until the number
      // of messages is equal to the number of processes. Initialize to
      // one since the coordinator is effectively done.
      int completed_ranks = 1;
      while (completed_ranks != size) {
        MPI_Status status;
        MPI_Probe(MPI_ANY_SOURCE, TAG_NOTIFY, MPI_COMM_WORLD, &status);

        // Find number of characters in message (including zero byte).
        int source_rank = status.MPI_SOURCE;
        int msg_length;
        MPI_Get_count(&status, MPI_BYTE, &msg_length);

        // If the length is zero, this is a DONE message.
        if (msg_length == 0) {
          completed_ranks++;
          MPI_Recv(NULL, 0, MPI_BYTE, source_rank, TAG_NOTIFY, MPI_COMM_WORLD,
                   &status);
          continue;
        }

        // Get tensor name from MPI into an std::string.
        char* buffer = new char[msg_length];
        MPI_Recv(buffer, msg_length, MPI_BYTE, source_rank, TAG_NOTIFY,
                 MPI_COMM_WORLD, &status);
        std::string received_data(buffer);
        delete[] buffer;

        MPIRequest received_message;
        received_message.ParseFromString(received_data);
        auto received_name = received_message.tensor_name();

        bool reduce = IncrementTensorCount(mpi_global.message_table,
                                           received_message, size);
        if (reduce) {
          ready_to_reduce.push_back(received_name);
        }
      }

      // At this point, rank zero should have a fully updated tensor
      // count table and should know all the tensors that need to be
      // reduced or gathered, and everyone else should have sent all
      // their information to rank zero. We can now do reductions and
      // gathers; rank zero will choose which ones and in what order,
      // and will notify the other ranks before doing each reduction.
      for (int i = 0; i < ready_to_reduce.size(); i++) {
        // Notify all nodes which tensor we'd like to reduce now
        auto name = ready_to_reduce[i];
        MPIResponse response =
            ConstructMPIResponse(mpi_global.message_table, name);

        std::string encoded_response;
        response.SerializeToString(&encoded_response);
        for (int r = 1; r < size; r++) {
          MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
                   MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
        }

        // Perform the reduction. All nodes should end up performing
        // the same reduction.
        PerformCollectiveOp(mpi_global.tensor_table, response);
      }

      // Notify all nodes that we are done with the reductions for this
      // tick.
      MPIResponse done_response;
      should_shut_down = mpi_global.shut_down;
      done_response.set_response_type(
          mpi_global.shut_down ? MPIResponse::SHUTDOWN : MPIResponse::DONE);
      std::string encoded_response;
      done_response.SerializeToString(&encoded_response);
      for (int r = 1; r < size; r++) {
        MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
                 MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
      }
    } else {
      // Notify the coordinator that this node is done sending messages.
      // A DONE message is encoded as a zero-length message.
      MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);

      // Receive names for tensors to reduce from rank zero. Once we
      // receive a empty DONE message, stop waiting for more names.
      while (true) {
        MPI_Status status;
        MPI_Probe(0, TAG_NOTIFY, MPI_COMM_WORLD, &status);

        // Find number of characters in message (including zero byte).
        int msg_length;
        MPI_Get_count(&status, MPI_BYTE, &msg_length);

        // Get tensor name from MPI into an std::string.
        char* buffer = new char[msg_length];
        MPI_Recv(buffer, msg_length, MPI_BYTE, 0, TAG_NOTIFY, MPI_COMM_WORLD,
                 &status);
        std::string received_message(buffer);
        delete[] buffer;

        MPIResponse response;
        response.ParseFromString(received_message);
        if (response.response_type() == MPIResponse::DONE) {
          // No more messages this tick
          break;
        } else if (response.response_type() == MPIResponse::SHUTDOWN) {
          // No more messages this tick, and the background thread
          // should shut down
          should_shut_down = true;
          break;
        } else {
          // Process the current message
          PerformCollectiveOp(mpi_global.tensor_table, response);
        }
      }
    }
  } while (!should_shut_down);

  MPI_Finalize();
}

// Initialize MPI and start the MPI background thread. Ensure that this is
// only done once no matter how many times this function is called.
Status InitializeMPIOnce(bool gpu) {
  // Ensure MPI is only initialized once.
  if (mpi_global.initialized_flag.test_and_set()) return mpi_global.init_status;

  mpi_global.device = -1;
#if GOOGLE_CUDA
  if (gpu) {
    cudaGetDevice(&mpi_global.device);
  }
#endif

  // Start the MPI background thread, which assumes MPI is initialized
  // TODO: Change this to a Tensorflow thread
  mpi_global.background_thread = std::thread(BackgroundThreadLoop);

  // Wait to ensure that the background thread has finished initializing MPI
  mutex_lock guard(mpi_global.mu);
  mpi_global.cv.wait(guard);
  if (!mpi_global.initialization_done) {
    mpi_global.init_status =
        errors::Unknown("Failed to wait for MPI initialization.");
  }

  return mpi_global.init_status;
}

// Check that MPI is initialized.
Status IsMPIInitialized() {
  if (!mpi_global.initialization_done) {
    return errors::FailedPrecondition(
        "MPI has not been initialized; use tf.contrib.mpi.Session.");
  }
  return Status::OK();
}

// This function (called from the callback set up in MPIAll*Op::ComputeAsync)
// only adds the op's record into the local op queue (to track the op's
// progress), and sends a message to the coordinator indicating that this rank
// is ready to begin. The MPI background thread will handle the MPI message.
void EnqueueTensorCollective(CollectiveOpRecord record,
                             MPIRequest::RequestType rtype) {
  const Tensor* input_tensor = record.in_t;
  MPIRequest message;
  message.set_request_rank(record.rank);
  message.set_tensor_name(record.name);
  message.set_tensor_type(record.dtype);
  message.set_request_type(rtype);
  input_tensor->shape().AsProto(message.mutable_tensor_shape());

  mutex_lock guard(mpi_global.mu);
  mpi_global.tensor_table.emplace(record.name, record);
  mpi_global.message_queue.push(message);
}

}  // namespace

#if GOOGLE_CUDA
cudaStream_t CudaStreamForMPI() { return mpi_global.stream; }
#endif

// Op to initialize MPI in the current process. The settings used in the
// configuration are the same that must be used for all future MPI ops.
template <typename Device>
class MPIInitOp : public OpKernel {
 public:
  explicit MPIInitOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    bool on_gpu = IsGPUDevice<Device>();
    OP_REQUIRES_OK(context, InitializeMPIOnce(on_gpu));
  }
};

REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_CPU),
                        MPIInitOp<CPUDevice>);
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_GPU),
                        MPIInitOp<GPUDevice>);
#endif

// Op to get the current MPI Size.
template <typename Device>
class MPISizeOp : public OpKernel {
 public:
  explicit MPISizeOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    OP_REQUIRES_OK(context, IsMPIInitialized());

    // Write integer to output tensor
    Tensor* output;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, TensorShape({}), &output));

    auto flat = output->flat<int>();
    flat(0) = mpi_global.size;
  }
};

REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_CPU),
                        MPISizeOp<CPUDevice>);
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_GPU).HostMemory("size"),
                        MPISizeOp<GPUDevice>);
#endif

// Op to get the current MPI Rank.
template <typename Device>
class MPIRankOp : public OpKernel {
 public:
  explicit MPIRankOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    OP_REQUIRES_OK(context, IsMPIInitialized());

    // Write integer to output tensor
    Tensor* output;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, TensorShape({}), &output));

    auto flat = output->flat<int>();
    flat(0) = mpi_global.rank;
  }
};

REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_CPU),
                        MPIRankOp<CPUDevice>);
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_GPU).HostMemory("rank"),
                        MPIRankOp<GPUDevice>);
#endif

// Op to get the current local MPI Rank.
template <typename Device>
class MPILocalRankOp : public OpKernel {
 public:
  explicit MPILocalRankOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    OP_REQUIRES_OK(context, IsMPIInitialized());

    // Write integer to output tensor
    Tensor* output;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, TensorShape({}), &output));

    auto flat = output->flat<int>();
    flat(0) = mpi_global.local_rank;
  }
};

REGISTER_KERNEL_BUILDER(Name("MPILocalRank").Device(DEVICE_CPU),
                        MPILocalRankOp<CPUDevice>);
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(
    Name("MPILocalRank").Device(DEVICE_GPU).HostMemory("rank"),
    MPILocalRankOp<GPUDevice>);
#endif

template <typename Device>
class MPIAllreduceOp : public AsyncOpKernel {
 public:
  explicit MPIAllreduceOp(OpKernelConstruction* context)
      : AsyncOpKernel(context) {}

  // Although this op is handled asynchronously, the ComputeAsync call is
  // very inexpensive. It only sets up a CollectiveOpRecord and places it
  // in the table for the background thread to handle. Thus, we do not need
  // a TF pool thread to perform the op.
  bool IsExpensive() override { return false; }

  void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
    OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done);
    const Tensor* input_tensor = &context->input(0);
    Tensor* output_tensor;
    OP_REQUIRES_OK_ASYNC(
        context,
        context->allocate_output(0, input_tensor->shape(), &output_tensor),
        done);

    // Record allocated on stack so op can fail without memory leak
    CollectiveOpRecord record;
    record.name = name();
    record.context = context;
    record.in_t = input_tensor;
    record.out_t = output_tensor;
    record.on_gpu = IsGPUDevice<Device>();
    record.dtype = input_tensor->dtype();

    const size_t temp_size =
        (input_tensor->NumElements() + mpi_global.size - 1) / mpi_global.size;
    TensorShape temp_shape;
    temp_shape.AddDim(temp_size);
    OP_REQUIRES_OK_ASYNC(context,
                         context->allocate_temp(input_tensor->dtype(),
                                                temp_shape, &record.temp_t),
                         done);

    auto allreduce_done_callback = [done, context](StatusOr<Tensor> status) {
      context->SetStatus(status.status());
      done();
    };
    record.callback = allreduce_done_callback;

    auto allreduce_launch_callback = [record] {
      EnqueueTensorCollective(record, MPIRequest::ALLREDUCE);
    };

    // If we are on a CPU, our device context will be null and we can't
    // get a stream to enqueue this on. On a CPU this op is called when the
    // data is already available, so we can just immediately do the
    // allreduce; we don't have to wait for the data to get populated.
#if GOOGLE_CUDA
    auto device_context = context->op_device_context();
    if (device_context == nullptr) {
      allreduce_launch_callback();
    } else {
      auto stream = device_context->stream();
      stream->ThenDoHostCallback(allreduce_launch_callback);
    }
#else
    allreduce_launch_callback();
#endif
  }
};

REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_CPU),
                        MPIAllreduceOp<CPUDevice>);
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_GPU),
                        MPIAllreduceOp<GPUDevice>);
#endif

template <typename Device>
class MPIAllgatherOp : public AsyncOpKernel {
 public:
  explicit MPIAllgatherOp(OpKernelConstruction* context)
      : AsyncOpKernel(context) {}

  // Although this op is handled asynchronously, the ComputeAsync call is
  // very inexpensive. It only sets up a CollectiveOpRecord and places it
  // in the table for the background thread to handle. Thus, we do not need
  // a TF pool thread to perform the op.
  bool IsExpensive() override { return false; }

  void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
    OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done);
    const Tensor* input_tensor = &context->input(0);
    const Tensor* sizing_tensor = &context->input(1);

    // Record allocated on stack so op can fail without memory leak
    CollectiveOpRecord record;
    record.name = name();
    record.context = context;
    record.in_t = input_tensor;
    record.on_gpu = IsGPUDevice<Device>();

    // Construct the output size from the sizing tensor
    size_t output_first_dim = 0;
    if (sizing_tensor->shape().dims() == 0) {
      // 0-dim sizing_tensor implies that the op is just gathering
      // a single element from each rank
      output_first_dim = mpi_global.size;
      for (int i = 0; i < mpi_global.size; i++) {
        record.sizes_vec.push_back(1);
      }
    } else {
      // Collect the total output tensor sizing from the sizing tensor
      // NOTE: The sizing tensor is forced to be placed on the CPU by
      // declaring the input as HostMemory, so it is valid to read it here.
      const int64* sizing_array =
          (const int64*)sizing_tensor->tensor_data().data();
      for (int i = 0; i < mpi_global.size; i++) {
        record.sizes_vec.push_back(sizing_array[i]);
        output_first_dim += sizing_array[i];
      }
    }

    TensorShape output_shape;
    output_shape.AddDim(output_first_dim);
    for (int i = 1; i < input_tensor->shape().dims(); i++) {
      output_shape.AddDim(input_tensor->shape().dim_size(i));
    }

    Tensor* output_tensor;
    OP_REQUIRES_OK_ASYNC(
        context, context->allocate_output(0, output_shape, &output_tensor),
        done);

    record.out_t = output_tensor;
    record.dtype = input_tensor->dtype();

    auto allgather_done_callback = [done, context](StatusOr<Tensor> status) {
      context->SetStatus(status.status());
      done();
    };
    record.callback = allgather_done_callback;

    auto allgather_launch_callback = [record] {
      EnqueueTensorCollective(record, MPIRequest::ALLGATHER);
    };

    // If we are on a CPU, our device context will be null and we can't
    // get a stream to enqueue this on. On a CPU this op is called when the
    // data is already available, so we can just immediately do the
    // allgather; we don't have to wait for the data to get populated.
#if GOOGLE_CUDA
    auto device_context = context->op_device_context();
    if (device_context == nullptr) {
      allgather_launch_callback();
    } else {
      auto stream = device_context->stream();
      stream->ThenDoHostCallback(allgather_launch_callback);
    }
#else
    allgather_launch_callback();
#endif
  }
};

REGISTER_KERNEL_BUILDER(
    Name("MPIAllgather").Device(DEVICE_CPU).HostMemory("sizes"),
    MPIAllgatherOp<CPUDevice>);
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(
    Name("MPIAllgather").Device(DEVICE_GPU).HostMemory("sizes"),
    MPIAllgatherOp<GPUDevice>);
#endif

}  // namespace mpi_collectives
}  // namespace contrib
}  // namespace tensorflow

#endif  // TENSORFLOW_USE_MPI