diff options
Diffstat (limited to 'tensorflow/core/common_runtime/broadcaster.cc')
-rw-r--r-- | tensorflow/core/common_runtime/broadcaster.cc | 247 |
1 files changed, 151 insertions, 96 deletions
diff --git a/tensorflow/core/common_runtime/broadcaster.cc b/tensorflow/core/common_runtime/broadcaster.cc index 46142d5923..e1c6b21939 100644 --- a/tensorflow/core/common_runtime/broadcaster.cc +++ b/tensorflow/core/common_runtime/broadcaster.cc @@ -27,13 +27,14 @@ namespace tensorflow { namespace { // Key to be used for BufRendezvous by Broadcaster. -string BroadcastBufKey(const string& exec_key, int src_rank, int dst_rank) { +string BroadcastBufKey(const string& exec_key, int subdiv, int src_rank, + int dst_rank) { if (READABLE_KEYS) { - return strings::StrCat("broadcast(", exec_key, "):src(", src_rank, "):dst(", - dst_rank, ")"); + return strings::StrCat("broadcast(", exec_key, "):subdiv(", subdiv, + "):src(", src_rank, "):dst(", dst_rank, ")"); } else { // TODO(tucker): Try a denser format, e.g. a 64 or 128 bit hash. - return strings::StrCat(exec_key, ":", src_rank, ":", dst_rank); + return strings::StrCat(exec_key, ":", subdiv, ":", src_rank, ":", dst_rank); } } } // namespace @@ -85,11 +86,15 @@ void Broadcaster::Run(StatusCallback done) { // device, no send to it is necessary. /* static*/ -int Broadcaster::TreeRecvFrom(const CollectiveParams& cp) { - DCHECK_EQ(1, cp.subdiv_rank.size()); - if (cp.is_source) return -1; - int source_rank = cp.instance.impl_details.subdiv_source_rank[0]; - int my_rank = cp.subdiv_rank[0]; +int Broadcaster::TreeRecvFrom(const CollectiveParams& cp, int subdiv) { + DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size())); + int my_rank = cp.subdiv_rank[subdiv]; + if (-1 == my_rank) return -1; + + const auto& impl = cp.instance.impl_details; + DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size())); + int source_rank = impl.subdiv_source_rank[subdiv]; + if (my_rank == source_rank) return -1; if (source_rank == 0) { return (my_rank - 1) / 2; } else { @@ -99,13 +104,24 @@ int Broadcaster::TreeRecvFrom(const CollectiveParams& cp) { } /* static */ -void Broadcaster::TreeSendTo(const CollectiveParams& cp, +void Broadcaster::TreeSendTo(const CollectiveParams& cp, int subdiv, std::vector<int>* targets) { - DCHECK_EQ(1, cp.subdiv_rank.size()); + DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size())); + int my_rank = cp.subdiv_rank[subdiv]; + if (-1 == my_rank) return; + + const auto& impl = cp.instance.impl_details; + DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size())); + int source_rank = impl.subdiv_source_rank[subdiv]; + + int group_size = 0; + for (int i = 0; i < impl.subdiv_permutations[subdiv].size(); i++) { + if (impl.subdiv_permutations[subdiv][i] >= 0) { + group_size++; + } + } + targets->clear(); - int my_rank = cp.subdiv_rank[0]; - DCHECK_EQ(1, cp.instance.impl_details.subdiv_source_rank.size()); - int source_rank = cp.instance.impl_details.subdiv_source_rank[0]; int successor_rank = 0; if (source_rank == 0) { successor_rank = (2 * my_rank) + 1; @@ -116,108 +132,147 @@ void Broadcaster::TreeSendTo(const CollectiveParams& cp, if (cp.is_source && source_rank != 0) { // The source sends to rank 0,1 in addition to its positional // descendants. - if (cp.group.group_size > 1) { + if (group_size > 1) { targets->push_back(0); } - if (cp.group.group_size > 2 && source_rank != 1) { + if (group_size > 2 && source_rank != 1) { targets->push_back(1); } } for (int i = 0; i < 2; ++i) { - if (successor_rank < cp.group.group_size && successor_rank != source_rank) { + if (successor_rank < group_size && successor_rank != source_rank) { targets->push_back(successor_rank); } ++successor_rank; } } -// Execute a tree broadcast, i.e. each non-source device receives from -// one other and sends to up-to two others. +// Executes a hierarchical tree broadcast. +// Each subdiv is a broadcast between a subset of the devices. +// If there is only one task, there is one subdiv comprising a broadcast between +// all devices belonging to the task. +// If there are n tasks, n>1, then there are n+1 subdivs. In the first (global) +// subdiv, one device from each task participates in a binary tree broadcast. +// Each task receives a copy of the tensor on one device via this broadcast. +// Subsequent subdivs correspond to intra-task broadcasts. Subdiv i+1 +// corresponds to broadcast between all devices on task i. Thus, each task +// participates in at most 2 subdivs. void Broadcaster::RunTree() { - mutex mu; // also guards status_ while callbacks are pending - int pending_count = 0; // GUARDED_BY(mu) - condition_variable all_done; - std::vector<int> send_to_ranks; - TreeSendTo(col_params_, &send_to_ranks); - - if (!is_source_) { - // Begin by receiving the value. - int recv_from_rank = TreeRecvFrom(col_params_); - Notification note; - DispatchRecv(recv_from_rank, output_, - [this, recv_from_rank, &mu, ¬e](const Status& s) { - mutex_lock l(mu); - status_.Update(s); - note.Notify(); - }); - note.WaitForNotification(); - } + int num_subdivs = static_cast<int>(col_params_.subdiv_rank.size()); + // TODO(ayushd): this is easily improved when a node participates in both + // first and second subdivision. It would first send to its descendents in + // the first subdiv, then wait until all pending ops are finished before + // sending to descendents in second subdiv. A better implementation would + // collapse the two send blocks. + for (int si = 0; si < num_subdivs; si++) { + int my_rank = col_params_.subdiv_rank[si]; + // If rank is -1, this device does not participate in this subdiv. + if (-1 == my_rank) continue; + int source_rank = col_params_.instance.impl_details.subdiv_source_rank[si]; + if (VLOG_IS_ON(1)) { + string subdiv_buf; + for (int r : col_params_.instance.impl_details.subdiv_permutations[si]) { + strings::StrAppend(&subdiv_buf, r, ","); + } + VLOG(1) << "Running Broadcast tree device=" << device_->name() + << " subdiv=" << si << " perm=" << subdiv_buf + << " my_rank=" << my_rank << " source_rank=" << source_rank; + } + + mutex mu; // also guards status_ while callbacks are pending + int pending_count = 0; // GUARDED_BY(mu) + condition_variable all_done; - // Then forward value to all descendent devices. - if (status_.ok()) { - for (int i = 0; i < send_to_ranks.size(); ++i) { - int target_rank = send_to_ranks[i]; - { - mutex_lock l(mu); - ++pending_count; + if (my_rank >= 0 && my_rank != source_rank) { + // Begin by receiving the value. + int recv_from_rank = TreeRecvFrom(col_params_, si); + Notification note; + DispatchRecv(si, recv_from_rank, my_rank, output_, + [this, &mu, ¬e](const Status& s) { + mutex_lock l(mu); + status_.Update(s); + note.Notify(); + }); + note.WaitForNotification(); + } + + // Then forward value to all descendent devices. + if (my_rank >= 0 && status_.ok()) { + std::vector<int> send_to_ranks; + TreeSendTo(col_params_, si, &send_to_ranks); + for (int i = 0; i < send_to_ranks.size(); ++i) { + int target_rank = send_to_ranks[i]; + { + mutex_lock l(mu); + ++pending_count; + } + DispatchSend(si, target_rank, my_rank, + (is_source_ ? &ctx_->input(0) : output_), + [this, &mu, &pending_count, &all_done](const Status& s) { + mutex_lock l(mu); + status_.Update(s); + --pending_count; + if (pending_count == 0) { + all_done.notify_all(); + } + }); } - DispatchSend( - target_rank, (is_source_ ? &ctx_->input(0) : output_), - [this, target_rank, &mu, &pending_count, &all_done](const Status& s) { - mutex_lock l(mu); - status_.Update(s); - --pending_count; - if (pending_count == 0) { - all_done.notify_all(); - } - }); } - } - if (status_.ok() && is_source_) { - // Meanwhile, copy input to output if we weren't lucky enough to - // be able to reuse input as output. - const Tensor* input = &ctx_->input(0); - if (input != output_ && - (DMAHelper::base(input) != DMAHelper::base(output_))) { - { - mutex_lock l(mu); - ++pending_count; + // For the original source device, we copy input to output if they are + // different. + // If there is only 1 subdiv, we do this in that subdiv. If there is more + // than 1 subdiv, then the original source device will participate in 2 + // subdivs - the global inter-task broadcast and one local intra-task + // broadcast. In this case, we perform the copy in the second subdiv for + // this device. + if (status_.ok() && is_source_ && (1 == num_subdivs || 0 != si)) { + VLOG(2) << "copying input to output for device=" << device_->name() + << " subdiv=" << si; + const Tensor* input = &ctx_->input(0); + if (input != output_ && + (DMAHelper::base(input) != DMAHelper::base(output_))) { + { + mutex_lock l(mu); + ++pending_count; + } + DeviceContext* op_dev_ctx = ctx_->op_device_context(); + CollectiveRemoteAccessLocal::MemCpyAsync( + op_dev_ctx, op_dev_ctx, device_, device_, ctx_->input_alloc_attr(0), + ctx_->output_alloc_attr(0), input, output_, 0, /*stream_index*/ + [this, &mu, &pending_count, &all_done](const Status& s) { + mutex_lock l(mu); + status_.Update(s); + --pending_count; + if (0 == pending_count) { + all_done.notify_all(); + } + }); } - DeviceContext* op_dev_ctx = ctx_->op_device_context(); - CollectiveRemoteAccessLocal::MemCpyAsync( - op_dev_ctx, op_dev_ctx, device_, device_, ctx_->input_alloc_attr(0), - ctx_->output_alloc_attr(0), input, output_, 0 /*steam_index*/, - [this, &mu, &pending_count, &all_done](const Status& s) { - mutex_lock l(mu); - status_.Update(s); - --pending_count; - if (0 == pending_count) { - all_done.notify_all(); - } - }); } - } - // Then wait for all pending actions to complete. - { - mutex_lock l(mu); - if (pending_count > 0) { - all_done.wait(l); + // Then wait for all pending actions to complete. + { + mutex_lock l(mu); + if (pending_count > 0) { + all_done.wait(l); + } } } - - VLOG(2) << "return status " << status_; + VLOG(2) << "device=" << device_->name() << " return status " << status_; done_(status_); } -void Broadcaster::DispatchSend(int dst_rank, const Tensor* src_tensor, +void Broadcaster::DispatchSend(int subdiv, int dst_rank, int src_rank, + const Tensor* src_tensor, const StatusCallback& done) { - string send_buf_key = BroadcastBufKey(exec_key_, rank_, dst_rank); - VLOG(1) << "DispatchSend " << send_buf_key << " from_device " - << device_->name(); + string send_buf_key = BroadcastBufKey(exec_key_, subdiv, src_rank, dst_rank); int dst_idx = - col_params_.instance.impl_details.subdiv_permutations[0][dst_rank]; + col_params_.instance.impl_details.subdiv_permutations[subdiv][dst_rank]; + VLOG(1) << "DispatchSend " << send_buf_key << " from_device " + << device_->name() << " to_device " + << col_params_.instance.device_names[dst_idx] << " subdiv=" << subdiv + << " dst_rank=" << dst_rank << " dst_idx=" << dst_idx; col_exec_->PostToPeer(col_params_.instance.device_names[dst_idx], col_params_.instance.task_names[dst_idx], send_buf_key, device_, ctx_->op_device_context(), @@ -225,15 +280,15 @@ void Broadcaster::DispatchSend(int dst_rank, const Tensor* src_tensor, device_locality_, done); } -void Broadcaster::DispatchRecv(int src_rank, Tensor* dst_tensor, - const StatusCallback& done) { - string recv_buf_key = BroadcastBufKey(exec_key_, src_rank, rank_); +void Broadcaster::DispatchRecv(int subdiv, int src_rank, int dst_rank, + Tensor* dst_tensor, const StatusCallback& done) { + string recv_buf_key = BroadcastBufKey(exec_key_, subdiv, src_rank, dst_rank); int src_idx = - col_params_.instance.impl_details.subdiv_permutations[0][src_rank]; + col_params_.instance.impl_details.subdiv_permutations[subdiv][src_rank]; VLOG(1) << "DispatchRecv " << recv_buf_key << " from_device " - << col_params_.instance.device_names[src_idx]; - int dst_idx = col_params_.instance.impl_details.subdiv_permutations[0][rank_]; - CHECK_EQ(col_params_.instance.device_names[dst_idx], device_->name()); + << col_params_.instance.device_names[src_idx] << " to_device " + << device_->name() << " subdiv=" << subdiv << " src_rank=" << src_rank + << " src_idx=" << src_idx; col_exec_->RecvFromPeer(col_params_.instance.device_names[src_idx], col_params_.instance.task_names[src_idx], col_params_.task.is_local[src_idx], recv_buf_key, |