/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_ #include #include "tensorflow/core/common_runtime/base_collective_executor.h" #include "tensorflow/core/framework/collective.h" namespace tensorflow { // Hierarchical tree-algorithm implementation of collective broadcast. class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface { public: HierarchicalTreeBroadcaster(); ~HierarchicalTreeBroadcaster() override = default; // Establishes the subdiv permutations needed for a hierarchical broadcast. // If all devices are local, establishes a single subdiv comprising all // devices. If any devices are on a different task, establishes n+1 subdivs // for n tasks. // The first subdiv comprises one device per task which gets the tensor on // each task. Subdiv i+1 corresponds to a task-local tree-broadcast for task // i. Status InitializeCollectiveParams(CollectiveParams* col_params) override; // Initializes members of CollectiveContext not yet initialized, i.e. device // and device_locality. Also saves the CollectiveContext in this object. Status InitializeCollectiveContext(CollectiveContext* col_ctx) override; // Begins async execution of the hierarchical tree broadcast. // Must be called in a blockable thread. // TODO(b/80529858): remove the previous warning when we have a dedicated // collective threadpool. void Run(StatusCallback done) override; // Returns the rank of the device from which this device should receive // its value, -1 if no value should be received. static int TreeRecvFrom(const CollectiveParams& cp, int subdiv); // Populates targets with the ranks of the devices to which this device // should forward the value. static void TreeSendTo(const CollectiveParams& cp, int subdiv, std::vector* targets); private: // Get the task to which the device at `device_rank` belongs. int GetDeviceTask(int device_rank, const std::vector& dev_per_task); // Sends `src_tensor` asynchronously from this device to device at `dst_rank` // in `subdiv`. Calls `done` upon completion. void DispatchSend(int subdiv, int dst_rank, int src_rank, const Tensor* src_tensor, const StatusCallback& done); // Receives a tensor into the memory buffer owned by `dst_tensor` at this // device from device at `src_rank` in `subdiv`. Calls `done` upon // completion. void DispatchRecv(int subdiv, int src_rank, int dst_rank, Tensor* dst_tensor, const StatusCallback& done); // Executes the hierarchical broadcast defined by this op. void RunTree(); CollectiveContext* col_ctx_; // Not owned const CollectiveParams* col_params_; // Not owned StatusCallback done_; Status status_; bool is_source_; }; } // namespace tensorflow #endif // TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_