aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/nccl/kernels/nccl_manager.h
blob: 8d5e5ddf763ccfaa9f6d00eec8ab925ebf28ba87 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_

#ifdef GOOGLE_CUDA

#include <unordered_map>
#include <vector>

#include "external/nccl_archive/src/nccl.h"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor.h"

namespace tensorflow {

// The communicator is used to make the asynchronous communicator calls and to
// manage the per-device streams used for communication.
//
// See nccl_ops.cc for example usage, including description of memory
// management and stream synchronization.
class NcclManager {
 public:
  typedef std::function<void(Status)> DoneCallback;
  NcclManager();
  ~NcclManager();

  static NcclManager* instance();

  // Add one participant to an all-reduce, sending in data from <in_t> and
  // receiving the result of the all-reduce in <out_t>.  The device for this
  // participant is managed by <executor>, and its events are polled by
  // <event_mgr>.
  //
  // This is an asynchronous call. When <done_callback> is called, <out_t> has
  // been set to the all-reduce result (note: the stream may not yet have been
  // synced).
  //
  // <tensor_stream> is the stream that should be waited on to ensure <in_t>'s
  // data is available on the GPU for the communication stream to access. It
  // is also the stream that will use the produced data; <done_callback> is
  // not called until the next kernel launched on <stream> would see the data.
  void AddToAllReduce(int num_devices, const string& key,
                      ncclRedOp_t reduction_op,
                      perftools::gputools::StreamExecutor* executor,
                      EventMgr* event_mgr,
                      perftools::gputools::Stream* tensor_stream,
                      const Tensor* in_t, Tensor* out_t,
                      const DoneCallback& done_callback);

  // AddBroadcastSend and AddBroadcastRecv combine to sent data from one sender
  // to all receivers.
  void AddBroadcastSend(int num_devices, const string& key,
                        perftools::gputools::StreamExecutor* executor,
                        EventMgr* event_mgr,
                        perftools::gputools::Stream* tensor_stream,
                        const Tensor* in_t, DoneCallback done_callback);
  void AddBroadcastRecv(int num_devices, const string& key,
                        perftools::gputools::StreamExecutor* executor,
                        EventMgr* event_mgr,
                        perftools::gputools::Stream* tensor_stream,
                        Tensor* out_t, DoneCallback done_callback);

 private:
  enum CollectiveType {
    kAllReduce = 1,
    kBroadcast = 2,
  };
  struct Collective;
  struct Communicator;
  struct CommunicatorMember;
  struct NcclStream;
  struct Participant;

  Communicator* GetCommunicator(Collective* collective);

  void AddParticipant(int num_devices, const string& key,
                      std::unique_ptr<Participant> participant,
                      DataType data_type, CollectiveType collective_type,
                      ncclRedOp_t reduction_op);

  // Run <collective>.  This calls takes ownership of <collective>.
  void RunCollective(const string& key, Collective* collective);
  void LoopKernelLaunches(NcclStream* stream);

  mutex mu_;

  // Maps key to collectives currently being assembled or run.
  std::unordered_map<string, std::unique_ptr<Collective>> collectives_
      GUARDED_BY(mu_);

  // Maps a device to the communication streams that make up its collective.
  // This is used to share the stream across different communicators that
  // include the same device.
  std::map<perftools::gputools::StreamExecutor*,
           std::vector<std::unique_ptr<NcclStream>>>
      device_to_comm_streams_ GUARDED_BY(mu_);

  std::vector<std::unique_ptr<Communicator>> communicators_;

  TF_DISALLOW_COPY_AND_ASSIGN(NcclManager);
};

}  // namespace tensorflow

#endif  // GOOGLE_CUDA

#endif  // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_