aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/broadcaster.h
blob: 799228b16170f9c3875b4db298e12cba5a1705f1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BROADCASTER_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_BROADCASTER_H_

#include <vector>
#include "tensorflow/core/common_runtime/base_collective_executor.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/device_attributes.pb.h"

namespace tensorflow {

// Tree-algorithm implementation of collective broadcast.
class Broadcaster {
 public:
  Broadcaster(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
              OpKernelContext* ctx, OpKernelContext::Params* params,
              const CollectiveParams& col_params, const string& exec_key,
              int64 step_id, Tensor* output);

  void Run(StatusCallback done);

  // Returns the rank of the device from which this device should receive
  // its value, -1 if no value should be received.
  static int TreeRecvFrom(const CollectiveParams& cp, int subdiv);

  // Populates targets with the ranks of the devices to which this device
  // should forward the value.
  static void TreeSendTo(const CollectiveParams& cp, int subdiv,
                         std::vector<int>* targets);

 private:
  // Sends `src_tensor` asynchronously from this device to device at `dst_rank`
  // in `subdiv`.  Calls `done` upon completion.
  void DispatchSend(int subdiv, int dst_rank, int src_rank,
                    const Tensor* src_tensor, const StatusCallback& done);
  // Receives a tensor into the memory buffer owned by `dst_tensor` at this
  // device from device at `src_rank` in `subdiv`.  Calls `done` upon
  // completion.
  void DispatchRecv(int subdiv, int src_rank, int dst_rank, Tensor* dst_tensor,
                    const StatusCallback& done);
  // Executes the hierarchical broadcast defined by this op.
  void RunTree();

  Status status_;
  CollectiveExecutor* col_exec_;  // Not owned
  const DeviceMgr* dev_mgr_;      // Not owned
  OpKernelContext* ctx_;          // Not owned
  const CollectiveParams& col_params_;
  const string exec_key_;
  const int rank_;
  const bool is_source_;
  Tensor* output_;  // Not owned
  std::unique_ptr<CollectiveAdapter> ca_;
  StatusCallback done_;
  Device* device_;  // The device for which this instance labors
  DeviceLocality device_locality_;
};

}  // namespace tensorflow
#endif  // TENSORFLOW_CORE_COMMON_RUNTIME_BROADCASTER_H_