diff options
17 files changed, 1427 insertions, 1069 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 44662ea79e..51225f34bc 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2707,12 +2707,13 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/allocator_retry.h", "common_runtime/base_collective_executor.h", "common_runtime/bfc_allocator.h", - "common_runtime/broadcaster.h", + "common_runtime/hierarchical_tree_broadcaster.h", "common_runtime/buf_rendezvous.h", "common_runtime/build_graph_options.h", "common_runtime/collective_executor_mgr.h", "common_runtime/collective_param_resolver_local.h", "common_runtime/collective_rma_local.h", + "common_runtime/collective_util.h", "common_runtime/constant_folding.h", "common_runtime/copy_tensor.h", "common_runtime/costmodel_manager.h", @@ -2758,12 +2759,12 @@ tf_cuda_library( "common_runtime/allocator_retry.cc", "common_runtime/base_collective_executor.cc", "common_runtime/bfc_allocator.cc", - "common_runtime/broadcaster.cc", "common_runtime/buf_rendezvous.cc", "common_runtime/build_graph_options.cc", "common_runtime/collective_executor_mgr.cc", "common_runtime/collective_param_resolver_local.cc", "common_runtime/collective_rma_local.cc", + "common_runtime/collective_util.cc", "common_runtime/constant_folding.cc", "common_runtime/copy_tensor.cc", "common_runtime/costmodel_manager.cc", @@ -2778,6 +2779,7 @@ tf_cuda_library( "common_runtime/function.cc", "common_runtime/graph_optimizer.cc", "common_runtime/graph_runner.cc", + "common_runtime/hierarchical_tree_broadcaster.cc", "common_runtime/local_device.cc", "common_runtime/lower_if_op.cc", "common_runtime/memory_types.cc", @@ -3664,10 +3666,10 @@ tf_cc_tests_gpu( ) tf_cc_tests_gpu( - name = "broadcaster_test", + name = "hierarchical_tree_broadcaster_test", size = "small", srcs = [ - "common_runtime/broadcaster_test.cc", + "common_runtime/hierarchical_tree_broadcaster_test.cc", ], linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags(), diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc index 425a628a49..5b01f7fa03 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.cc +++ b/tensorflow/core/common_runtime/base_collective_executor.cc @@ -14,13 +14,28 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/base_collective_executor.h" -#include "tensorflow/core/common_runtime/broadcaster.h" +#include <algorithm> +#include <functional> +#include <utility> + #include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/ring_reducer.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" #define VALUE_IN_DEBUG_STRING false @@ -211,104 +226,67 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, }; Tensor* output = ctx->mutable_output(0); - string error; - switch (col_params.instance.type) { - case REDUCTION_COLLECTIVE: { - // TODO(tucker): support other reduction algorithms, - // e.g. tree-reduce, hybrid tree/ring, delegate-to-NCCL, etc. - const Tensor* input = &ctx->input(0); - RingReducer* reducer = - CreateReducer(ctx, CtxParams(ctx), col_params, exec_key, step_id_, - input, output, &error); - if (!reducer) { - done_safe(errors::Internal(error)); - return; - } - // Run in an I/O thread, so as not to starve the executor threads. - // TODO(tucker): Instead of forking every per-device Collective - // Op off into its own thread, consider queuing them on a - // fixed-size thread-pool dedicated to running CollectiveOps. - SchedClosure([reducer, done_safe]() { - reducer->Run([reducer, done_safe](const Status& s) { - done_safe(s); - delete reducer; - }); - }); - } break; - - case BROADCAST_COLLECTIVE: { - Broadcaster* broadcaster = CreateBroadcaster( - ctx, CtxParams(ctx), col_params, exec_key, step_id_, output, &error); - if (!broadcaster) { - done_safe(errors::Internal(error)); - return; - } - // Run in an I/O thread, so as not to starve the executor threads. - SchedClosure([broadcaster, done_safe]() { - broadcaster->Run([broadcaster, done_safe](const Status& s) { - done_safe(s); - delete broadcaster; - }); - }); - } break; - - default: - done_safe(errors::Internal("Unimplemented CollectiveType ", - col_params.instance.type)); + const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE || + (col_params.instance.type == BROADCAST_COLLECTIVE && + col_params.is_source)) + ? &ctx->input(0) + : nullptr; + CollectiveImplementationInterface* col_impl = nullptr; + Status status = CreateCollective(col_params, &col_impl); + if (!status.ok()) { + done_safe(status); + DCHECK_EQ(nullptr, col_impl); + return; } -} - -RingReducer* BaseCollectiveExecutor::CreateReducer( - OpKernelContext* ctx, OpKernelContext::Params* params, - const CollectiveParams& col_params, const string& exec_key, int64 step_id, - const Tensor* input, Tensor* output, string* error) { - switch (col_params.instance.data_type) { - case DT_INT32: - if (col_params.group.device_type == DEVICE_GPU) { - *error = - "Collective Reduce does not support datatype DT_INT32 on " - "DEVICE_GPU"; - return nullptr; - } - TF_FALLTHROUGH_INTENDED; - case DT_FLOAT: - case DT_DOUBLE: - case DT_INT64: - return new RingReducer(this, dev_mgr_, ctx, params, col_params, exec_key, - step_id, input, output); - break; - default: - *error = strings::StrCat("Collective Reduce does not support datatype ", - col_params.instance.data_type); - return nullptr; + CollectiveContext* col_ctx = + new CollectiveContext(this, dev_mgr_, ctx, CtxParams(ctx), col_params, + exec_key, step_id_, input, output); + status = col_impl->InitializeCollectiveContext(col_ctx); + if (!status.ok()) { + done_safe(status); + delete col_ctx; + delete col_impl; + return; } + // Run in an I/O thread, so as not to starve the executor threads. + // TODO(b/80529858): Instead of forking every per-device Collective + // Op off into its own thread, consider queuing them on a + // fixed-size thread-pool dedicated to running CollectiveOps. + SchedClosure([col_impl, col_ctx, done_safe]() { + col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) { + done_safe(s); + delete col_ctx; + delete col_impl; + }); + }); } -Broadcaster* BaseCollectiveExecutor::CreateBroadcaster( - OpKernelContext* ctx, OpKernelContext::Params* params, - const CollectiveParams& col_params, const string& exec_key, int64 step_id, - Tensor* output, string* error) { +Status BaseCollectiveExecutor::CreateCollective( + const CollectiveParams& col_params, + CollectiveImplementationInterface** col_impl) { + *col_impl = nullptr; + Status status; switch (col_params.instance.data_type) { case DT_INT32: if (col_params.group.device_type == DEVICE_GPU) { - *error = - "Collective Broadcast does not support datatype DT_INT32 on " - "DEVICE_GPU"; - return nullptr; + status = errors::Internal( + "CollectiveImplementation does not support datatype DT_INT32 on " + "DEVICE_GPU"); } TF_FALLTHROUGH_INTENDED; case DT_FLOAT: case DT_DOUBLE: case DT_INT64: { - return new Broadcaster(this, dev_mgr_, ctx, params, col_params, exec_key, - step_id, output); - } break; + status = CollectiveRegistry::Lookup( + col_params.instance.impl_details.collective_name, col_impl); + break; + } default: - *error = - strings::StrCat("Collective Broadcast does not support datatype ", - DataTypeString(col_params.instance.data_type)); - return nullptr; + status = errors::Internal( + "CollectiveImplementation does not support datatype ", + col_params.instance.data_type); } + return status; } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/base_collective_executor.h b/tensorflow/core/common_runtime/base_collective_executor.h index 3af9286264..360ce4db7b 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.h +++ b/tensorflow/core/common_runtime/base_collective_executor.h @@ -15,15 +15,17 @@ limitations under the License. #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_ +#include <memory> #include <string> + #include "tensorflow/core/common_runtime/buf_rendezvous.h" #include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/device_attributes.pb.h" namespace tensorflow { -class Broadcaster; +class CollectiveImplementation; class DeviceMgr; -class RingReducer; +class Device; // Helper interface that aliases regular subfields of a Tensor as separate // Tensors for in-place update. @@ -133,18 +135,8 @@ class BaseCollectiveExecutor : public CollectiveExecutor { std::unique_ptr<PerStepCollectiveRemoteAccess> remote_access_; private: - RingReducer* CreateReducer(OpKernelContext* ctx, - OpKernelContext::Params* params, - const CollectiveParams& col_params, - const string& exec_key, int64 step_id, - const Tensor* input, Tensor* output, - string* error); - - Broadcaster* CreateBroadcaster(OpKernelContext* ctx, - OpKernelContext::Params* params, - const CollectiveParams& col_params, - const string& exec_key, int64 step_id, - Tensor* output, string* error); + Status CreateCollective(const CollectiveParams& col_params, + CollectiveImplementationInterface** col_impl); }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/broadcaster.cc b/tensorflow/core/common_runtime/broadcaster.cc deleted file mode 100644 index e1c6b21939..0000000000 --- a/tensorflow/core/common_runtime/broadcaster.cc +++ /dev/null @@ -1,300 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/common_runtime/broadcaster.h" - -#include "tensorflow/core/common_runtime/collective_rma_local.h" -#include "tensorflow/core/common_runtime/device_mgr.h" -#include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/lib/core/notification.h" -#include "tensorflow/core/platform/env.h" - -// Set true for greater intelligibility of debug mode log messages. -#define READABLE_KEYS false - -namespace tensorflow { - -namespace { -// Key to be used for BufRendezvous by Broadcaster. -string BroadcastBufKey(const string& exec_key, int subdiv, int src_rank, - int dst_rank) { - if (READABLE_KEYS) { - return strings::StrCat("broadcast(", exec_key, "):subdiv(", subdiv, - "):src(", src_rank, "):dst(", dst_rank, ")"); - } else { - // TODO(tucker): Try a denser format, e.g. a 64 or 128 bit hash. - return strings::StrCat(exec_key, ":", subdiv, ":", src_rank, ":", dst_rank); - } -} -} // namespace - -Broadcaster::Broadcaster(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr, - OpKernelContext* ctx, OpKernelContext::Params* params, - const CollectiveParams& col_params, - const string& exec_key, int64 step_id, Tensor* output) - : col_exec_(col_exec), - dev_mgr_(dev_mgr), - ctx_(ctx), - col_params_(col_params), - exec_key_(exec_key), - rank_(col_params.subdiv_rank[0]), - is_source_(col_params.is_source), - output_(output), - done_(nullptr), - device_(nullptr) {} - -void Broadcaster::Run(StatusCallback done) { - // The optimal data transfer choreography is going to very platform dependent. - // That will be addressed by later improvements here or by platform-specific - // overrides of collective broadcast. The initial version is simply - // a binary tree that completely ignores DeviceLocality. - done_ = std::move(done); - - // Get the device for which we're executing and look up its locality. - status_ = dev_mgr_->LookupDevice( - col_params_.instance.device_names[col_params_.default_rank], &device_); - if (!status_.ok()) { - done_(status_); - return; - } - CHECK(device_); - device_locality_ = device_->attributes().locality(); - - RunTree(); -} - -// Binary tree parent/child relations are trivial to calculate, i.e. -// device at rank r is the parent of 2r+1 and 2r+2. The one exception -// is if the source is not rank 0. We treat that case as though the -// source is appended to the front of the rank ordering as well as -// continuing to occupy its current position. Hence we calculate as -// though each device's rank is actually r+1, then subtract 1 again to -// get the descendent ranks. If the source is not rank 0 then its -// descendants include both {0,1} and the descendents of its current -// position. Where a non-0-rank source is a descendent of another -// device, no send to it is necessary. - -/* static*/ -int Broadcaster::TreeRecvFrom(const CollectiveParams& cp, int subdiv) { - DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size())); - int my_rank = cp.subdiv_rank[subdiv]; - if (-1 == my_rank) return -1; - - const auto& impl = cp.instance.impl_details; - DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size())); - int source_rank = impl.subdiv_source_rank[subdiv]; - if (my_rank == source_rank) return -1; - if (source_rank == 0) { - return (my_rank - 1) / 2; - } else { - int predecessor_rank = (my_rank / 2) - 1; - return (predecessor_rank < 0) ? source_rank : predecessor_rank; - } -} - -/* static */ -void Broadcaster::TreeSendTo(const CollectiveParams& cp, int subdiv, - std::vector<int>* targets) { - DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size())); - int my_rank = cp.subdiv_rank[subdiv]; - if (-1 == my_rank) return; - - const auto& impl = cp.instance.impl_details; - DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size())); - int source_rank = impl.subdiv_source_rank[subdiv]; - - int group_size = 0; - for (int i = 0; i < impl.subdiv_permutations[subdiv].size(); i++) { - if (impl.subdiv_permutations[subdiv][i] >= 0) { - group_size++; - } - } - - targets->clear(); - int successor_rank = 0; - if (source_rank == 0) { - successor_rank = (2 * my_rank) + 1; - } else { - successor_rank = (2 * (my_rank + 1)); - } - DCHECK_NE(successor_rank, my_rank); - if (cp.is_source && source_rank != 0) { - // The source sends to rank 0,1 in addition to its positional - // descendants. - if (group_size > 1) { - targets->push_back(0); - } - if (group_size > 2 && source_rank != 1) { - targets->push_back(1); - } - } - for (int i = 0; i < 2; ++i) { - if (successor_rank < group_size && successor_rank != source_rank) { - targets->push_back(successor_rank); - } - ++successor_rank; - } -} - -// Executes a hierarchical tree broadcast. -// Each subdiv is a broadcast between a subset of the devices. -// If there is only one task, there is one subdiv comprising a broadcast between -// all devices belonging to the task. -// If there are n tasks, n>1, then there are n+1 subdivs. In the first (global) -// subdiv, one device from each task participates in a binary tree broadcast. -// Each task receives a copy of the tensor on one device via this broadcast. -// Subsequent subdivs correspond to intra-task broadcasts. Subdiv i+1 -// corresponds to broadcast between all devices on task i. Thus, each task -// participates in at most 2 subdivs. -void Broadcaster::RunTree() { - int num_subdivs = static_cast<int>(col_params_.subdiv_rank.size()); - // TODO(ayushd): this is easily improved when a node participates in both - // first and second subdivision. It would first send to its descendents in - // the first subdiv, then wait until all pending ops are finished before - // sending to descendents in second subdiv. A better implementation would - // collapse the two send blocks. - for (int si = 0; si < num_subdivs; si++) { - int my_rank = col_params_.subdiv_rank[si]; - // If rank is -1, this device does not participate in this subdiv. - if (-1 == my_rank) continue; - int source_rank = col_params_.instance.impl_details.subdiv_source_rank[si]; - if (VLOG_IS_ON(1)) { - string subdiv_buf; - for (int r : col_params_.instance.impl_details.subdiv_permutations[si]) { - strings::StrAppend(&subdiv_buf, r, ","); - } - VLOG(1) << "Running Broadcast tree device=" << device_->name() - << " subdiv=" << si << " perm=" << subdiv_buf - << " my_rank=" << my_rank << " source_rank=" << source_rank; - } - - mutex mu; // also guards status_ while callbacks are pending - int pending_count = 0; // GUARDED_BY(mu) - condition_variable all_done; - - if (my_rank >= 0 && my_rank != source_rank) { - // Begin by receiving the value. - int recv_from_rank = TreeRecvFrom(col_params_, si); - Notification note; - DispatchRecv(si, recv_from_rank, my_rank, output_, - [this, &mu, ¬e](const Status& s) { - mutex_lock l(mu); - status_.Update(s); - note.Notify(); - }); - note.WaitForNotification(); - } - - // Then forward value to all descendent devices. - if (my_rank >= 0 && status_.ok()) { - std::vector<int> send_to_ranks; - TreeSendTo(col_params_, si, &send_to_ranks); - for (int i = 0; i < send_to_ranks.size(); ++i) { - int target_rank = send_to_ranks[i]; - { - mutex_lock l(mu); - ++pending_count; - } - DispatchSend(si, target_rank, my_rank, - (is_source_ ? &ctx_->input(0) : output_), - [this, &mu, &pending_count, &all_done](const Status& s) { - mutex_lock l(mu); - status_.Update(s); - --pending_count; - if (pending_count == 0) { - all_done.notify_all(); - } - }); - } - } - - // For the original source device, we copy input to output if they are - // different. - // If there is only 1 subdiv, we do this in that subdiv. If there is more - // than 1 subdiv, then the original source device will participate in 2 - // subdivs - the global inter-task broadcast and one local intra-task - // broadcast. In this case, we perform the copy in the second subdiv for - // this device. - if (status_.ok() && is_source_ && (1 == num_subdivs || 0 != si)) { - VLOG(2) << "copying input to output for device=" << device_->name() - << " subdiv=" << si; - const Tensor* input = &ctx_->input(0); - if (input != output_ && - (DMAHelper::base(input) != DMAHelper::base(output_))) { - { - mutex_lock l(mu); - ++pending_count; - } - DeviceContext* op_dev_ctx = ctx_->op_device_context(); - CollectiveRemoteAccessLocal::MemCpyAsync( - op_dev_ctx, op_dev_ctx, device_, device_, ctx_->input_alloc_attr(0), - ctx_->output_alloc_attr(0), input, output_, 0, /*stream_index*/ - [this, &mu, &pending_count, &all_done](const Status& s) { - mutex_lock l(mu); - status_.Update(s); - --pending_count; - if (0 == pending_count) { - all_done.notify_all(); - } - }); - } - } - - // Then wait for all pending actions to complete. - { - mutex_lock l(mu); - if (pending_count > 0) { - all_done.wait(l); - } - } - } - VLOG(2) << "device=" << device_->name() << " return status " << status_; - done_(status_); -} - -void Broadcaster::DispatchSend(int subdiv, int dst_rank, int src_rank, - const Tensor* src_tensor, - const StatusCallback& done) { - string send_buf_key = BroadcastBufKey(exec_key_, subdiv, src_rank, dst_rank); - int dst_idx = - col_params_.instance.impl_details.subdiv_permutations[subdiv][dst_rank]; - VLOG(1) << "DispatchSend " << send_buf_key << " from_device " - << device_->name() << " to_device " - << col_params_.instance.device_names[dst_idx] << " subdiv=" << subdiv - << " dst_rank=" << dst_rank << " dst_idx=" << dst_idx; - col_exec_->PostToPeer(col_params_.instance.device_names[dst_idx], - col_params_.instance.task_names[dst_idx], send_buf_key, - device_, ctx_->op_device_context(), - ctx_->output_alloc_attr(0), src_tensor, - device_locality_, done); -} - -void Broadcaster::DispatchRecv(int subdiv, int src_rank, int dst_rank, - Tensor* dst_tensor, const StatusCallback& done) { - string recv_buf_key = BroadcastBufKey(exec_key_, subdiv, src_rank, dst_rank); - int src_idx = - col_params_.instance.impl_details.subdiv_permutations[subdiv][src_rank]; - VLOG(1) << "DispatchRecv " << recv_buf_key << " from_device " - << col_params_.instance.device_names[src_idx] << " to_device " - << device_->name() << " subdiv=" << subdiv << " src_rank=" << src_rank - << " src_idx=" << src_idx; - col_exec_->RecvFromPeer(col_params_.instance.device_names[src_idx], - col_params_.instance.task_names[src_idx], - col_params_.task.is_local[src_idx], recv_buf_key, - device_, ctx_->op_device_context(), - ctx_->output_alloc_attr(0), dst_tensor, - device_locality_, 0 /*stream_index*/, done); -} - -} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc index 2a14493a67..52eedae9b7 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc @@ -14,7 +14,20 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/collective_param_resolver_local.h" +#include <stddef.h> +#include <algorithm> +#include <unordered_map> +#include <utility> + #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -319,206 +332,6 @@ void SortDevicesAndTasks(CollectiveParams* cp) { } } // namespace -int GetDeviceTask(int device_rank, const std::vector<int>& dev_per_task) { - int num_tasks = static_cast<int>(dev_per_task.size()); - int task_lo = 0; - int task_hi; - for (int ti = 0; ti < num_tasks; ti++) { - task_hi = task_lo + dev_per_task[ti]; - if (task_lo <= device_rank && device_rank < task_hi) return ti; - task_lo += dev_per_task[ti]; - } - LOG(FATAL) << "Unexpected device rank " << device_rank << " for " << task_hi - << " devices"; - return -1; -} - -void CollectiveParamResolverLocal::GenerateBcastSubdivPerms( - const string& device, int source_rank, const std::vector<int>& dev_per_task, - CollectiveParams* cp) { - if (VLOG_IS_ON(1)) { - string dpt_buf; - for (int dpt : dev_per_task) strings::StrAppend(&dpt_buf, dpt, ";"); - VLOG(1) << "GenerateBcastSubdivPerms device=" << device - << " source_rank=" << source_rank << " dev_per_task=" << dpt_buf; - } - int num_tasks = cp->group.num_tasks; - // If there is just 1 task, then execute binary tree broadcast over all - // devices. Otherwise, the first subdiv is inter-task broadcast, and then - // there are N more subdivs, where N is #task. - int num_subdivs = num_tasks + (num_tasks > 1 ? 1 : 0); - int total_num_devices = 0; - for (int num_dev : dev_per_task) total_num_devices += num_dev; - - cp->instance.impl_details.subdiv_permutations.resize(num_subdivs); - cp->subdiv_rank.reserve(num_subdivs); - cp->instance.impl_details.subdiv_source_rank.reserve(num_subdivs); - - // Inter-task subdiv. Pick one device from each task - this is the source - // device if it belongs to that task, or device 0 for that task. If a device - // does not participate in the subdiv, set subdiv_rank to -1. - if (num_tasks > 1) { - const int sdi = 0; - std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi]; - CHECK_EQ(perm.size(), 0); - int device_count = 0; - int source_task = GetDeviceTask(source_rank, dev_per_task); - for (int ti = 0; ti < cp->group.num_tasks; ti++) { - bool participate = false; - if (source_task == ti) { - // Source device belongs to this task. - perm.push_back(source_rank); - participate = cp->instance.device_names[source_rank] == device; - } else { - // Source does not belong to this task, choose dev 0. - perm.push_back(device_count); - participate = cp->instance.device_names[device_count] == device; - } - if (participate) cp->subdiv_rank.push_back(ti); - device_count += dev_per_task[ti]; - } - if (cp->subdiv_rank.empty()) cp->subdiv_rank.push_back(-1); - cp->instance.impl_details.subdiv_source_rank.push_back(source_task); - } - - // Intra-task subdivs. Pick all devices in task ti for subdiv sdi. Set - // source to dev 0 for that task if it does not contain original source, else - // set to rank of original source. If a device does not participate in the - // subdiv, set subdiv_rank to -1; - int abs_di = 0; - for (int ti = 0; ti < cp->group.num_tasks; ti++) { - const int sdi = ti + (num_tasks > 1 ? 1 : 0); - std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi]; - CHECK_EQ(perm.size(), 0); - bool participate = false; - int subdiv_source = 0; - for (int di = 0; di < dev_per_task[ti]; di++) { - perm.push_back(abs_di); - if (cp->instance.device_names[abs_di] == device) { - participate = true; - cp->subdiv_rank.push_back(di); - } - if (abs_di == source_rank) subdiv_source = di; - abs_di++; - } - if (!participate) cp->subdiv_rank.push_back(-1); - cp->instance.impl_details.subdiv_source_rank.push_back(subdiv_source); - } - - for (int sri = 0; sri < num_subdivs; sri++) { - CHECK_GE(cp->instance.impl_details.subdiv_source_rank[sri], 0); - } -} - -// Establish the requested number of subdivision permutations based on the -// ring order implicit in the device order. -/*static*/ -void CollectiveParamResolverLocal::GenerateSubdivPerms(const string& device, - int source_rank, - CollectiveParams* cp) { - // Each subdiv permutation is a ring formed by rotating each - // single-task subsequence of devices by an offset. This makes most - // sense when each task has the same number of devices but we can't - // depend on that being the case so we'll compute something that - // works in any case. - - // Start by counting the devices in each task. - // Precondition: device_names must be sorted so that all devices in - // the same task are adjacent. - VLOG(2) << "Sorted task names: " - << str_util::Join(cp->instance.task_names, ", "); - std::vector<int> dev_per_task; - const string* prior_task_name = &cp->instance.task_names[0]; - int dev_count = 1; - for (int di = 1; di < cp->group.group_size; ++di) { - if (cp->instance.task_names[di] != *prior_task_name) { - dev_per_task.push_back(dev_count); - dev_count = 1; - prior_task_name = &cp->instance.task_names[di]; - } else { - ++dev_count; - } - } - dev_per_task.push_back(dev_count); - CHECK_EQ(cp->group.num_tasks, dev_per_task.size()); - - CHECK(cp->instance.type == REDUCTION_COLLECTIVE || - cp->instance.type == BROADCAST_COLLECTIVE); - if (cp->instance.type == REDUCTION_COLLECTIVE) { - // Generate a ring permutation for each requested offset. - CHECK_GT(cp->instance.impl_details.subdiv_offsets.size(), 0); - VLOG(2) << "Setting up perms for cp " << cp << " subdiv_permutations " - << &cp->instance.impl_details.subdiv_permutations; - cp->instance.impl_details.subdiv_permutations.resize( - cp->instance.impl_details.subdiv_offsets.size()); - cp->subdiv_rank.resize(cp->instance.impl_details.subdiv_offsets.size(), -1); - for (int sdi = 0; sdi < cp->instance.impl_details.subdiv_offsets.size(); - ++sdi) { - std::vector<int>& perm = - cp->instance.impl_details.subdiv_permutations[sdi]; - CHECK_EQ(perm.size(), 0); - int offset = cp->instance.impl_details.subdiv_offsets[sdi]; - // A negative subdivision offset is interpreted as follows: - // 1. Reverse the local device ordering. - // 2. Begin the subdivision at abs(offset) in the reversed ordering. - bool reverse = false; - if (offset < 0) { - offset = abs(offset); - reverse = true; - } - int prior_dev_count = 0; // sum over prior worker device counts - for (int ti = 0; ti < cp->group.num_tasks; ++ti) { - for (int di = 0; di < dev_per_task[ti]; ++di) { - int di_offset = (di + offset) % dev_per_task[ti]; - int offset_di = - reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset; - // Device index in global subdivision permutation. - int permuted_di = prior_dev_count + offset_di; - int rank = static_cast<int>(perm.size()); - perm.push_back(permuted_di); - if (cp->instance.device_names[permuted_di] == device) { - CHECK_EQ(permuted_di, cp->default_rank); - cp->subdiv_rank[sdi] = rank; - } - } - prior_dev_count += dev_per_task[ti]; - } - CHECK_EQ(cp->group.group_size, perm.size()); - } - } else if (cp->instance.type == BROADCAST_COLLECTIVE) { - GenerateBcastSubdivPerms(device, source_rank, dev_per_task, cp); - } - - if (VLOG_IS_ON(1)) { - // Log the computed ring order for each subdiv. - string buf; - for (int sdi = 0; - sdi < cp->instance.impl_details.subdiv_permutations.size(); ++sdi) { - buf = strings::StrCat("Subdiv ", sdi, " device order:\n"); - for (int di = 0; - di < cp->instance.impl_details.subdiv_permutations[sdi].size(); - ++di) { - int idx = cp->instance.impl_details.subdiv_permutations[sdi][di]; - if (idx >= 0) { - CHECK_GT(cp->instance.device_names.size(), idx); - strings::StrAppend(&buf, cp->instance.device_names[idx], "\n"); - } - } - strings::StrAppend(&buf, " subdiv_offsets: "); - for (auto o : cp->instance.impl_details.subdiv_offsets) - strings::StrAppend(&buf, o, " "); - strings::StrAppend(&buf, " SubdivRank: "); - for (auto d : cp->subdiv_rank) strings::StrAppend(&buf, d, " "); - if (cp->instance.type == BROADCAST_COLLECTIVE) { - strings::StrAppend(&buf, " subdiv_source_rank: "); - for (auto src : cp->instance.impl_details.subdiv_source_rank) - strings::StrAppend(&buf, src, " "); - } - VLOG(1) << buf; - } - } -} - void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name, CollectiveParams* cp) { cp->task.is_local.resize(cp->group.group_size, false); @@ -785,29 +598,39 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec( // Populate the fields common across task, also default_rank. SetDefaultRank(device, cp); CompleteTaskIsLocal(task_name_, cp); + // TODO(b/113171733): we need a better way to pick the collective + // implementation. The ideal way would depend upon the topology and link + // strength before picking a particular implementation. + cp->instance.impl_details.collective_name = + (cp->instance.type == BROADCAST_COLLECTIVE) ? "HierarchicalTreeBroadcast" + : "RingReduce"; + CollectiveImplementationInterface* col_impl; + Status lookup_status = CollectiveRegistry::LookupParamResolverInstance( + cp->instance.impl_details.collective_name, &col_impl); + if (!lookup_status.ok()) { + done(lookup_status); + return; + } // If broadcast, may need to wait for source discovery. if (cp->instance.type == BROADCAST_COLLECTIVE) { CompleteInstanceSource(ir, cp, is_source, - [this, ir, device, cp, done](InstanceRec* irec) { + [col_impl, ir, device, cp, done](InstanceRec* irec) { CHECK_EQ(ir, irec); Status s; - int source_rank; { mutex_lock l(irec->out_mu); irec->WaitForOutMu(l); s = irec->status; - source_rank = irec->source_rank; + cp->source_rank = irec->source_rank; } if (s.ok()) { - GenerateSubdivPerms(device, source_rank, cp); + s = col_impl->InitializeCollectiveParams(cp); } done(s); }); - return; } else { - GenerateSubdivPerms(device, 0, cp); + done(col_impl->InitializeCollectiveParams(cp)); } - done(Status::OK()); } void CollectiveParamResolverLocal::CompleteInstanceSource(InstanceRec* ir, diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h index 9372fd6272..c5c3497e28 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.h +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h @@ -15,7 +15,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ +#include <functional> +#include <memory> +#include <set> #include <string> +#include <vector> #include "tensorflow/core/framework/collective.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -79,6 +83,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { // Used to complete/verify CollInstance. struct InstanceRec; + typedef std::function<void(InstanceRec*)> IRConsumer; struct InstanceRec { // This structure has two mutexes so that a possibly long @@ -212,18 +217,6 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { void CallbackWithStatus(const InstanceRecCallback& done, InstanceRec* irec) LOCKS_EXCLUDED(irec->out_mu); - friend class CollectiveParamResolverLocalTest; - // Establishes the requested number of subdivision permutations based on the - // ring order implicit in the device order. - static void GenerateSubdivPerms(const string& device, int source_rank, - CollectiveParams* cp); - // Establishes the subdivisions for broadcast op. The first subdiv executes - // binary tree bcast with one device per task. Each subsequent subdiv - // executes intra-task binary tree broadcast. - static void GenerateBcastSubdivPerms(const string& device, int source_rank, - const std::vector<int>& dev_per_task, - CollectiveParams* cp); - const DeviceMgr* dev_mgr_; DeviceResolverInterface* dev_resolver_; // Not owned. string task_name_; diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc index 9ea23b72d2..9e1e2e8d5b 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc @@ -44,31 +44,6 @@ class CollectiveParamResolverLocalTest : public ::testing::Test { task_name)); } - void GenSubdivPerms(const string& device, int source_rank, - CollectiveParams* cp) { - CollectiveParamResolverLocal::GenerateSubdivPerms(device, source_rank, cp); - } - - // Calls GenerateBcastSubdivPerms for device at `device_rank`. Checks if the - // generated subdiv perms, ranks, and source ranks match the expected values. - void BcastSubdivPerms( - CollectiveParams* cp, const std::vector<int>& dev_per_task, - int device_rank, int source_rank, - const std::vector<std::vector<int>>& expected_subdiv_perms, - const std::vector<int>& expected_subdiv_rank, - const std::vector<int>& expected_subdiv_source_rank) { - cp->subdiv_rank.clear(); - cp->instance.impl_details.subdiv_permutations.clear(); - cp->instance.impl_details.subdiv_source_rank.clear(); - CollectiveParamResolverLocal::GenerateBcastSubdivPerms( - cp->instance.device_names[device_rank], source_rank, dev_per_task, cp); - EXPECT_EQ(expected_subdiv_perms, - cp->instance.impl_details.subdiv_permutations); - EXPECT_EQ(expected_subdiv_rank, cp->subdiv_rank); - EXPECT_EQ(expected_subdiv_source_rank, - cp->instance.impl_details.subdiv_source_rank); - } - std::vector<Device*> devices_; std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<DeviceResolverLocal> drl_; @@ -114,7 +89,6 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) { cps[i].instance.device_names[j]); EXPECT_TRUE(cps[i].task.is_local[j]); } - EXPECT_EQ(cps[i].subdiv_rank[0], i); EXPECT_EQ(cps[i].instance.impl_details.subdiv_source_rank.size(), 0); EXPECT_FALSE(cps[i].is_source); EXPECT_EQ(cps[i].default_rank, i); @@ -161,188 +135,10 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) { cps[i].instance.device_names[j]); EXPECT_TRUE(cps[i].task.is_local[j]); } - ASSERT_GT(cps[i].subdiv_rank.size(), 0); - EXPECT_EQ(cps[i].subdiv_rank[0], i); - ASSERT_GT(cps[i].instance.impl_details.subdiv_source_rank.size(), 0); - EXPECT_EQ(cps[i].instance.impl_details.subdiv_source_rank[0], 1); EXPECT_EQ(cps[i].is_source, (i == 1)); EXPECT_EQ(cps[i].default_rank, i); EXPECT_TRUE(cps[i].instance.same_num_devices_per_task); } } -TEST_F(CollectiveParamResolverLocalTest, GenerateSubdivPerms) { - static const int kNumDevsPerTask = 8; - static const int kNumTasks = 3; - static const int kNumDevs = kNumDevsPerTask * kNumTasks; - CollectiveParams cp; - std::vector<string> device_names; - std::vector<string> task_names; - cp.group.group_key = 1; - cp.group.group_size = kNumDevs; - cp.group.device_type = DeviceType("GPU"); - cp.group.num_tasks = kNumTasks; - cp.instance.instance_key = 3; - cp.instance.type = REDUCTION_COLLECTIVE; - cp.instance.data_type = DataType(DT_FLOAT); - cp.instance.shape = TensorShape({5}); - cp.instance.impl_details.subdiv_offsets.push_back(0); - cp.is_source = false; - for (int i = 0; i < kNumDevs; ++i) { - int task_id = i / kNumDevsPerTask; - int dev_id = i % kNumDevsPerTask; - string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id); - task_names.push_back(task_name); - string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id); - device_names.push_back(device_name); - cp.instance.task_names.push_back(task_name); - cp.instance.device_names.push_back(device_name); - } - - int test_rank = 0; - cp.default_rank = test_rank; - cp.instance.impl_details.subdiv_offsets = {0, 4}; - GenSubdivPerms(cp.instance.device_names[test_rank], 0, &cp); - std::vector<int> expected_0 = {0, 1, 2, 3, 4, 5, 6, 7, - 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23}; - std::vector<int> expected_1 = {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, - 8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19}; - for (int i = 0; i < kNumDevs; ++i) { - EXPECT_EQ(expected_0[i], - cp.instance.impl_details.subdiv_permutations[0][i]); - EXPECT_EQ(expected_1[i], - cp.instance.impl_details.subdiv_permutations[1][i]); - } - EXPECT_EQ(0, cp.subdiv_rank[0]); - EXPECT_EQ(4, cp.subdiv_rank[1]); - - test_rank = 3; - cp.default_rank = test_rank; - cp.instance.impl_details.subdiv_offsets = {3, -3}; - cp.instance.impl_details.subdiv_permutations.clear(); - GenSubdivPerms(cp.instance.device_names[test_rank], 0, &cp); - expected_0 = {3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14, - 15, 8, 9, 10, 19, 20, 21, 22, 23, 16, 17, 18}; - expected_1 = {4, 3, 2, 1, 0, 7, 6, 5, 12, 11, 10, 9, - 8, 15, 14, 13, 20, 19, 18, 17, 16, 23, 22, 21}; - for (int i = 0; i < kNumDevs; ++i) { - EXPECT_EQ(expected_0[i], - cp.instance.impl_details.subdiv_permutations[0][i]); - EXPECT_EQ(expected_1[i], - cp.instance.impl_details.subdiv_permutations[1][i]); - } - EXPECT_EQ(0, cp.subdiv_rank[0]); - EXPECT_EQ(1, cp.subdiv_rank[1]); -} - -TEST_F(CollectiveParamResolverLocalTest, GenerateBcastSubdivPerms1Task8GPU) { - CollectiveParams cp; - cp.group.device_type = DeviceType("GPU"); - cp.group.num_tasks = 1; - cp.instance.type = BROADCAST_COLLECTIVE; - for (int i = 0; i < 8; i++) { - string dev_name = - strings::StrCat("/job:worker/replica:0/task:0/device:GPU:", i); - cp.instance.device_names.push_back(dev_name); - } - std::vector<int> dev_per_task = {8}; - - // source 0 device 0 - BcastSubdivPerms(&cp, dev_per_task, 0, 0, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, - {0}); - - // source 2 device 2 - BcastSubdivPerms(&cp, dev_per_task, 2, 2, {{0, 1, 2, 3, 4, 5, 6, 7}}, {2}, - {2}); - - // source 2 device 0 - BcastSubdivPerms(&cp, dev_per_task, 0, 2, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, - {2}); -} - -TEST_F(CollectiveParamResolverLocalTest, GenerateBcastSubdivPerms4Tasks8GPU) { - CollectiveParams cp; - cp.group.device_type = DeviceType("GPU"); - cp.group.num_tasks = 4; - cp.instance.type = BROADCAST_COLLECTIVE; - for (int ti = 0; ti < cp.group.num_tasks; ti++) { - for (int di = 0; di < 8; di++) { - string dev_name = strings::StrCat("/job:worker/replica:0/task:", ti, - "/device:GPU:", di); - cp.instance.device_names.push_back(dev_name); - } - } - std::vector<int> dev_per_task = {8, 8, 8, 8}; - - // source 0 device 0 - BcastSubdivPerms(&cp, dev_per_task, 0, 0, - {{0, 8, 16, 24}, - {0, 1, 2, 3, 4, 5, 6, 7}, - {8, 9, 10, 11, 12, 13, 14, 15}, - {16, 17, 18, 19, 20, 21, 22, 23}, - {24, 25, 26, 27, 28, 29, 30, 31}}, - {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0}); - - // source 2 device 0 - BcastSubdivPerms(&cp, dev_per_task, 0, 2, - {{2, 8, 16, 24}, - {0, 1, 2, 3, 4, 5, 6, 7}, - {8, 9, 10, 11, 12, 13, 14, 15}, - {16, 17, 18, 19, 20, 21, 22, 23}, - {24, 25, 26, 27, 28, 29, 30, 31}}, - {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0}); - - // source 9 device 9 - BcastSubdivPerms(&cp, dev_per_task, 9, 9, - {{0, 9, 16, 24}, - {0, 1, 2, 3, 4, 5, 6, 7}, - {8, 9, 10, 11, 12, 13, 14, 15}, - {16, 17, 18, 19, 20, 21, 22, 23}, - {24, 25, 26, 27, 28, 29, 30, 31}}, - {1, -1, 1, -1, -1}, {1, 0, 1, 0, 0}); -} - -TEST_F(CollectiveParamResolverLocalTest, - GenerateBcastSubdivPerms4TasksVariableGPU) { - CollectiveParams cp; - cp.group.device_type = DeviceType("GPU"); - cp.group.num_tasks = 4; - std::vector<int> dev_per_task = {4, 4, 6, 8}; - for (int ti = 0; ti < cp.group.num_tasks; ti++) { - for (int di = 0; di < dev_per_task[ti]; di++) { - string dev_name = strings::StrCat("/job:worker/replica:0/task:", ti, - "/device:GPU:", di); - cp.instance.device_names.push_back(dev_name); - } - } - - // source 0 device 0 - BcastSubdivPerms(&cp, dev_per_task, 0, 0, - {{0, 4, 8, 14}, - {0, 1, 2, 3}, - {4, 5, 6, 7}, - {8, 9, 10, 11, 12, 13}, - {14, 15, 16, 17, 18, 19, 20, 21}}, - {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0}); - - // source 2 device 0 - BcastSubdivPerms(&cp, dev_per_task, 0, 2, - {{2, 4, 8, 14}, - {0, 1, 2, 3}, - {4, 5, 6, 7}, - {8, 9, 10, 11, 12, 13}, - {14, 15, 16, 17, 18, 19, 20, 21}}, - {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0}); - - // source 9 device 5 - BcastSubdivPerms(&cp, dev_per_task, 5, 9, - {{0, 4, 9, 14}, - {0, 1, 2, 3}, - {4, 5, 6, 7}, - {8, 9, 10, 11, 12, 13}, - {14, 15, 16, 17, 18, 19, 20, 21}}, - {-1, -1, 1, -1, -1}, {2, 0, 0, 1, 0}); -} - } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/collective_util.cc b/tensorflow/core/common_runtime/collective_util.cc new file mode 100644 index 0000000000..195521a078 --- /dev/null +++ b/tensorflow/core/common_runtime/collective_util.cc @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/collective_util.h" + +#include <memory> +#include <vector> + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace collective_util { + +/*static*/ +Status InitializeDeviceAndLocality(const DeviceMgr* dev_mgr, + const string& device_name, Device** device, + DeviceLocality* device_locality) { + if (!dev_mgr) { + return errors::Internal("Required non-null dev_mgr ", dev_mgr, + " for InitializeDeviceAndLocality"); + } + + Status status = dev_mgr->LookupDevice(device_name, device); + if (status.ok()) { + CHECK(*device); + *device_locality = (*device)->attributes().locality(); + } else { + LOG(ERROR) << "Failed to find device " << device_name; + for (auto d : dev_mgr->ListDevices()) { + LOG(ERROR) << "Available devices " << d->name(); + } + } + return status; +} + +/*static*/ +string SubdivPermDebugString(const CollectiveParams& col_params) { + const auto& subdiv_perms = + col_params.instance.impl_details.subdiv_permutations; + string buf; + for (int sdi = 0; sdi < subdiv_perms.size(); ++sdi) { + strings::StrAppend(&buf, "Subdiv ", sdi, " device order:\n"); + for (int di = 0; di < subdiv_perms[sdi].size(); ++di) { + int idx = subdiv_perms[sdi][di]; + if (idx >= 0) { + CHECK_GT(col_params.instance.device_names.size(), idx); + strings::StrAppend(&buf, col_params.instance.device_names[idx], "\n"); + } + } + strings::StrAppend(&buf, " subdiv_offsets: "); + for (auto o : col_params.instance.impl_details.subdiv_offsets) + strings::StrAppend(&buf, o, " "); + strings::StrAppend(&buf, " SubdivRank: "); + for (auto d : col_params.subdiv_rank) strings::StrAppend(&buf, d, " "); + if (col_params.instance.type == BROADCAST_COLLECTIVE) { + strings::StrAppend(&buf, " subdiv_source_rank: "); + for (auto src : col_params.instance.impl_details.subdiv_source_rank) + strings::StrAppend(&buf, src, " "); + } + strings::StrAppend(&buf, "\n"); + } + return buf; +} + +} // namespace collective_util +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/collective_util.h b/tensorflow/core/common_runtime/collective_util.h new file mode 100644 index 0000000000..ebb5731bec --- /dev/null +++ b/tensorflow/core/common_runtime/collective_util.h @@ -0,0 +1,38 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_UTIL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_UTIL_H_ + +#include <string> + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace collective_util { + +Status InitializeDeviceAndLocality(const DeviceMgr* dev_mgr, + const string& device_name, Device** device, + DeviceLocality* device_locality); +string SubdivPermDebugString(const CollectiveParams& col_params); + +} // namespace collective_util +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_UTIL_H_ diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc new file mode 100644 index 0000000000..eae34997d9 --- /dev/null +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc @@ -0,0 +1,440 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h" + +#include <functional> +#include <memory> +#include <string> +#include <utility> + +#include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/common_runtime/collective_util.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" + +// Set true for greater intelligibility of debug mode log messages. +#define READABLE_KEYS false + +namespace tensorflow { + +namespace { +// Key to be used for BufRendezvous by Broadcaster. +string BroadcastBufKey(const string& exec_key, int subdiv, int src_rank, + int dst_rank) { + if (READABLE_KEYS) { + return strings::StrCat("broadcast(", exec_key, "):subdiv(", subdiv, + "):src(", src_rank, "):dst(", dst_rank, ")"); + } else { + // TODO(b/78352018): Try a denser format, e.g. a 64 or 128 bit hash. + return strings::StrCat(exec_key, ":", subdiv, ":", src_rank, ":", dst_rank); + } +} +} // namespace + +HierarchicalTreeBroadcaster::HierarchicalTreeBroadcaster() + : col_ctx_(nullptr), + col_params_(nullptr), + done_(nullptr), + is_source_(false) {} + +int HierarchicalTreeBroadcaster::GetDeviceTask( + int device_rank, const std::vector<int>& dev_per_task) { + int num_tasks = static_cast<int>(dev_per_task.size()); + int task_lo = 0; + int task_hi; + for (int ti = 0; ti < num_tasks; ti++) { + task_hi = task_lo + dev_per_task[ti]; + if (task_lo <= device_rank && device_rank < task_hi) return ti; + task_lo = task_hi; + } + LOG(FATAL) << "Unexpected device rank " << device_rank << " for " << task_hi + << " devices"; + return -1; +} + +Status HierarchicalTreeBroadcaster::InitializeCollectiveParams( + CollectiveParams* col_params) { + CHECK_EQ(col_params->instance.type, BROADCAST_COLLECTIVE); + CHECK_EQ(col_params->instance.impl_details.collective_name, + "HierarchicalTreeBroadcast"); + const string& device_name = + col_params->instance.device_names[col_params->default_rank]; + // Start by counting the devices in each task. + // Precondition: device_names must be sorted so that all devices in + // the same task are adjacent. + VLOG(2) << "Sorted task names: " + << str_util::Join(col_params->instance.task_names, ", "); + std::vector<int> dev_per_task; + const string* prior_task_name = &col_params->instance.task_names[0]; + int dev_count = 1; + for (int di = 1; di < col_params->group.group_size; ++di) { + if (col_params->instance.task_names[di] != *prior_task_name) { + dev_per_task.push_back(dev_count); + dev_count = 1; + prior_task_name = &col_params->instance.task_names[di]; + } else { + ++dev_count; + } + } + dev_per_task.push_back(dev_count); + CHECK_EQ(col_params->group.num_tasks, dev_per_task.size()); + + if (VLOG_IS_ON(2)) { + string dpt_buf; + for (int dpt : dev_per_task) strings::StrAppend(&dpt_buf, dpt, ";"); + VLOG(2) << "HierarchicalTreeBroadcaster::InitializeCollectiveParams device=" + << device_name << " source_rank=" << col_params->source_rank + << " dev_per_task=" << dpt_buf; + } + int num_tasks = col_params->group.num_tasks; + // If there is just 1 task, then execute binary tree broadcast over all + // devices. Otherwise, the first subdiv is inter-task broadcast, and then + // there are N more subdivs, where N is #task. + int num_subdivs = num_tasks + (num_tasks > 1 ? 1 : 0); + int total_num_devices = 0; + for (int num_dev : dev_per_task) total_num_devices += num_dev; + + col_params->instance.impl_details.subdiv_permutations.resize(num_subdivs); + col_params->subdiv_rank.reserve(num_subdivs); + col_params->instance.impl_details.subdiv_source_rank.reserve(num_subdivs); + + // Inter-task subdiv. Pick one device from each task - this is the source + // device if it belongs to that task, or device 0 for that task. If a device + // does not participate in the subdiv, set subdiv_rank to -1. + if (num_tasks > 1) { + const int sdi = 0; + std::vector<int>& perm = + col_params->instance.impl_details.subdiv_permutations[sdi]; + CHECK_EQ(perm.size(), 0); + int device_count = 0; + int source_task = GetDeviceTask(col_params->source_rank, dev_per_task); + for (int ti = 0; ti < col_params->group.num_tasks; ti++) { + bool participate = false; + if (source_task == ti) { + // Source device belongs to this task. + perm.push_back(col_params->source_rank); + participate = + col_params->instance.device_names[col_params->source_rank] == + device_name; + } else { + // Source does not belong to this task, choose dev 0. + perm.push_back(device_count); + participate = + col_params->instance.device_names[device_count] == device_name; + } + if (participate) col_params->subdiv_rank.push_back(ti); + device_count += dev_per_task[ti]; + } + if (col_params->subdiv_rank.empty()) col_params->subdiv_rank.push_back(-1); + col_params->instance.impl_details.subdiv_source_rank.push_back(source_task); + } + + // Intra-task subdivs. Pick all devices in task ti for subdiv sdi. Set + // source to dev 0 for that task if it does not contain original source, else + // set to rank of original source. If a device does not participate in + // the subdiv, set subdiv_rank to -1; + int abs_di = 0; + for (int ti = 0; ti < col_params->group.num_tasks; ti++) { + const int sdi = ti + (num_tasks > 1 ? 1 : 0); + std::vector<int>& perm = + col_params->instance.impl_details.subdiv_permutations[sdi]; + CHECK_EQ(perm.size(), 0); + bool participate = false; + int subdiv_source = 0; + for (int di = 0; di < dev_per_task[ti]; di++) { + perm.push_back(abs_di); + if (col_params->instance.device_names[abs_di] == device_name) { + participate = true; + col_params->subdiv_rank.push_back(di); + } + if (abs_di == col_params->source_rank) subdiv_source = di; + abs_di++; + } + if (!participate) col_params->subdiv_rank.push_back(-1); + col_params->instance.impl_details.subdiv_source_rank.push_back( + subdiv_source); + } + + for (int sri = 0; sri < num_subdivs; sri++) { + CHECK_GE(col_params->instance.impl_details.subdiv_source_rank[sri], 0); + } + + VLOG(2) << collective_util::SubdivPermDebugString(*col_params); + return Status::OK(); +} + +Status HierarchicalTreeBroadcaster::InitializeCollectiveContext( + CollectiveContext* col_ctx) { + CHECK(col_ctx->dev_mgr); + col_ctx_ = col_ctx; + col_params_ = &col_ctx->col_params; + return collective_util::InitializeDeviceAndLocality( + col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device, + &col_ctx->device_locality); +} + +void HierarchicalTreeBroadcaster::Run(StatusCallback done) { + CHECK(col_ctx_); + CHECK(col_params_); + done_ = std::move(done); + is_source_ = col_params_->is_source; + RunTree(); +} + +// Binary tree parent/child relations are trivial to calculate, i.e. +// device at rank r is the parent of 2r+1 and 2r+2. The one exception +// is if the source is not rank 0. We treat that case as though the +// source is appended to the front of the rank ordering as well as +// continuing to occupy its current position. Hence we calculate as +// though each device's rank is actually r+1, then subtract 1 again to +// get the descendent ranks. If the source is not rank 0 then its +// descendants include both {0,1} and the descendents of its current +// position. Where a non-0-rank source is a descendent of another +// device, no send to it is necessary. + +/* static*/ +int HierarchicalTreeBroadcaster::TreeRecvFrom(const CollectiveParams& cp, + int subdiv) { + DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size())); + int my_rank = cp.subdiv_rank[subdiv]; + if (-1 == my_rank) return -1; + + const auto& impl = cp.instance.impl_details; + DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size())); + int source_rank = impl.subdiv_source_rank[subdiv]; + if (my_rank == source_rank) return -1; + if (source_rank == 0) { + return (my_rank - 1) / 2; + } else { + int predecessor_rank = (my_rank / 2) - 1; + return (predecessor_rank < 0) ? source_rank : predecessor_rank; + } +} + +/* static */ +void HierarchicalTreeBroadcaster::TreeSendTo(const CollectiveParams& cp, + int subdiv, + std::vector<int>* targets) { + DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size())); + int my_rank = cp.subdiv_rank[subdiv]; + if (-1 == my_rank) return; + + const auto& impl = cp.instance.impl_details; + DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size())); + int source_rank = impl.subdiv_source_rank[subdiv]; + + int group_size = 0; + for (int i = 0; i < impl.subdiv_permutations[subdiv].size(); i++) { + if (impl.subdiv_permutations[subdiv][i] >= 0) { + group_size++; + } + } + + targets->clear(); + int successor_rank = 0; + if (source_rank == 0) { + successor_rank = (2 * my_rank) + 1; + } else { + successor_rank = (2 * (my_rank + 1)); + } + DCHECK_NE(successor_rank, my_rank); + if (cp.is_source && source_rank != 0) { + // The source sends to rank 0,1 in addition to its positional + // descendants. + if (group_size > 1) { + targets->push_back(0); + } + if (group_size > 2 && source_rank != 1) { + targets->push_back(1); + } + } + for (int i = 0; i < 2; ++i) { + if (successor_rank < group_size && successor_rank != source_rank) { + targets->push_back(successor_rank); + } + ++successor_rank; + } +} + +// Executes a hierarchical tree broadcast. +// Each subdiv is a broadcast between a subset of the devices. +// If there is only one task, there is one subdiv comprising a broadcast between +// all devices belonging to the task. +// If there are n tasks, n>1, then there are n+1 subdivs. In the first (global) +// subdiv, one device from each task participates in a binary tree broadcast. +// Each task receives a copy of the tensor on one device via this broadcast. +// Subsequent subdivs correspond to intra-task broadcasts. Subdiv i+1 +// corresponds to broadcast between all devices on task i. Thus, each task +// participates in at most 2 subdivs. +void HierarchicalTreeBroadcaster::RunTree() { + int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size()); + // TODO(b/78352018): this is easily improved when a node participates in both + // first and second subdivision. It would first send to its descendents in + // the first subdiv, then wait until all pending ops are finished before + // sending to descendents in second subdiv. A better implementation would + // collapse the two send blocks. + for (int si = 0; si < num_subdivs; si++) { + int my_rank = col_params_->subdiv_rank[si]; + // If rank is -1, this device does not participate in this subdiv. + if (-1 == my_rank) continue; + int source_rank = col_params_->instance.impl_details.subdiv_source_rank[si]; + if (VLOG_IS_ON(1)) { + string subdiv_buf; + for (int r : col_params_->instance.impl_details.subdiv_permutations[si]) { + strings::StrAppend(&subdiv_buf, r, ","); + } + VLOG(1) << "Running Broadcast tree device=" << col_ctx_->device_name + << " subdiv=" << si << " perm=" << subdiv_buf + << " my_rank=" << my_rank << " source_rank=" << source_rank; + } + + mutex mu; // also guards status_ while callbacks are pending + int pending_count = 0; // GUARDED_BY(mu) + condition_variable all_done; + + if (my_rank >= 0 && my_rank != source_rank) { + // Begin by receiving the value. + int recv_from_rank = TreeRecvFrom(*col_params_, si); + Notification note; + DispatchRecv(si, recv_from_rank, my_rank, col_ctx_->output, + [this, &mu, ¬e](const Status& s) { + mutex_lock l(mu); + status_.Update(s); + note.Notify(); + }); + note.WaitForNotification(); + } + + // Then forward value to all descendent devices. + if (my_rank >= 0 && status_.ok()) { + std::vector<int> send_to_ranks; + TreeSendTo(*col_params_, si, &send_to_ranks); + for (int i = 0; i < send_to_ranks.size(); ++i) { + int target_rank = send_to_ranks[i]; + { + mutex_lock l(mu); + ++pending_count; + } + DispatchSend(si, target_rank, my_rank, + (is_source_ ? col_ctx_->input : col_ctx_->output), + [this, &mu, &pending_count, &all_done](const Status& s) { + mutex_lock l(mu); + status_.Update(s); + --pending_count; + if (pending_count == 0) { + all_done.notify_all(); + } + }); + } + } + + // For the original source device, we copy input to output if they are + // different. + // If there is only 1 subdiv, we do this in that subdiv. If there is more + // than 1 subdiv, then the original source device will participate in 2 + // subdivs - the global inter-task broadcast and one local intra-task + // broadcast. In this case, we perform the copy in the second subdiv for + // this device. + if (status_.ok() && is_source_ && (1 == num_subdivs || 0 != si)) { + VLOG(2) << "copying input to output for device=" << col_ctx_->device_name + << " subdiv=" << si; + if (col_ctx_->input != col_ctx_->output && + (DMAHelper::base(col_ctx_->input) != + DMAHelper::base(col_ctx_->output))) { + { + mutex_lock l(mu); + ++pending_count; + } + DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context(); + CollectiveRemoteAccessLocal::MemCpyAsync( + op_dev_ctx, op_dev_ctx, col_ctx_->device, col_ctx_->device, + col_ctx_->op_ctx->input_alloc_attr(0), + col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input, + col_ctx_->output, 0, /*stream_index*/ + [this, &mu, &pending_count, &all_done](const Status& s) { + mutex_lock l(mu); + status_.Update(s); + --pending_count; + if (0 == pending_count) { + all_done.notify_all(); + } + }); + } + } + + // Then wait for all pending actions to complete. + { + mutex_lock l(mu); + if (pending_count > 0) { + all_done.wait(l); + } + } + } + VLOG(2) << "device=" << col_ctx_->device_name << " return status " << status_; + done_(status_); +} + +void HierarchicalTreeBroadcaster::DispatchSend(int subdiv, int dst_rank, + int src_rank, + const Tensor* src_tensor, + const StatusCallback& done) { + string send_buf_key = + BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank); + int dst_idx = + col_params_->instance.impl_details.subdiv_permutations[subdiv][dst_rank]; + VLOG(3) << "DispatchSend " << send_buf_key << " from_device " + << col_ctx_->device_name << " to_device " + << col_params_->instance.device_names[dst_idx] << " subdiv=" << subdiv + << " dst_rank=" << dst_rank << " dst_idx=" << dst_idx; + col_ctx_->col_exec->PostToPeer(col_params_->instance.device_names[dst_idx], + col_params_->instance.task_names[dst_idx], + send_buf_key, col_ctx_->device, + col_ctx_->op_ctx->op_device_context(), + col_ctx_->op_ctx->output_alloc_attr(0), + src_tensor, col_ctx_->device_locality, done); +} + +void HierarchicalTreeBroadcaster::DispatchRecv(int subdiv, int src_rank, + int dst_rank, Tensor* dst_tensor, + const StatusCallback& done) { + string recv_buf_key = + BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank); + int src_idx = + col_params_->instance.impl_details.subdiv_permutations[subdiv][src_rank]; + VLOG(3) << "DispatchRecv " << recv_buf_key << " from_device " + << col_params_->instance.device_names[src_idx] << " to_device " + << col_ctx_->device_name << " subdiv=" << subdiv + << " src_rank=" << src_rank << " src_idx=" << src_idx; + col_ctx_->col_exec->RecvFromPeer( + col_params_->instance.device_names[src_idx], + col_params_->instance.task_names[src_idx], + col_params_->task.is_local[src_idx], recv_buf_key, col_ctx_->device, + col_ctx_->op_ctx->op_device_context(), + col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor, + col_ctx_->device_locality, 0 /*stream_index*/, done); +} + +REGISTER_COLLECTIVE(HierarchicalTreeBroadcast, HierarchicalTreeBroadcaster); + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/broadcaster.h b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h index 799228b161..ceb9baad30 100644 --- a/tensorflow/core/common_runtime/broadcaster.h +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h @@ -12,25 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BROADCASTER_H_ -#define TENSORFLOW_CORE_COMMON_RUNTIME_BROADCASTER_H_ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_ #include <vector> + #include "tensorflow/core/common_runtime/base_collective_executor.h" #include "tensorflow/core/framework/collective.h" -#include "tensorflow/core/framework/device_attributes.pb.h" namespace tensorflow { -// Tree-algorithm implementation of collective broadcast. -class Broadcaster { +// Hierarchical tree-algorithm implementation of collective broadcast. +class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface { public: - Broadcaster(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr, - OpKernelContext* ctx, OpKernelContext::Params* params, - const CollectiveParams& col_params, const string& exec_key, - int64 step_id, Tensor* output); + HierarchicalTreeBroadcaster(); + ~HierarchicalTreeBroadcaster() override = default; + + // Establishes the subdiv permutations needed for a hierarchical broadcast. + // If all devices are local, establishes a single subdiv comprising all + // devices. If any devices are on a different task, establishes n+1 subdivs + // for n tasks. + // The first subdiv comprises one device per task which gets the tensor on + // each task. Subdiv i+1 corresponds to a task-local tree-broadcast for task + // i. + Status InitializeCollectiveParams(CollectiveParams* col_params) override; - void Run(StatusCallback done); + // Initializes members of CollectiveContext not yet initialized, i.e. device + // and device_locality. Also saves the CollectiveContext in this object. + Status InitializeCollectiveContext(CollectiveContext* col_ctx) override; + + // Begins async execution of the hierarchical tree broadcast. + // Must be called in a blockable thread. + // TODO(b/80529858): remove the previous warning when we have a dedicated + // collective threadpool. + void Run(StatusCallback done) override; // Returns the rank of the device from which this device should receive // its value, -1 if no value should be received. @@ -42,32 +57,29 @@ class Broadcaster { std::vector<int>* targets); private: + // Get the task to which the device at `device_rank` belongs. + int GetDeviceTask(int device_rank, const std::vector<int>& dev_per_task); + // Sends `src_tensor` asynchronously from this device to device at `dst_rank` // in `subdiv`. Calls `done` upon completion. void DispatchSend(int subdiv, int dst_rank, int src_rank, const Tensor* src_tensor, const StatusCallback& done); + // Receives a tensor into the memory buffer owned by `dst_tensor` at this // device from device at `src_rank` in `subdiv`. Calls `done` upon // completion. void DispatchRecv(int subdiv, int src_rank, int dst_rank, Tensor* dst_tensor, const StatusCallback& done); + // Executes the hierarchical broadcast defined by this op. void RunTree(); - Status status_; - CollectiveExecutor* col_exec_; // Not owned - const DeviceMgr* dev_mgr_; // Not owned - OpKernelContext* ctx_; // Not owned - const CollectiveParams& col_params_; - const string exec_key_; - const int rank_; - const bool is_source_; - Tensor* output_; // Not owned - std::unique_ptr<CollectiveAdapter> ca_; + CollectiveContext* col_ctx_; // Not owned + const CollectiveParams* col_params_; // Not owned StatusCallback done_; - Device* device_; // The device for which this instance labors - DeviceLocality device_locality_; + Status status_; + bool is_source_; }; } // namespace tensorflow -#endif // TENSORFLOW_CORE_COMMON_RUNTIME_BROADCASTER_H_ +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_ diff --git a/tensorflow/core/common_runtime/broadcaster_test.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc index 3960fc6c97..da0e359cf8 100644 --- a/tensorflow/core/common_runtime/broadcaster_test.cc +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/broadcaster.h" +#include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h" #include <algorithm> #include "tensorflow/core/common_runtime/base_collective_executor.h" @@ -41,7 +41,7 @@ static int64 kStepId = 123; // The test harness won't allow a mixture of fixture and non-fixture // tests in one file, so this is a trival fixture for tests that don't -// need the heavy-weight BroadcasterTest fixture. +// need the heavy-weight HierarchicalTreeBroadcasterTest fixture. class TrivialTest : public ::testing::Test { protected: TrivialTest() {} @@ -53,23 +53,23 @@ class TrivialTest : public ::testing::Test { // R = tested rank // RF = receive-from rank // ST = send_to rank vector -#define DEF_TL_TEST(D, S, R, RF, ST) \ - TEST_F(TrivialTest, TreeLinks_##D##Devs_##S##Source_##R##Rank) { \ - CollectiveParams cp; \ - cp.group.group_size = D; \ - cp.instance.impl_details.subdiv_source_rank = {S}; \ - cp.instance.impl_details.subdiv_permutations.push_back( \ - std::vector<int>(D, 0)); \ - cp.subdiv_rank = {R}; \ - cp.is_source = (S == R); \ - EXPECT_EQ(RF, Broadcaster::TreeRecvFrom(cp, 0)); \ - std::vector<int> expected = ST; \ - std::vector<int> send_to; \ - Broadcaster::TreeSendTo(cp, 0, &send_to); \ - ASSERT_EQ(expected.size(), send_to.size()); \ - for (int i = 0; i < expected.size(); ++i) { \ - EXPECT_EQ(expected[i], send_to[i]); \ - } \ +#define DEF_TL_TEST(D, S, R, RF, ST) \ + TEST_F(TrivialTest, TreeLinks_##D##Devs_##S##Source_##R##Rank) { \ + CollectiveParams cp; \ + cp.group.group_size = D; \ + cp.instance.impl_details.subdiv_source_rank = {S}; \ + cp.instance.impl_details.subdiv_permutations.push_back( \ + std::vector<int>(D, 0)); \ + cp.subdiv_rank = {R}; \ + cp.is_source = (S == R); \ + EXPECT_EQ(RF, HierarchicalTreeBroadcaster::TreeRecvFrom(cp, 0)); \ + std::vector<int> expected = ST; \ + std::vector<int> send_to; \ + HierarchicalTreeBroadcaster::TreeSendTo(cp, 0, &send_to); \ + ASSERT_EQ(expected.size(), send_to.size()); \ + for (int i = 0; i < expected.size(); ++i) { \ + EXPECT_EQ(expected[i], send_to[i]); \ + } \ } #define V(...) std::vector<int>({__VA_ARGS__}) @@ -130,7 +130,7 @@ DEF_TL_TEST(8, 7, 7, -1, V(0, 1)) // Wraps CollectiveRemoteAccessLocal with the ability to return an // error status to the N'th action. -// TODO(tucker): factor out of this file and ring_reducer_test.cc +// TODO(b/113171733): factor out of this file and ring_reducer_test.cc // into a single common source. class FailTestRMA : public CollectiveRemoteAccessLocal { public: @@ -187,31 +187,32 @@ class FailTestRMA : public CollectiveRemoteAccessLocal { int fail_after_ GUARDED_BY(mu_); }; -class BroadcasterTest : public ::testing::Test { +class HierarchicalTreeBroadcasterTest : public ::testing::Test { protected: - BroadcasterTest() : device_type_(DEVICE_CPU) {} + HierarchicalTreeBroadcasterTest() : device_type_(DEVICE_CPU) {} - ~BroadcasterTest() override { + ~HierarchicalTreeBroadcasterTest() override { stop_ = true; - for (auto i : instances_) { - delete i; - } + for (auto i : instances_) delete i; if (col_exec_) col_exec_->Unref(); } - void SetUp() override { -#if GOOGLE_CUDA +#ifdef GOOGLE_CUDA + void InitGPUDevices() { auto device_factory = DeviceFactory::GetFactory("GPU"); CHECK(device_factory); SessionOptions options; Status s = device_factory->CreateDevices( options, "/job:worker/replica:0/task:0", &gpu_devices_); CHECK(s.ok()); -#endif } +#endif void Init(int num_workers, int num_devices_per_worker, DataType dtype, const DeviceType& device_type, int fail_after) { +#ifdef GOOGLE_CUDA + InitGPUDevices(); +#endif VLOG(2) << "num_workers=" << num_workers << " num_devices_per_worker=" << num_devices_per_worker; int total_num_devices = num_workers * num_devices_per_worker; @@ -400,8 +401,6 @@ class BroadcasterTest : public ::testing::Test { return GetKernel(node_def, device_type, device); } - void BuildColParams() {} - template <typename T> void RunTest(DataType dtype, const DeviceType& device_type, int num_workers, int num_devices, int tensor_len, int fail_after, @@ -511,10 +510,47 @@ class BroadcasterTest : public ::testing::Test { } } + void RunSubdivPermsTest( + CollectiveParams* cp, + const std::vector<std::vector<int>>& expected_subdiv_perms, + const std::vector<int>& expected_subdiv_rank, + const std::vector<int>& expected_subdiv_source_rank) { + col_exec_ = nullptr; + cp->instance.impl_details.subdiv_permutations.clear(); + cp->subdiv_rank.clear(); + cp->instance.impl_details.subdiv_source_rank.clear(); + // Create a stub broadcaster only for testing param initialization. + HierarchicalTreeBroadcaster broadcaster; + TF_CHECK_OK(broadcaster.InitializeCollectiveParams(cp)); + EXPECT_EQ(expected_subdiv_perms, + cp->instance.impl_details.subdiv_permutations); + EXPECT_EQ(expected_subdiv_rank, cp->subdiv_rank); + EXPECT_EQ(expected_subdiv_source_rank, + cp->instance.impl_details.subdiv_source_rank); + } + + void PrepColParamsForSubdivPermsTest(CollectiveParams* cp, int num_tasks, + int num_gpus) { + cp->group.device_type = DeviceType("GPU"); + cp->group.num_tasks = num_tasks; + cp->group.group_size = num_tasks * num_gpus; + cp->instance.type = BROADCAST_COLLECTIVE; + cp->instance.impl_details.collective_name = "HierarchicalTreeBroadcast"; + for (int ti = 0; ti < num_tasks; ti++) { + string task_name = strings::StrCat("/job:worker/replica:0/task:", ti); + for (int di = 0; di < num_gpus; di++) { + string dev_name = strings::StrCat(task_name, "/device:GPU:", di); + cp->instance.task_names.push_back(task_name); + cp->instance.device_names.push_back(dev_name); + } + } + } + class DeviceInstance { public: DeviceInstance(int rank, const string& dev_name, - const DeviceType& device_type, BroadcasterTest* parent) + const DeviceType& device_type, + HierarchicalTreeBroadcasterTest* parent) : parent_(parent), dev_name_(dev_name), device_type_(device_type), @@ -636,21 +672,20 @@ class BroadcasterTest : public ::testing::Test { ctx.allocate_output(0, tensor_.shape(), &output_tensor_ptr)); } CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0)); + const Tensor* input_tensor_ptr = + col_params_.is_source ? &tensor_ : nullptr; // Prepare a Broadcaster instance. string exec_key = strings::StrCat(col_params_.instance.instance_key, ":0:0"); - Broadcaster broadcaster(parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, - &op_params, col_params_, exec_key, kStepId, - output_tensor_ptr); - - // Start execution in a threadpool then wait for completion. - Notification notification; - broadcaster.Run([this, ¬ification](Status s) { - status_ = s; - notification.Notify(); - }); - notification.WaitForNotification(); + HierarchicalTreeBroadcaster broadcaster; + CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(), + &ctx, &op_params, col_params_, exec_key, + kStepId, input_tensor_ptr, output_tensor_ptr); + TF_CHECK_OK(broadcaster.InitializeCollectiveContext(&col_ctx)); + + // Run the broadcast. + broadcaster.Run([this](Status s) { status_ = s; }); if (status_.ok()) { CHECK(tensor_.CopyFrom(*ctx.mutable_output(0), tensor_.shape())); } @@ -658,15 +693,13 @@ class BroadcasterTest : public ::testing::Test { dev_ctx->Unref(); } - BroadcasterTest* parent_; + HierarchicalTreeBroadcasterTest* parent_; string dev_name_; DeviceType device_type_ = DEVICE_CPU; int rank_; Tensor tensor_; Device* device_; CollectiveParams col_params_; - std::unique_ptr<CollectiveAdapter> ca_; - std::unique_ptr<OpKernelContext> ctx_; Status status_; }; // class DeviceInstance @@ -688,6 +721,118 @@ class BroadcasterTest : public ::testing::Test { int failure_count_ GUARDED_BY(mu_) = 0; }; +TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams1Task8GPU) { + CollectiveParams cp; + PrepColParamsForSubdivPermsTest(&cp, 1, 8); + + // source 0 device 0 + cp.source_rank = 0; + cp.default_rank = 0; + RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {0}); + + // source 2 device 2 + cp.source_rank = 2; + cp.default_rank = 2; + RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {2}, {2}); + + // source 2 device 0 + cp.source_rank = 2; + cp.default_rank = 0; + RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {2}); +} + +TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) { + CollectiveParams cp; + PrepColParamsForSubdivPermsTest(&cp, 4, 8); + + // source 0 device 0 + cp.source_rank = 0; + cp.default_rank = 0; + RunSubdivPermsTest(&cp, + {{0, 8, 16, 24}, + {0, 1, 2, 3, 4, 5, 6, 7}, + {8, 9, 10, 11, 12, 13, 14, 15}, + {16, 17, 18, 19, 20, 21, 22, 23}, + {24, 25, 26, 27, 28, 29, 30, 31}}, + {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0}); + + // source 2 device 0 + cp.source_rank = 2; + cp.default_rank = 0; + RunSubdivPermsTest(&cp, + {{2, 8, 16, 24}, + {0, 1, 2, 3, 4, 5, 6, 7}, + {8, 9, 10, 11, 12, 13, 14, 15}, + {16, 17, 18, 19, 20, 21, 22, 23}, + {24, 25, 26, 27, 28, 29, 30, 31}}, + {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0}); + + // source 9 device 9 + cp.source_rank = 9; + cp.default_rank = 9; + RunSubdivPermsTest(&cp, + {{0, 9, 16, 24}, + {0, 1, 2, 3, 4, 5, 6, 7}, + {8, 9, 10, 11, 12, 13, 14, 15}, + {16, 17, 18, 19, 20, 21, 22, 23}, + {24, 25, 26, 27, 28, 29, 30, 31}}, + {1, -1, 1, -1, -1}, {1, 0, 1, 0, 0}); +} + +TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) { + CollectiveParams cp; + int num_tasks = 4; + cp.group.device_type = DeviceType("GPU"); + cp.group.num_tasks = num_tasks; + cp.group.group_size = 0; + cp.instance.type = BROADCAST_COLLECTIVE; + cp.instance.impl_details.collective_name = "HierarchicalTreeBroadcast"; + std::vector<int> dev_per_task = {4, 4, 6, 8}; + for (int ti = 0; ti < cp.group.num_tasks; ti++) { + string task_name = strings::StrCat("/job:worker/replica:0/task:", ti); + for (int di = 0; di < dev_per_task[ti]; di++) { + string dev_name = strings::StrCat(task_name, "/device:GPU:", di); + cp.instance.task_names.push_back(task_name); + cp.instance.device_names.push_back(dev_name); + cp.group.group_size++; + } + } + + // source 0 device 0 + cp.source_rank = 0; + cp.default_rank = 0; + RunSubdivPermsTest(&cp, + {{0, 4, 8, 14}, + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 10, 11, 12, 13}, + {14, 15, 16, 17, 18, 19, 20, 21}}, + {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0}); + + // source 2 device 0 + cp.source_rank = 2; + cp.default_rank = 0; + RunSubdivPermsTest(&cp, + {{2, 4, 8, 14}, + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 10, 11, 12, 13}, + {14, 15, 16, 17, 18, 19, 20, 21}}, + {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0}); + + // source 9 device 5 + cp.source_rank = 9; + cp.default_rank = 5; + RunSubdivPermsTest(&cp, + {{0, 4, 9, 14}, + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 10, 11, 12, 13}, + {14, 15, 16, 17, 18, 19, 20, 21}}, + {-1, -1, 1, -1, -1}, {2, 0, 0, 1, 0}); +} + +// TODO(b/113171733): change to use TEST_P. // Tests of full broadcast algorithm, with different device and // data types. // B = data element type @@ -697,7 +842,7 @@ class BroadcasterTest : public ::testing::Test { // L = tensor length // A = abort after count #define DEF_TEST(B, T, W, D, L, A, F) \ - TEST_F(BroadcasterTest, \ + TEST_F(HierarchicalTreeBroadcasterTest, \ DaTy##B##_DevTy##T##_Wkr##W##_Dev##D##_Len##L##_Abt##A##_Fw##F) { \ DataType dtype = DT_##B; \ switch (dtype) { \ diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc index e26761703b..bb8eeb141a 100644 --- a/tensorflow/core/common_runtime/ring_reducer.cc +++ b/tensorflow/core/common_runtime/ring_reducer.cc @@ -14,13 +14,29 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/ring_reducer.h" +#include <stdlib.h> +#include <atomic> +#include <functional> +#include <utility> + #include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/common_runtime/collective_util.h" #include "tensorflow/core/common_runtime/copy_tensor.h" +#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" // Set true for greater intelligibility of debug mode log messages. #define READABLE_KEYS false @@ -36,7 +52,8 @@ string RingReduceBufKey(const string& exec_key, int pass, int section, return strings::StrCat("rred(", exec_key, "):pass(", pass, "):section(", section, "):srcrank(", source_rank, ")"); } else { - // TODO(tucker): Try out some kind of denser encoding, e.g. 128 bit hash. + // TODO(b/78352018): Try out some kind of denser encoding, e.g. 128 bit + // hash. return strings::StrCat(exec_key, ":", pass, ":", section, ":", source_rank); } } @@ -65,105 +82,149 @@ RingReducer::RingField* RingReducer::PCQueue::Dequeue() { return rf; } -RingReducer::RingReducer(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr, - OpKernelContext* ctx, - OpKernelContext::Params* op_params, - const CollectiveParams& col_params, - const string& exec_key, int64 step_id, - const Tensor* input, Tensor* output) - : col_exec_(col_exec), - dev_mgr_(dev_mgr), - ctx_(ctx), - op_params_(op_params), - col_params_(col_params), - exec_key_(exec_key), - input_(input), - output_(output), - rank_(col_params.subdiv_rank[0]), - step_id_(step_id), - group_size_(col_params.group.group_size), - num_subdivs_(static_cast<int>( - col_params.instance.impl_details.subdiv_permutations.size())), +RingReducer::RingReducer() + : col_ctx_(nullptr), + col_params_(nullptr), done_(nullptr), - device_(nullptr), - device_name_( - col_params_.instance.device_names[col_params_.default_rank]) { - CHECK_GT(group_size_, 0); - CHECK_GT(num_subdivs_, 0); -} + group_size_(-1), + num_subdivs_(-1) {} RingReducer::~RingReducer() { group_size_tensor_ready_.WaitForNotification(); } -string RingReducer::TensorDebugString(Tensor tensor) { - const DeviceBase::GpuDeviceInfo* gpu_device_info = - ctx_->device()->tensorflow_gpu_device_info(); - if (gpu_device_info) { - Tensor cpu_tensor(tensor.dtype(), tensor.shape()); - Notification note; - gpu_device_info->default_context->CopyDeviceTensorToCPU( - &tensor, "" /*tensor_name*/, device_, &cpu_tensor, - [¬e](const Status& s) { - CHECK(s.ok()); - note.Notify(); - }); - note.WaitForNotification(); - return cpu_tensor.SummarizeValue(64); - } else { - return tensor.SummarizeValue(64); +Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) { + CHECK_EQ(col_params->instance.type, REDUCTION_COLLECTIVE); + CHECK_EQ(col_params->instance.impl_details.collective_name, "RingReduce"); + const string& device_name = + col_params->instance.device_names[col_params->default_rank]; + // Each subdiv permutation is a ring formed by rotating each + // single-task subsequence of devices by an offset. This makes most + // sense when each task has the same number of devices but we can't + // depend on that being the case so we'll compute something that + // works in any case. + + // Start by counting the devices in each task. + // Precondition: device_names must be sorted so that all devices in + // the same task are adjacent. + VLOG(2) << "Sorted task names: " + << str_util::Join(col_params->instance.task_names, ", "); + std::vector<int> dev_per_task; + const string* prior_task_name = &col_params->instance.task_names[0]; + int dev_count = 1; + for (int di = 1; di < col_params->group.group_size; ++di) { + if (col_params->instance.task_names[di] != *prior_task_name) { + dev_per_task.push_back(dev_count); + dev_count = 1; + prior_task_name = &col_params->instance.task_names[di]; + } else { + ++dev_count; + } + } + dev_per_task.push_back(dev_count); + CHECK_EQ(col_params->group.num_tasks, dev_per_task.size()); + + // Generate a ring permutation for each requested offset. + if (col_params->instance.impl_details.subdiv_offsets.empty()) { + return errors::Internal( + "Subdiv offsets should be non-empty for ring reducer, size=", + col_params->instance.impl_details.subdiv_offsets.size()); + } + VLOG(2) << "Setting up perms for col_params " << col_params + << " subdiv_permutations " + << &col_params->instance.impl_details.subdiv_permutations; + col_params->instance.impl_details.subdiv_permutations.resize( + col_params->instance.impl_details.subdiv_offsets.size()); + col_params->subdiv_rank.resize( + col_params->instance.impl_details.subdiv_offsets.size(), -1); + for (int sdi = 0; + sdi < col_params->instance.impl_details.subdiv_offsets.size(); ++sdi) { + std::vector<int>& perm = + col_params->instance.impl_details.subdiv_permutations[sdi]; + CHECK_EQ(perm.size(), 0); + int offset = col_params->instance.impl_details.subdiv_offsets[sdi]; + // A negative subdivision offset is interpreted as follows: + // 1. Reverse the local device ordering. + // 2. Begin the subdivision at abs(offset) in the reversed ordering. + bool reverse = false; + if (offset < 0) { + offset = abs(offset); + reverse = true; + } + int prior_dev_count = 0; // sum over prior worker device counts + for (int ti = 0; ti < col_params->group.num_tasks; ++ti) { + for (int di = 0; di < dev_per_task[ti]; ++di) { + int di_offset = (di + offset) % dev_per_task[ti]; + int offset_di = + reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset; + // Device index in global subdivision permutation. + int permuted_di = prior_dev_count + offset_di; + int rank = static_cast<int>(perm.size()); + perm.push_back(permuted_di); + if (col_params->instance.device_names[permuted_di] == device_name) { + CHECK_EQ(permuted_di, col_params->default_rank); + col_params->subdiv_rank[sdi] = rank; + } + } + prior_dev_count += dev_per_task[ti]; + } + CHECK_EQ(col_params->group.group_size, perm.size()); } + + VLOG(2) << collective_util::SubdivPermDebugString(*col_params); + return Status::OK(); +} + +Status RingReducer::InitializeCollectiveContext(CollectiveContext* col_ctx) { + CHECK(col_ctx->dev_mgr); + col_ctx_ = col_ctx; + col_params_ = &col_ctx->col_params; + return collective_util::InitializeDeviceAndLocality( + col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device, + &col_ctx->device_locality); } void RingReducer::Run(StatusCallback done) { + CHECK(col_ctx_); + CHECK(col_params_); done_ = std::move(done); + group_size_ = col_params_->group.group_size; + num_subdivs_ = static_cast<int>( + col_params_->instance.impl_details.subdiv_permutations.size()); + CHECK_GT(num_subdivs_, 0); - // Get local execution device. if (VLOG_IS_ON(1)) { string buf; - for (int r = 0; r < col_params_.instance.device_names.size(); ++r) { + for (int r = 0; r < col_params_->instance.device_names.size(); ++r) { strings::StrAppend(&buf, "dev ", r, " : ", - col_params_.instance.device_names[r], "\n"); + col_params_->instance.device_names[r], "\n"); } for (int sd = 0; - sd < col_params_.instance.impl_details.subdiv_permutations.size(); + sd < col_params_->instance.impl_details.subdiv_permutations.size(); ++sd) { strings::StrAppend(&buf, "\nsubdiv ", sd, " perm: "); - for (auto x : col_params_.instance.impl_details.subdiv_permutations[sd]) { + for (auto x : + col_params_->instance.impl_details.subdiv_permutations[sd]) { strings::StrAppend(&buf, x, ", "); } } - VLOG(1) << "RingReducer::Run for device " << device_name_ - << " default_rank " << col_params_.default_rank << "\n" + VLOG(1) << "RingReducer::Run for device " << col_ctx_->device_name + << " default_rank " << col_params_->default_rank << "\n" << buf; } - CHECK(dev_mgr_); - Status status = dev_mgr_->LookupDevice( - col_params_.instance.device_names[col_params_.default_rank], &device_); - if (!status.ok()) { - LOG(ERROR) << "Failed to find device " - << col_params_.instance.device_names[col_params_.default_rank]; - for (auto d : dev_mgr_->ListDevices()) { - LOG(ERROR) << "Available device " << d->name(); - } - done_(status); - return; - } - CHECK(device_); - device_locality_ = device_->attributes().locality(); - - VLOG(1) << this << " default_rank " << col_params_.default_rank << " cp " - << &col_params_ << ": " << col_params_.ToString(); // Start by copying input to output if they're not already the same, i.e. if // we're not computing in-place on the input tensor. - if ((input_ != output_) && - (DMAHelper::base(input_) != DMAHelper::base(output_))) { + if ((col_ctx_->input != col_ctx_->output) && + (DMAHelper::base(col_ctx_->input) != DMAHelper::base(col_ctx_->output))) { // We are running in a blockable thread and the callback can't block so // just wait here on the copy. Notification note; + Status status; CollectiveRemoteAccessLocal::MemCpyAsync( - ctx_->input_device_context(0), ctx_->op_device_context(), device_, - device_, ctx_->input_alloc_attr(0), ctx_->output_alloc_attr(0), input_, - output_, 0 /*dev_to_dev_stream_index*/, + col_ctx_->op_ctx->input_device_context(0), + col_ctx_->op_ctx->op_device_context(), col_ctx_->device, + col_ctx_->device, col_ctx_->op_ctx->input_alloc_attr(0), + col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input, + col_ctx_->output, 0 /*dev_to_dev_stream_index*/, [this, ¬e, &status](const Status& s) { status.Update(s); note.Notify(); @@ -177,24 +238,43 @@ void RingReducer::Run(StatusCallback done) { ContinueAfterInputCopy(); } +string RingReducer::TensorDebugString(const Tensor& tensor) { + const DeviceBase::GpuDeviceInfo* gpu_device_info = + col_ctx_->op_ctx->device()->tensorflow_gpu_device_info(); + if (gpu_device_info) { + Tensor cpu_tensor(tensor.dtype(), tensor.shape()); + Notification note; + gpu_device_info->default_context->CopyDeviceTensorToCPU( + &tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor, + [¬e](const Status& s) { + CHECK(s.ok()); + note.Notify(); + }); + note.WaitForNotification(); + return cpu_tensor.SummarizeValue(64); + } else { + return tensor.SummarizeValue(64); + } +} + // Note that this function is blocking and must not run in any thread // which cannot be blocked. void RingReducer::ContinueAfterInputCopy() { - AllocatorAttributes attr = ctx_->output_alloc_attr(0); - ca_.reset(MakeCollectiveAdapter(output_, group_size_ * num_subdivs_, - device_->GetAllocator(attr))); + AllocatorAttributes attr = col_ctx_->op_ctx->output_alloc_attr(0); + ca_.reset(MakeCollectiveAdapter(col_ctx_->output, group_size_ * num_subdivs_, + col_ctx_->device->GetAllocator(attr))); - if (col_params_.final_op) { + if (col_params_->final_op) { // Create an on-device scalar value from group_size_ that may be needed // later. // TODO(tucker): Cache and reuse across invocations? Or maybe the scalar // can be provided to the kernel in host memory? Tensor group_size_val = ca_->Scalar(group_size_); - if (col_params_.group.device_type != "CPU") { - group_size_tensor_ = - ca_->Scalar(device_->GetAllocator(ctx_->input_alloc_attr(0))); - DeviceContext* op_dev_ctx = ctx_->op_device_context(); - op_dev_ctx->CopyCPUTensorToDevice(&group_size_val, device_, + if (col_params_->group.device_type != "CPU") { + group_size_tensor_ = ca_->Scalar(col_ctx_->device->GetAllocator( + col_ctx_->op_ctx->input_alloc_attr(0))); + DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context(); + op_dev_ctx->CopyCPUTensorToDevice(&group_size_val, col_ctx_->device, &group_size_tensor_, [this](const Status& s) { if (!s.ok()) { @@ -231,14 +311,14 @@ void RingReducer::StartAbort(const Status& s) { // cancellation on all of the outstanding CollectiveRemoteAccess // actions. if (abort_started) { - col_exec_->StartAbort(s); + col_ctx_->col_exec->StartAbort(s); } } void RingReducer::Finish(bool ok) { if (ok) { // Recover the output from the adaptor. - ca_->ConsumeFinalValue(output_); + ca_->ConsumeFinalValue(col_ctx_->output); } Status s; { @@ -275,7 +355,7 @@ Status RingReducer::ComputeBinOp(Device* device, OpKernel* op, Tensor* output, // TODO(tucker): Is it possible to cache and reuse these objects? They're // mostly identical inside one device execution. std::unique_ptr<SubContext> sub_ctx( - new SubContext(ctx_, op_params_, op, output, input)); + new SubContext(col_ctx_->op_ctx, col_ctx_->op_params, op, output, input)); device->Compute(op, sub_ctx->sub_ctx_); return sub_ctx->sub_ctx_->status(); } @@ -295,18 +375,18 @@ void RingReducer::InitRingField(RingField* rf, int chunk_idx, int subdiv_idx, rf->chunk_idx = chunk_idx; rf->subdiv_idx = subdiv_idx; rf->sc_idx = field_idx; - rf->rank = col_params_.subdiv_rank[subdiv_idx]; + rf->rank = col_params_->subdiv_rank[subdiv_idx]; rf->second_pass = false; rf->action = RF_INIT; // Recv from the device with preceding rank within the subdivision. int recv_from_rank = (rf->rank + (group_size_ - 1)) % group_size_; int send_to_rank = (rf->rank + 1) % group_size_; - rf->recv_dev_idx = col_params_.instance.impl_details + rf->recv_dev_idx = col_params_->instance.impl_details .subdiv_permutations[subdiv_idx][recv_from_rank]; - int send_dev_idx = col_params_.instance.impl_details + int send_dev_idx = col_params_->instance.impl_details .subdiv_permutations[subdiv_idx][send_to_rank]; - rf->recv_is_remote = !col_params_.task.is_local[rf->recv_dev_idx]; - rf->send_is_remote = !col_params_.task.is_local[send_dev_idx]; + rf->recv_is_remote = !col_params_->task.is_local[rf->recv_dev_idx]; + rf->send_is_remote = !col_params_->task.is_local[send_dev_idx]; if (ca_->ChunkBytes(rf->sc_idx) > 0) { // In pass 0 we skip Recv when rank = chunk_idx rf->do_recv = (rf->chunk_idx != rf->rank); @@ -360,45 +440,47 @@ string RingReducer::RingField::DebugString() const { void RingReducer::DispatchSend(RingField* rf, const StatusCallback& done) { CHECK(rf->do_send); - string send_buf_key = - RingReduceBufKey(exec_key_, rf->second_pass, rf->sc_idx, rf->rank); - VLOG(3) << "DispatchSend rank=" << col_params_.default_rank << " send key " + string send_buf_key = RingReduceBufKey(col_ctx_->exec_key, rf->second_pass, + rf->sc_idx, rf->rank); + VLOG(3) << "DispatchSend rank=" << col_params_->default_rank << " send key " << send_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " sc_idx " << rf->sc_idx; int send_to_rank = (rf->rank + 1) % group_size_; - int send_to_dev_idx = col_params_.instance.impl_details + int send_to_dev_idx = col_params_->instance.impl_details .subdiv_permutations[rf->subdiv_idx][send_to_rank]; - col_exec_->PostToPeer(col_params_.instance.device_names[send_to_dev_idx], - col_params_.instance.task_names[send_to_dev_idx], - send_buf_key, device_, ctx_->op_device_context(), - ctx_->output_alloc_attr(0), &rf->chunk, - device_locality_, done); + col_ctx_->col_exec->PostToPeer( + col_params_->instance.device_names[send_to_dev_idx], + col_params_->instance.task_names[send_to_dev_idx], send_buf_key, + col_ctx_->device, col_ctx_->op_ctx->op_device_context(), + col_ctx_->op_ctx->output_alloc_attr(0), &rf->chunk, + col_ctx_->device_locality, done); } void RingReducer::DispatchRecv(RingField* rf, const StatusCallback& done) { CHECK(rf->do_recv); string recv_buf_key = - RingReduceBufKey(exec_key_, rf->second_pass, rf->sc_idx, + RingReduceBufKey(col_ctx_->exec_key, rf->second_pass, rf->sc_idx, (rf->rank + (group_size_ - 1)) % group_size_); - VLOG(3) << "DispatchRecv rank=" << col_params_.default_rank << " recv key " + VLOG(3) << "DispatchRecv rank=" << col_params_->default_rank << " recv key " << recv_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " into " - << ((col_params_.merge_op != nullptr) ? "tmp_chunk" : "chunk"); - Tensor* dst_tensor = (!rf->second_pass && (col_params_.merge_op != nullptr)) + << ((col_params_->merge_op != nullptr) ? "tmp_chunk" : "chunk"); + Tensor* dst_tensor = (!rf->second_pass && (col_params_->merge_op != nullptr)) ? &rf->tmp_chunk : &rf->chunk; - col_exec_->RecvFromPeer(col_params_.instance.device_names[rf->recv_dev_idx], - col_params_.instance.task_names[rf->recv_dev_idx], - col_params_.task.is_local[rf->recv_dev_idx], - recv_buf_key, device_, ctx_->op_device_context(), - ctx_->output_alloc_attr(0), dst_tensor, - device_locality_, rf->subdiv_idx, done); + col_ctx_->col_exec->RecvFromPeer( + col_params_->instance.device_names[rf->recv_dev_idx], + col_params_->instance.task_names[rf->recv_dev_idx], + col_params_->task.is_local[rf->recv_dev_idx], recv_buf_key, + col_ctx_->device, col_ctx_->op_ctx->op_device_context(), + col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor, + col_ctx_->device_locality, rf->subdiv_idx, done); } string RingReducer::FieldState() { - string s = strings::StrCat("RingReducer ", - strings::Hex(reinterpret_cast<uint64>(this)), - " exec ", exec_key_, " step_id=", step_id_, - " state of all ", rfv_.size(), " fields:"); + string s = strings::StrCat( + "RingReducer ", strings::Hex(reinterpret_cast<uint64>(this)), " exec ", + col_ctx_->exec_key, " step_id=", col_ctx_->step_id, " state of all ", + rfv_.size(), " fields:"); for (int i = 0; i < rfv_.size(); ++i) { s.append("\n"); s.append(rfv_[i].DebugString()); @@ -468,8 +550,9 @@ bool RingReducer::RunAsyncParts() { --recv_pending_count; if (!rf->second_pass) { rf->action = RF_REDUCE; - Status s = ComputeBinOp(device_, col_params_.merge_op.get(), - &rf->chunk, &rf->tmp_chunk); + Status s = + ComputeBinOp(col_ctx_->device, col_params_->merge_op.get(), + &rf->chunk, &rf->tmp_chunk); if (!s.ok()) { aborted = true; StartAbort(s); @@ -479,11 +562,12 @@ bool RingReducer::RunAsyncParts() { } break; case RF_REDUCE: - if (!rf->second_pass && col_params_.final_op.get() && rf->is_final) { + if (!rf->second_pass && col_params_->final_op.get() && rf->is_final) { rf->action = RF_FINALIZE; group_size_tensor_ready_.WaitForNotification(); - Status s = ComputeBinOp(device_, col_params_.final_op.get(), - &rf->chunk, &group_size_tensor_); + Status s = + ComputeBinOp(col_ctx_->device, col_params_->final_op.get(), + &rf->chunk, &group_size_tensor_); if (!s.ok()) { aborted = true; StartAbort(s); @@ -552,9 +636,11 @@ bool RingReducer::RunAsyncParts() { CHECK_EQ(send_pending_count, 0); CHECK_EQ(recv_pending_count, 0); - VLOG(2) << this << " rank=" << rank_ << " finish;" + VLOG(2) << this << " device=" << col_ctx_->device_name << " finish;" << " final value " << TensorDebugString(ca_->Value()); return !aborted; } +REGISTER_COLLECTIVE(RingReduce, RingReducer); + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/ring_reducer.h b/tensorflow/core/common_runtime/ring_reducer.h index 3e1988e787..0848e37b52 100644 --- a/tensorflow/core/common_runtime/ring_reducer.h +++ b/tensorflow/core/common_runtime/ring_reducer.h @@ -16,25 +16,35 @@ limitations under the License. #define TENSORFLOW_CORE_COMMON_RUNTIME_RING_REDUCER_H_ #include <deque> +#include <memory> +#include <string> +#include <vector> #include "tensorflow/core/common_runtime/base_collective_executor.h" #include "tensorflow/core/framework/collective.h" -#include "tensorflow/core/framework/device_attributes.pb.h" namespace tensorflow { -class DeviceMgr; +class Device; // Ring-algorithm implementation of collective all-reduce. -class RingReducer { +class RingReducer : public CollectiveImplementationInterface { public: - RingReducer(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr, - OpKernelContext* ctx, OpKernelContext::Params* op_params, - const CollectiveParams& col_params, const string& exec_key, - int64 step_id, const Tensor* input, Tensor* output); + RingReducer(); + ~RingReducer() override; - virtual ~RingReducer(); + // Establishes the requested number of subdivision permutations based on the + // ring order implicit in the device order. + Status InitializeCollectiveParams(CollectiveParams* col_params) override; - void Run(StatusCallback done); + // Initializes members of CollectiveContext not yet initialized, i.e. device + // and device_locality. Also saves the CollectiveContext in this object. + Status InitializeCollectiveContext(CollectiveContext* col_ctx) override; + + // Begins async execution of the ring reduce algorithm. + // Must be called in a blockable thread. + // TODO(b/80529858): remove the previous warning when we have a dedicated + // collective threadpool. + void Run(StatusCallback done) override; private: // Called when a bad status is received that implies we should terminate @@ -101,7 +111,7 @@ class RingReducer { // For constructing log messages for debugging. string FieldState(); - string TensorDebugString(Tensor tensor); + string TensorDebugString(const Tensor& tensor); // Producer/Consumer Queue of RingField structs. class PCQueue { @@ -116,30 +126,19 @@ class RingReducer { std::deque<RingField*> deque_ GUARDED_BY(pcq_mu_); }; - CollectiveExecutor* col_exec_; // Not owned - const DeviceMgr* dev_mgr_; // Not owned - OpKernelContext* ctx_; // Not owned - OpKernelContext::Params* op_params_; // Not owned - const CollectiveParams& col_params_; - const string exec_key_; - const Tensor* input_; // Not owned - Tensor* output_; // Not owned - const int rank_; - const int64 step_id_; - const int group_size_; - const int num_subdivs_; + CollectiveContext* col_ctx_; // Not owned + const CollectiveParams* col_params_; // Not owned + StatusCallback done_; + int group_size_; + int num_subdivs_; Tensor group_size_tensor_; Notification group_size_tensor_ready_; std::unique_ptr<CollectiveAdapter> ca_; - StatusCallback done_; - Device* device_; // The device for which this instance labors - const string device_name_; - DeviceLocality device_locality_; - mutex status_mu_; Status status_ GUARDED_BY(status_mu_); - std::vector<RingField> rfv_; + + friend class RingReducerTest; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc index fcdf9deff8..5e079dbce6 100644 --- a/tensorflow/core/common_runtime/ring_reducer_test.cc +++ b/tensorflow/core/common_runtime/ring_reducer_test.cc @@ -37,7 +37,6 @@ limitations under the License. #include "tensorflow/core/public/version.h" namespace tensorflow { -namespace { // Wraps CollectiveRemoteAccessLocal with the ability to return an // error status to the N'th action. @@ -135,27 +134,28 @@ class RingReducerTest : public ::testing::Test { protected: RingReducerTest() : device_type_(DEVICE_CPU) {} - void SetUp() override { -#if GOOGLE_CUDA +#ifdef GOOGLE_CUDA + void InitGPUDevices() { auto device_factory = DeviceFactory::GetFactory("GPU"); CHECK(device_factory); SessionOptions options; Status s = device_factory->CreateDevices( options, "/job:worker/replica:0/task:0", &gpu_devices_); CHECK(s.ok()); -#endif } +#endif ~RingReducerTest() override { stop_ = true; - for (auto i : instances_) { - delete i; - } + for (auto i : instances_) delete i; if (col_exec_) col_exec_->Unref(); } void Init(int num_workers, int num_devices, DataType dtype, const DeviceType& device_type, int num_subdivs, int fail_after) { +#ifdef GOOGLE_CUDA + InitGPUDevices(); +#endif device_type_ = device_type; std::vector<Device*> local_devices; SessionOptions sess_opts; @@ -201,6 +201,7 @@ class RingReducerTest : public ::testing::Test { col_params_.instance.instance_key = kInstanceKey; col_params_.instance.impl_details.subdiv_offsets.clear(); col_params_.instance.type = REDUCTION_COLLECTIVE; + col_params_.instance.impl_details.collective_name = "RingReduce"; col_params_.instance.data_type = dtype; col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs); col_params_.subdiv_rank.resize(num_subdivs); @@ -373,6 +374,22 @@ class RingReducerTest : public ::testing::Test { return GetKernel(node_def, device_type, device); } + void RunSubdivPermsTest( + CollectiveParams* cp, + const std::vector<std::vector<int>>& expected_subdiv_perms, + const std::vector<int>& expected_subdiv_rank) { + col_exec_ = nullptr; + cp->instance.impl_details.subdiv_permutations.clear(); + cp->subdiv_rank.clear(); + // Create a stub ring reducer only for testing param initialization. + RingReducer reducer; + TF_CHECK_OK(reducer.InitializeCollectiveParams(cp)); + EXPECT_EQ(expected_subdiv_perms, + cp->instance.impl_details.subdiv_permutations); + EXPECT_EQ(expected_subdiv_rank, cp->subdiv_rank); + reducer.group_size_tensor_ready_.Notify(); // To unblock destructor. + } + class DeviceInstance { public: DeviceInstance(int rank, const string& dev_name, @@ -475,8 +492,8 @@ class RingReducerTest : public ::testing::Test { op_params.op_kernel = op.get(); OpKernelContext ctx(&op_params, 1); - // We never actually execute the kernel, so we need to do the - // output allocation that it would do, ourselves. + // We never actually execute the kernel, so we need to do the output + // allocation it would do, ourselves. Tensor* output_tensor_ptr = nullptr; TF_CHECK_OK(ctx.forward_input_or_allocate_output({0}, 0, tensor_.shape(), &output_tensor_ptr)); @@ -485,20 +502,17 @@ class RingReducerTest : public ::testing::Test { // Prepare a RingReducer instance. string exec_key = strings::StrCat(col_params_.instance.instance_key, ":0:0"); - RingReducer rr(parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, - &op_params, col_params_, exec_key, kStepId, &tensor_, - &tensor_); - - // Start execution in a threadpool then wait for completion. - Notification notification; - SchedClosure([this, ¬ification, &rr]() { - rr.Run([this, ¬ification](Status s) { - status_ = s; - notification.Notify(); - }); - }); - notification.WaitForNotification(); - CHECK(tensor_.CopyFrom(*ctx.mutable_output(0), tensor_.shape())); + RingReducer reducer; + CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(), + &ctx, &op_params, col_params_, exec_key, + kStepId, &tensor_, &tensor_); + TF_CHECK_OK(reducer.InitializeCollectiveContext(&col_ctx)); + + // Run the all-reduce. + reducer.Run([this](Status s) { status_ = s; }); + if (status_.ok()) { + CHECK(tensor_.CopyFrom(*ctx.mutable_output(0), tensor_.shape())); + } dev_ctx->Unref(); } @@ -531,6 +545,57 @@ class RingReducerTest : public ::testing::Test { int32 reduce_counter_ GUARDED_BY(mu_) = 0; }; +TEST_F(RingReducerTest, InitializeParams) { + static const int kNumDevsPerTask = 8; + static const int kNumTasks = 3; + static const int kNumDevs = kNumDevsPerTask * kNumTasks; + CollectiveParams cp; + std::vector<string> device_names; + std::vector<string> task_names; + cp.group.group_key = 1; + cp.group.group_size = kNumDevs; + cp.group.device_type = DeviceType("GPU"); + cp.group.num_tasks = kNumTasks; + cp.instance.instance_key = 3; + cp.instance.type = REDUCTION_COLLECTIVE; + cp.instance.data_type = DataType(DT_FLOAT); + cp.instance.shape = TensorShape({5}); + cp.instance.impl_details.collective_name = "RingReduce"; + cp.instance.impl_details.subdiv_offsets.push_back(0); + cp.is_source = false; + for (int i = 0; i < kNumDevs; ++i) { + int task_id = i / kNumDevsPerTask; + int dev_id = i % kNumDevsPerTask; + string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id); + task_names.push_back(task_name); + string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id); + device_names.push_back(device_name); + cp.instance.task_names.push_back(task_name); + cp.instance.device_names.push_back(device_name); + } + + int test_rank = 0; + cp.default_rank = test_rank; + cp.instance.impl_details.subdiv_offsets = {0, 4}; + RunSubdivPermsTest(&cp, + {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, + 8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19}}, + {0, 4}); + + test_rank = 3; + cp.default_rank = test_rank; + cp.instance.impl_details.subdiv_offsets = {3, -3}; + RunSubdivPermsTest(&cp, + {{3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14, + 15, 8, 9, 10, 19, 20, 21, 22, 23, 16, 17, 18}, + {4, 3, 2, 1, 0, 7, 6, 5, 12, 11, 10, 9, + 8, 15, 14, 13, 20, 19, 18, 17, 16, 23, 22, 21}}, + {0, 1}); +} + +// TODO(b/113171733): change to use TEST_P. #define DEF_TEST(B, T, W, D, S, L, A) \ TEST_F(RingReducerTest, \ DaTy##B##_DevTy##T##_Wkr##W##_Dev##D##_Sdiv##S##_Len##L##_Abrt##A) { \ @@ -604,5 +669,4 @@ DEF_TEST(FLOAT, GPU, 1, 8, 1, 9408, 2) DEF_TEST(FLOAT, GPU, 1, 8, 2, 9408, 5) #endif -} // namespace } // namespace tensorflow diff --git a/tensorflow/core/framework/collective.cc b/tensorflow/core/framework/collective.cc index d4ac50cbbe..4cb277d5a8 100644 --- a/tensorflow/core/framework/collective.cc +++ b/tensorflow/core/framework/collective.cc @@ -21,6 +21,31 @@ limitations under the License. namespace tensorflow { +namespace { +// A RegistrationInfo object stores a collective implementation registration +// details. `factory` is used to create instances of the collective +// implementation. +struct RegistrationInfo { + // This constructor also creates, and stores in `param_resolver_instance`, + // what is effectively a static instance of the collective implementation. + // During param resolution of collective ops we return this static instance. + // The actual op execution gets a fresh instance using `factory`. + RegistrationInfo(const string& n, CollectiveRegistry::Factory f) + : name(n), + factory(std::move(f)), + param_resolver_instance(this->factory()) {} + string name; + CollectiveRegistry::Factory factory; + CollectiveImplementationInterface* param_resolver_instance; +}; + +std::vector<RegistrationInfo>* MutableCollectiveRegistry() { + static std::vector<RegistrationInfo>* registry = + new std::vector<RegistrationInfo>; + return registry; +} +} // namespace + string CollGroupParams::ToString() const { return strings::StrCat("CollGroupParams {group_key=", group_key, " group_size=", group_size, @@ -102,7 +127,8 @@ string CollectiveParams::ToString() const { strings::StrAppend(&v, " ", instance.ToString()); strings::StrAppend(&v, " ", task.ToString()); strings::StrAppend(&v, " default_rank=", default_rank, - " is_source=", is_source, " subdiv_rank={"); + " is_source=", is_source, " source_rank=", source_rank, + " subdiv_rank={"); for (const auto& r : subdiv_rank) { strings::StrAppend(&v, r, ","); } @@ -115,7 +141,81 @@ string CollectiveParams::ToString() const { return ctx->params_; } +CollectiveContext::CollectiveContext(CollectiveExecutor* col_exec, + const DeviceMgr* dev_mgr, + OpKernelContext* ctx, + OpKernelContext::Params* op_params, + const CollectiveParams& col_params, + const string& exec_key, int64 step_id, + const Tensor* input, Tensor* output) + : col_exec(col_exec), + dev_mgr(dev_mgr), + op_ctx(ctx), + op_params(op_params), + col_params(col_params), + exec_key(exec_key), + step_id(step_id), + input(input), + output(output), + device(nullptr), + device_name(col_params.instance.device_names[col_params.default_rank]) {} + /*static*/ int64 CollectiveExecutor::kInvalidId = -1; +/*static*/ +Status CollectiveRegistry::Lookup( + const string& collective_name, + CollectiveImplementationInterface** implementation) { + return LookupHelper(collective_name, implementation, false); +} + +/*static*/ +Status CollectiveRegistry::LookupParamResolverInstance( + const string& collective_name, + CollectiveImplementationInterface** implementation) { + return LookupHelper(collective_name, implementation, true); +} + +/*static*/ +void CollectiveRegistry::GetAll( + std::vector<CollectiveImplementationInterface*>* implementations) { + std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry(); + for (const RegistrationInfo& reg_info : *registry) + implementations->emplace_back(reg_info.factory()); +} + +/*static*/ +Status CollectiveRegistry::Register(const string& collective_name, + Factory factory) { + std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry(); + for (const RegistrationInfo& reg_info : *registry) { + if (reg_info.name == collective_name) + return errors::Internal("Already registered collective ", + collective_name); + } + registry->emplace_back(collective_name, std::move(factory)); + return Status::OK(); +} + +/*static*/ +Status CollectiveRegistry::LookupHelper( + const string& collective_name, + CollectiveImplementationInterface** implementation, bool param_resolver) { + std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry(); + for (const RegistrationInfo& reg_info : *registry) { + if (reg_info.name == collective_name) { + if (param_resolver) { + *implementation = reg_info.param_resolver_instance; + } else { + *implementation = reg_info.factory(); + } + return Status::OK(); + } + } + return errors::Internal( + "CollectiveRegistry::Lookup did not find collective implementation ", + collective_name); +} + } // namespace tensorflow diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h index 0b37b3a88c..e35edb09d0 100644 --- a/tensorflow/core/framework/collective.h +++ b/tensorflow/core/framework/collective.h @@ -18,6 +18,7 @@ limitations under the License. #include <string> #include <vector> +#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/refcount.h" @@ -30,7 +31,8 @@ class CompleteGroupRequest; class CompleteGroupResponse; class CompleteInstanceRequest; class CompleteInstanceResponse; -class DeviceLocality; +class Device; +class DeviceMgr; class GetStepSequenceRequest; class GetStepSequenceResponse; class Op; @@ -64,10 +66,10 @@ struct CollGroupParams { // interpretation. On first execution the runtime will update this // structure with decisions that will guide all subsequent executions. struct CollImplDetails { + string collective_name; std::vector<std::vector<int>> subdiv_permutations; std::vector<int> subdiv_offsets; - // broadcast only: rank of source in each subdiv - std::vector<int> subdiv_source_rank; + std::vector<int> subdiv_source_rank; // rank of source in each subdiv }; // Data common to all members of a collective instance. @@ -104,6 +106,7 @@ struct CollectiveParams { string name = ""; // node name used only for log or error messages int default_rank = -1; // index of this op within device_names bool is_source = false; // broadcast only + int source_rank = -1; // broadcast only // Rank of this device in each subdivision permutation. std::vector<int> subdiv_rank; std::unique_ptr<OpKernel> merge_op; // reduction only @@ -306,6 +309,110 @@ class PerStepCollectiveRemoteAccess : public CollectiveRemoteAccess { virtual void StartAbort(const Status& s) = 0; }; +class CollectiveContext { + public: + CollectiveContext(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr, + OpKernelContext* ctx, OpKernelContext::Params* op_params, + const CollectiveParams& col_params, const string& exec_key, + int64 step_id, const Tensor* input, Tensor* output); + + virtual ~CollectiveContext() = default; + + CollectiveExecutor* col_exec; // Not owned + const DeviceMgr* dev_mgr; // Not owned + OpKernelContext* op_ctx; // Not owned + OpKernelContext::Params* op_params; // Not owned + const CollectiveParams& col_params; + const string exec_key; + const int64 step_id; + const Tensor* input; // Not owned + Tensor* output; // Not owned + Device* device; // The device for which this instance labors + const string device_name; + DeviceLocality device_locality; +}; + +// Interface of a Collective Op implementation. Each specific CollectiveOp will +// implement this interface and register the implementation via the +// CollectiveRegistry detailed below. See common_runtime/ring_reducer and +// common_runtime/hierarchical_tree_broadcaster for examples. +class CollectiveImplementationInterface { + public: + virtual ~CollectiveImplementationInterface() = default; + + // Initializes the portions of `col_params` specific to this + // implementation. Called exactly once for every Collective instance during + // the CollectiveParams resolution process when the graph is first executed. + // NOTE(ayushd): This is effectively a static function because it modifies the + // `col_params` passed in and should not manipulate any data members. However + // because it is virtual and needs to be implemented by every derived class we + // do not mark it as static. + virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0; + + // Prepares the CollectiveContext for executing this CollectiveImplementation. + // Called from CollectiveExecutor right before calling Run(). The + // CollectiveContext passed in must outlive the CollectiveImplementation + // object. + virtual Status InitializeCollectiveContext(CollectiveContext* col_ctx) = 0; + + // Processes and moves data according to the logic of this Collective + // implementation. Relies on appropriate initialization of op-specific + // CollectiveParams in InitializeCollectiveParams(), as well as appropriate + // context initialization in InitializeCollectiveContext(). + virtual void Run(StatusCallback done) = 0; +}; + +// Static-methods only class for registering and looking up collective +// implementations. +class CollectiveRegistry { + public: + using Factory = std::function<CollectiveImplementationInterface*()>; + // Looks up a previously registered CollectiveImplementation under + // `collective_name`. If found, creates an instance of the implementation and + // assign to `implementation`. + static Status Lookup(const string& collective_name, + CollectiveImplementationInterface** implementation); + + // Looks up a previously registered CollectiveImplementation under + // `collective_name`. If found, returns the static instance of this + // implementation via `implementation`. This instance should only be used to + // call InitializateCollectiveParams. + static Status LookupParamResolverInstance( + const string& collective_name, + CollectiveImplementationInterface** implementation); + + // Returns all registered collective implementations. + static void GetAll( + std::vector<CollectiveImplementationInterface*>* implementations); + + private: + friend class CollectiveRegistration; + // Registers a CollectiveImplementation with name `collective_name` and + // factory `factory`. The latter is a function used to create instances of + // the CollectiveImplementation. Also creates a static instance of the + // implementation - this instance is used during param resolution and should + // only be used to call InitializeCollectiveParams. + static Status Register(const string& collective_name, Factory factory); + + static Status LookupHelper(const string& collective_name, + CollectiveImplementationInterface** implementation, + bool param_resolver); +}; + +// Class used to call CollectiveRegistry::Register. This should only be used to +// create a global static object. +class CollectiveRegistration { + public: + CollectiveRegistration(const string& collective_name, + CollectiveRegistry::Factory factory) { + TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory)); + } +}; + +#define REGISTER_COLLECTIVE(name, implementation) \ + static CollectiveRegistration register_##name##_collective( \ + #name, []() { return new implementation; }); + } // namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ |