aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/BUILD10
-rw-r--r--tensorflow/core/common_runtime/base_collective_executor.cc148
-rw-r--r--tensorflow/core/common_runtime/base_collective_executor.h20
-rw-r--r--tensorflow/core/common_runtime/broadcaster.cc300
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.cc237
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.h17
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local_test.cc204
-rw-r--r--tensorflow/core/common_runtime/collective_util.cc83
-rw-r--r--tensorflow/core/common_runtime/collective_util.h38
-rw-r--r--tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc440
-rw-r--r--tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h (renamed from tensorflow/core/common_runtime/broadcaster.h)58
-rw-r--r--tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc (renamed from tensorflow/core/common_runtime/broadcaster_test.cc)239
-rw-r--r--tensorflow/core/common_runtime/ring_reducer.cc320
-rw-r--r--tensorflow/core/common_runtime/ring_reducer.h55
-rw-r--r--tensorflow/core/common_runtime/ring_reducer_test.cc112
-rw-r--r--tensorflow/core/framework/collective.cc102
-rw-r--r--tensorflow/core/framework/collective.h113
17 files changed, 1427 insertions, 1069 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 44662ea79e..51225f34bc 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2707,12 +2707,13 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/allocator_retry.h",
"common_runtime/base_collective_executor.h",
"common_runtime/bfc_allocator.h",
- "common_runtime/broadcaster.h",
+ "common_runtime/hierarchical_tree_broadcaster.h",
"common_runtime/buf_rendezvous.h",
"common_runtime/build_graph_options.h",
"common_runtime/collective_executor_mgr.h",
"common_runtime/collective_param_resolver_local.h",
"common_runtime/collective_rma_local.h",
+ "common_runtime/collective_util.h",
"common_runtime/constant_folding.h",
"common_runtime/copy_tensor.h",
"common_runtime/costmodel_manager.h",
@@ -2758,12 +2759,12 @@ tf_cuda_library(
"common_runtime/allocator_retry.cc",
"common_runtime/base_collective_executor.cc",
"common_runtime/bfc_allocator.cc",
- "common_runtime/broadcaster.cc",
"common_runtime/buf_rendezvous.cc",
"common_runtime/build_graph_options.cc",
"common_runtime/collective_executor_mgr.cc",
"common_runtime/collective_param_resolver_local.cc",
"common_runtime/collective_rma_local.cc",
+ "common_runtime/collective_util.cc",
"common_runtime/constant_folding.cc",
"common_runtime/copy_tensor.cc",
"common_runtime/costmodel_manager.cc",
@@ -2778,6 +2779,7 @@ tf_cuda_library(
"common_runtime/function.cc",
"common_runtime/graph_optimizer.cc",
"common_runtime/graph_runner.cc",
+ "common_runtime/hierarchical_tree_broadcaster.cc",
"common_runtime/local_device.cc",
"common_runtime/lower_if_op.cc",
"common_runtime/memory_types.cc",
@@ -3664,10 +3666,10 @@ tf_cc_tests_gpu(
)
tf_cc_tests_gpu(
- name = "broadcaster_test",
+ name = "hierarchical_tree_broadcaster_test",
size = "small",
srcs = [
- "common_runtime/broadcaster_test.cc",
+ "common_runtime/hierarchical_tree_broadcaster_test.cc",
],
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags(),
diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc
index 425a628a49..5b01f7fa03 100644
--- a/tensorflow/core/common_runtime/base_collective_executor.cc
+++ b/tensorflow/core/common_runtime/base_collective_executor.cc
@@ -14,13 +14,28 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/base_collective_executor.h"
-#include "tensorflow/core/common_runtime/broadcaster.h"
+#include <algorithm>
+#include <functional>
+#include <utility>
+
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/ring_reducer.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
#define VALUE_IN_DEBUG_STRING false
@@ -211,104 +226,67 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
};
Tensor* output = ctx->mutable_output(0);
- string error;
- switch (col_params.instance.type) {
- case REDUCTION_COLLECTIVE: {
- // TODO(tucker): support other reduction algorithms,
- // e.g. tree-reduce, hybrid tree/ring, delegate-to-NCCL, etc.
- const Tensor* input = &ctx->input(0);
- RingReducer* reducer =
- CreateReducer(ctx, CtxParams(ctx), col_params, exec_key, step_id_,
- input, output, &error);
- if (!reducer) {
- done_safe(errors::Internal(error));
- return;
- }
- // Run in an I/O thread, so as not to starve the executor threads.
- // TODO(tucker): Instead of forking every per-device Collective
- // Op off into its own thread, consider queuing them on a
- // fixed-size thread-pool dedicated to running CollectiveOps.
- SchedClosure([reducer, done_safe]() {
- reducer->Run([reducer, done_safe](const Status& s) {
- done_safe(s);
- delete reducer;
- });
- });
- } break;
-
- case BROADCAST_COLLECTIVE: {
- Broadcaster* broadcaster = CreateBroadcaster(
- ctx, CtxParams(ctx), col_params, exec_key, step_id_, output, &error);
- if (!broadcaster) {
- done_safe(errors::Internal(error));
- return;
- }
- // Run in an I/O thread, so as not to starve the executor threads.
- SchedClosure([broadcaster, done_safe]() {
- broadcaster->Run([broadcaster, done_safe](const Status& s) {
- done_safe(s);
- delete broadcaster;
- });
- });
- } break;
-
- default:
- done_safe(errors::Internal("Unimplemented CollectiveType ",
- col_params.instance.type));
+ const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE ||
+ (col_params.instance.type == BROADCAST_COLLECTIVE &&
+ col_params.is_source))
+ ? &ctx->input(0)
+ : nullptr;
+ CollectiveImplementationInterface* col_impl = nullptr;
+ Status status = CreateCollective(col_params, &col_impl);
+ if (!status.ok()) {
+ done_safe(status);
+ DCHECK_EQ(nullptr, col_impl);
+ return;
}
-}
-
-RingReducer* BaseCollectiveExecutor::CreateReducer(
- OpKernelContext* ctx, OpKernelContext::Params* params,
- const CollectiveParams& col_params, const string& exec_key, int64 step_id,
- const Tensor* input, Tensor* output, string* error) {
- switch (col_params.instance.data_type) {
- case DT_INT32:
- if (col_params.group.device_type == DEVICE_GPU) {
- *error =
- "Collective Reduce does not support datatype DT_INT32 on "
- "DEVICE_GPU";
- return nullptr;
- }
- TF_FALLTHROUGH_INTENDED;
- case DT_FLOAT:
- case DT_DOUBLE:
- case DT_INT64:
- return new RingReducer(this, dev_mgr_, ctx, params, col_params, exec_key,
- step_id, input, output);
- break;
- default:
- *error = strings::StrCat("Collective Reduce does not support datatype ",
- col_params.instance.data_type);
- return nullptr;
+ CollectiveContext* col_ctx =
+ new CollectiveContext(this, dev_mgr_, ctx, CtxParams(ctx), col_params,
+ exec_key, step_id_, input, output);
+ status = col_impl->InitializeCollectiveContext(col_ctx);
+ if (!status.ok()) {
+ done_safe(status);
+ delete col_ctx;
+ delete col_impl;
+ return;
}
+ // Run in an I/O thread, so as not to starve the executor threads.
+ // TODO(b/80529858): Instead of forking every per-device Collective
+ // Op off into its own thread, consider queuing them on a
+ // fixed-size thread-pool dedicated to running CollectiveOps.
+ SchedClosure([col_impl, col_ctx, done_safe]() {
+ col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
+ done_safe(s);
+ delete col_ctx;
+ delete col_impl;
+ });
+ });
}
-Broadcaster* BaseCollectiveExecutor::CreateBroadcaster(
- OpKernelContext* ctx, OpKernelContext::Params* params,
- const CollectiveParams& col_params, const string& exec_key, int64 step_id,
- Tensor* output, string* error) {
+Status BaseCollectiveExecutor::CreateCollective(
+ const CollectiveParams& col_params,
+ CollectiveImplementationInterface** col_impl) {
+ *col_impl = nullptr;
+ Status status;
switch (col_params.instance.data_type) {
case DT_INT32:
if (col_params.group.device_type == DEVICE_GPU) {
- *error =
- "Collective Broadcast does not support datatype DT_INT32 on "
- "DEVICE_GPU";
- return nullptr;
+ status = errors::Internal(
+ "CollectiveImplementation does not support datatype DT_INT32 on "
+ "DEVICE_GPU");
}
TF_FALLTHROUGH_INTENDED;
case DT_FLOAT:
case DT_DOUBLE:
case DT_INT64: {
- return new Broadcaster(this, dev_mgr_, ctx, params, col_params, exec_key,
- step_id, output);
- } break;
+ status = CollectiveRegistry::Lookup(
+ col_params.instance.impl_details.collective_name, col_impl);
+ break;
+ }
default:
- *error =
- strings::StrCat("Collective Broadcast does not support datatype ",
- DataTypeString(col_params.instance.data_type));
- return nullptr;
+ status = errors::Internal(
+ "CollectiveImplementation does not support datatype ",
+ col_params.instance.data_type);
}
+ return status;
}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/base_collective_executor.h b/tensorflow/core/common_runtime/base_collective_executor.h
index 3af9286264..360ce4db7b 100644
--- a/tensorflow/core/common_runtime/base_collective_executor.h
+++ b/tensorflow/core/common_runtime/base_collective_executor.h
@@ -15,15 +15,17 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_
+#include <memory>
#include <string>
+
#include "tensorflow/core/common_runtime/buf_rendezvous.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
namespace tensorflow {
-class Broadcaster;
+class CollectiveImplementation;
class DeviceMgr;
-class RingReducer;
+class Device;
// Helper interface that aliases regular subfields of a Tensor as separate
// Tensors for in-place update.
@@ -133,18 +135,8 @@ class BaseCollectiveExecutor : public CollectiveExecutor {
std::unique_ptr<PerStepCollectiveRemoteAccess> remote_access_;
private:
- RingReducer* CreateReducer(OpKernelContext* ctx,
- OpKernelContext::Params* params,
- const CollectiveParams& col_params,
- const string& exec_key, int64 step_id,
- const Tensor* input, Tensor* output,
- string* error);
-
- Broadcaster* CreateBroadcaster(OpKernelContext* ctx,
- OpKernelContext::Params* params,
- const CollectiveParams& col_params,
- const string& exec_key, int64 step_id,
- Tensor* output, string* error);
+ Status CreateCollective(const CollectiveParams& col_params,
+ CollectiveImplementationInterface** col_impl);
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/broadcaster.cc b/tensorflow/core/common_runtime/broadcaster.cc
deleted file mode 100644
index e1c6b21939..0000000000
--- a/tensorflow/core/common_runtime/broadcaster.cc
+++ /dev/null
@@ -1,300 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/common_runtime/broadcaster.h"
-
-#include "tensorflow/core/common_runtime/collective_rma_local.h"
-#include "tensorflow/core/common_runtime/device_mgr.h"
-#include "tensorflow/core/common_runtime/dma_helper.h"
-#include "tensorflow/core/lib/core/notification.h"
-#include "tensorflow/core/platform/env.h"
-
-// Set true for greater intelligibility of debug mode log messages.
-#define READABLE_KEYS false
-
-namespace tensorflow {
-
-namespace {
-// Key to be used for BufRendezvous by Broadcaster.
-string BroadcastBufKey(const string& exec_key, int subdiv, int src_rank,
- int dst_rank) {
- if (READABLE_KEYS) {
- return strings::StrCat("broadcast(", exec_key, "):subdiv(", subdiv,
- "):src(", src_rank, "):dst(", dst_rank, ")");
- } else {
- // TODO(tucker): Try a denser format, e.g. a 64 or 128 bit hash.
- return strings::StrCat(exec_key, ":", subdiv, ":", src_rank, ":", dst_rank);
- }
-}
-} // namespace
-
-Broadcaster::Broadcaster(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
- OpKernelContext* ctx, OpKernelContext::Params* params,
- const CollectiveParams& col_params,
- const string& exec_key, int64 step_id, Tensor* output)
- : col_exec_(col_exec),
- dev_mgr_(dev_mgr),
- ctx_(ctx),
- col_params_(col_params),
- exec_key_(exec_key),
- rank_(col_params.subdiv_rank[0]),
- is_source_(col_params.is_source),
- output_(output),
- done_(nullptr),
- device_(nullptr) {}
-
-void Broadcaster::Run(StatusCallback done) {
- // The optimal data transfer choreography is going to very platform dependent.
- // That will be addressed by later improvements here or by platform-specific
- // overrides of collective broadcast. The initial version is simply
- // a binary tree that completely ignores DeviceLocality.
- done_ = std::move(done);
-
- // Get the device for which we're executing and look up its locality.
- status_ = dev_mgr_->LookupDevice(
- col_params_.instance.device_names[col_params_.default_rank], &device_);
- if (!status_.ok()) {
- done_(status_);
- return;
- }
- CHECK(device_);
- device_locality_ = device_->attributes().locality();
-
- RunTree();
-}
-
-// Binary tree parent/child relations are trivial to calculate, i.e.
-// device at rank r is the parent of 2r+1 and 2r+2. The one exception
-// is if the source is not rank 0. We treat that case as though the
-// source is appended to the front of the rank ordering as well as
-// continuing to occupy its current position. Hence we calculate as
-// though each device's rank is actually r+1, then subtract 1 again to
-// get the descendent ranks. If the source is not rank 0 then its
-// descendants include both {0,1} and the descendents of its current
-// position. Where a non-0-rank source is a descendent of another
-// device, no send to it is necessary.
-
-/* static*/
-int Broadcaster::TreeRecvFrom(const CollectiveParams& cp, int subdiv) {
- DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
- int my_rank = cp.subdiv_rank[subdiv];
- if (-1 == my_rank) return -1;
-
- const auto& impl = cp.instance.impl_details;
- DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
- int source_rank = impl.subdiv_source_rank[subdiv];
- if (my_rank == source_rank) return -1;
- if (source_rank == 0) {
- return (my_rank - 1) / 2;
- } else {
- int predecessor_rank = (my_rank / 2) - 1;
- return (predecessor_rank < 0) ? source_rank : predecessor_rank;
- }
-}
-
-/* static */
-void Broadcaster::TreeSendTo(const CollectiveParams& cp, int subdiv,
- std::vector<int>* targets) {
- DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
- int my_rank = cp.subdiv_rank[subdiv];
- if (-1 == my_rank) return;
-
- const auto& impl = cp.instance.impl_details;
- DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
- int source_rank = impl.subdiv_source_rank[subdiv];
-
- int group_size = 0;
- for (int i = 0; i < impl.subdiv_permutations[subdiv].size(); i++) {
- if (impl.subdiv_permutations[subdiv][i] >= 0) {
- group_size++;
- }
- }
-
- targets->clear();
- int successor_rank = 0;
- if (source_rank == 0) {
- successor_rank = (2 * my_rank) + 1;
- } else {
- successor_rank = (2 * (my_rank + 1));
- }
- DCHECK_NE(successor_rank, my_rank);
- if (cp.is_source && source_rank != 0) {
- // The source sends to rank 0,1 in addition to its positional
- // descendants.
- if (group_size > 1) {
- targets->push_back(0);
- }
- if (group_size > 2 && source_rank != 1) {
- targets->push_back(1);
- }
- }
- for (int i = 0; i < 2; ++i) {
- if (successor_rank < group_size && successor_rank != source_rank) {
- targets->push_back(successor_rank);
- }
- ++successor_rank;
- }
-}
-
-// Executes a hierarchical tree broadcast.
-// Each subdiv is a broadcast between a subset of the devices.
-// If there is only one task, there is one subdiv comprising a broadcast between
-// all devices belonging to the task.
-// If there are n tasks, n>1, then there are n+1 subdivs. In the first (global)
-// subdiv, one device from each task participates in a binary tree broadcast.
-// Each task receives a copy of the tensor on one device via this broadcast.
-// Subsequent subdivs correspond to intra-task broadcasts. Subdiv i+1
-// corresponds to broadcast between all devices on task i. Thus, each task
-// participates in at most 2 subdivs.
-void Broadcaster::RunTree() {
- int num_subdivs = static_cast<int>(col_params_.subdiv_rank.size());
- // TODO(ayushd): this is easily improved when a node participates in both
- // first and second subdivision. It would first send to its descendents in
- // the first subdiv, then wait until all pending ops are finished before
- // sending to descendents in second subdiv. A better implementation would
- // collapse the two send blocks.
- for (int si = 0; si < num_subdivs; si++) {
- int my_rank = col_params_.subdiv_rank[si];
- // If rank is -1, this device does not participate in this subdiv.
- if (-1 == my_rank) continue;
- int source_rank = col_params_.instance.impl_details.subdiv_source_rank[si];
- if (VLOG_IS_ON(1)) {
- string subdiv_buf;
- for (int r : col_params_.instance.impl_details.subdiv_permutations[si]) {
- strings::StrAppend(&subdiv_buf, r, ",");
- }
- VLOG(1) << "Running Broadcast tree device=" << device_->name()
- << " subdiv=" << si << " perm=" << subdiv_buf
- << " my_rank=" << my_rank << " source_rank=" << source_rank;
- }
-
- mutex mu; // also guards status_ while callbacks are pending
- int pending_count = 0; // GUARDED_BY(mu)
- condition_variable all_done;
-
- if (my_rank >= 0 && my_rank != source_rank) {
- // Begin by receiving the value.
- int recv_from_rank = TreeRecvFrom(col_params_, si);
- Notification note;
- DispatchRecv(si, recv_from_rank, my_rank, output_,
- [this, &mu, &note](const Status& s) {
- mutex_lock l(mu);
- status_.Update(s);
- note.Notify();
- });
- note.WaitForNotification();
- }
-
- // Then forward value to all descendent devices.
- if (my_rank >= 0 && status_.ok()) {
- std::vector<int> send_to_ranks;
- TreeSendTo(col_params_, si, &send_to_ranks);
- for (int i = 0; i < send_to_ranks.size(); ++i) {
- int target_rank = send_to_ranks[i];
- {
- mutex_lock l(mu);
- ++pending_count;
- }
- DispatchSend(si, target_rank, my_rank,
- (is_source_ ? &ctx_->input(0) : output_),
- [this, &mu, &pending_count, &all_done](const Status& s) {
- mutex_lock l(mu);
- status_.Update(s);
- --pending_count;
- if (pending_count == 0) {
- all_done.notify_all();
- }
- });
- }
- }
-
- // For the original source device, we copy input to output if they are
- // different.
- // If there is only 1 subdiv, we do this in that subdiv. If there is more
- // than 1 subdiv, then the original source device will participate in 2
- // subdivs - the global inter-task broadcast and one local intra-task
- // broadcast. In this case, we perform the copy in the second subdiv for
- // this device.
- if (status_.ok() && is_source_ && (1 == num_subdivs || 0 != si)) {
- VLOG(2) << "copying input to output for device=" << device_->name()
- << " subdiv=" << si;
- const Tensor* input = &ctx_->input(0);
- if (input != output_ &&
- (DMAHelper::base(input) != DMAHelper::base(output_))) {
- {
- mutex_lock l(mu);
- ++pending_count;
- }
- DeviceContext* op_dev_ctx = ctx_->op_device_context();
- CollectiveRemoteAccessLocal::MemCpyAsync(
- op_dev_ctx, op_dev_ctx, device_, device_, ctx_->input_alloc_attr(0),
- ctx_->output_alloc_attr(0), input, output_, 0, /*stream_index*/
- [this, &mu, &pending_count, &all_done](const Status& s) {
- mutex_lock l(mu);
- status_.Update(s);
- --pending_count;
- if (0 == pending_count) {
- all_done.notify_all();
- }
- });
- }
- }
-
- // Then wait for all pending actions to complete.
- {
- mutex_lock l(mu);
- if (pending_count > 0) {
- all_done.wait(l);
- }
- }
- }
- VLOG(2) << "device=" << device_->name() << " return status " << status_;
- done_(status_);
-}
-
-void Broadcaster::DispatchSend(int subdiv, int dst_rank, int src_rank,
- const Tensor* src_tensor,
- const StatusCallback& done) {
- string send_buf_key = BroadcastBufKey(exec_key_, subdiv, src_rank, dst_rank);
- int dst_idx =
- col_params_.instance.impl_details.subdiv_permutations[subdiv][dst_rank];
- VLOG(1) << "DispatchSend " << send_buf_key << " from_device "
- << device_->name() << " to_device "
- << col_params_.instance.device_names[dst_idx] << " subdiv=" << subdiv
- << " dst_rank=" << dst_rank << " dst_idx=" << dst_idx;
- col_exec_->PostToPeer(col_params_.instance.device_names[dst_idx],
- col_params_.instance.task_names[dst_idx], send_buf_key,
- device_, ctx_->op_device_context(),
- ctx_->output_alloc_attr(0), src_tensor,
- device_locality_, done);
-}
-
-void Broadcaster::DispatchRecv(int subdiv, int src_rank, int dst_rank,
- Tensor* dst_tensor, const StatusCallback& done) {
- string recv_buf_key = BroadcastBufKey(exec_key_, subdiv, src_rank, dst_rank);
- int src_idx =
- col_params_.instance.impl_details.subdiv_permutations[subdiv][src_rank];
- VLOG(1) << "DispatchRecv " << recv_buf_key << " from_device "
- << col_params_.instance.device_names[src_idx] << " to_device "
- << device_->name() << " subdiv=" << subdiv << " src_rank=" << src_rank
- << " src_idx=" << src_idx;
- col_exec_->RecvFromPeer(col_params_.instance.device_names[src_idx],
- col_params_.instance.task_names[src_idx],
- col_params_.task.is_local[src_idx], recv_buf_key,
- device_, ctx_->op_device_context(),
- ctx_->output_alloc_attr(0), dst_tensor,
- device_locality_, 0 /*stream_index*/, done);
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index 2a14493a67..52eedae9b7 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -14,7 +14,20 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
+#include <stddef.h>
+#include <algorithm>
+#include <unordered_map>
+#include <utility>
+
#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@@ -319,206 +332,6 @@ void SortDevicesAndTasks(CollectiveParams* cp) {
}
} // namespace
-int GetDeviceTask(int device_rank, const std::vector<int>& dev_per_task) {
- int num_tasks = static_cast<int>(dev_per_task.size());
- int task_lo = 0;
- int task_hi;
- for (int ti = 0; ti < num_tasks; ti++) {
- task_hi = task_lo + dev_per_task[ti];
- if (task_lo <= device_rank && device_rank < task_hi) return ti;
- task_lo += dev_per_task[ti];
- }
- LOG(FATAL) << "Unexpected device rank " << device_rank << " for " << task_hi
- << " devices";
- return -1;
-}
-
-void CollectiveParamResolverLocal::GenerateBcastSubdivPerms(
- const string& device, int source_rank, const std::vector<int>& dev_per_task,
- CollectiveParams* cp) {
- if (VLOG_IS_ON(1)) {
- string dpt_buf;
- for (int dpt : dev_per_task) strings::StrAppend(&dpt_buf, dpt, ";");
- VLOG(1) << "GenerateBcastSubdivPerms device=" << device
- << " source_rank=" << source_rank << " dev_per_task=" << dpt_buf;
- }
- int num_tasks = cp->group.num_tasks;
- // If there is just 1 task, then execute binary tree broadcast over all
- // devices. Otherwise, the first subdiv is inter-task broadcast, and then
- // there are N more subdivs, where N is #task.
- int num_subdivs = num_tasks + (num_tasks > 1 ? 1 : 0);
- int total_num_devices = 0;
- for (int num_dev : dev_per_task) total_num_devices += num_dev;
-
- cp->instance.impl_details.subdiv_permutations.resize(num_subdivs);
- cp->subdiv_rank.reserve(num_subdivs);
- cp->instance.impl_details.subdiv_source_rank.reserve(num_subdivs);
-
- // Inter-task subdiv. Pick one device from each task - this is the source
- // device if it belongs to that task, or device 0 for that task. If a device
- // does not participate in the subdiv, set subdiv_rank to -1.
- if (num_tasks > 1) {
- const int sdi = 0;
- std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi];
- CHECK_EQ(perm.size(), 0);
- int device_count = 0;
- int source_task = GetDeviceTask(source_rank, dev_per_task);
- for (int ti = 0; ti < cp->group.num_tasks; ti++) {
- bool participate = false;
- if (source_task == ti) {
- // Source device belongs to this task.
- perm.push_back(source_rank);
- participate = cp->instance.device_names[source_rank] == device;
- } else {
- // Source does not belong to this task, choose dev 0.
- perm.push_back(device_count);
- participate = cp->instance.device_names[device_count] == device;
- }
- if (participate) cp->subdiv_rank.push_back(ti);
- device_count += dev_per_task[ti];
- }
- if (cp->subdiv_rank.empty()) cp->subdiv_rank.push_back(-1);
- cp->instance.impl_details.subdiv_source_rank.push_back(source_task);
- }
-
- // Intra-task subdivs. Pick all devices in task ti for subdiv sdi. Set
- // source to dev 0 for that task if it does not contain original source, else
- // set to rank of original source. If a device does not participate in the
- // subdiv, set subdiv_rank to -1;
- int abs_di = 0;
- for (int ti = 0; ti < cp->group.num_tasks; ti++) {
- const int sdi = ti + (num_tasks > 1 ? 1 : 0);
- std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi];
- CHECK_EQ(perm.size(), 0);
- bool participate = false;
- int subdiv_source = 0;
- for (int di = 0; di < dev_per_task[ti]; di++) {
- perm.push_back(abs_di);
- if (cp->instance.device_names[abs_di] == device) {
- participate = true;
- cp->subdiv_rank.push_back(di);
- }
- if (abs_di == source_rank) subdiv_source = di;
- abs_di++;
- }
- if (!participate) cp->subdiv_rank.push_back(-1);
- cp->instance.impl_details.subdiv_source_rank.push_back(subdiv_source);
- }
-
- for (int sri = 0; sri < num_subdivs; sri++) {
- CHECK_GE(cp->instance.impl_details.subdiv_source_rank[sri], 0);
- }
-}
-
-// Establish the requested number of subdivision permutations based on the
-// ring order implicit in the device order.
-/*static*/
-void CollectiveParamResolverLocal::GenerateSubdivPerms(const string& device,
- int source_rank,
- CollectiveParams* cp) {
- // Each subdiv permutation is a ring formed by rotating each
- // single-task subsequence of devices by an offset. This makes most
- // sense when each task has the same number of devices but we can't
- // depend on that being the case so we'll compute something that
- // works in any case.
-
- // Start by counting the devices in each task.
- // Precondition: device_names must be sorted so that all devices in
- // the same task are adjacent.
- VLOG(2) << "Sorted task names: "
- << str_util::Join(cp->instance.task_names, ", ");
- std::vector<int> dev_per_task;
- const string* prior_task_name = &cp->instance.task_names[0];
- int dev_count = 1;
- for (int di = 1; di < cp->group.group_size; ++di) {
- if (cp->instance.task_names[di] != *prior_task_name) {
- dev_per_task.push_back(dev_count);
- dev_count = 1;
- prior_task_name = &cp->instance.task_names[di];
- } else {
- ++dev_count;
- }
- }
- dev_per_task.push_back(dev_count);
- CHECK_EQ(cp->group.num_tasks, dev_per_task.size());
-
- CHECK(cp->instance.type == REDUCTION_COLLECTIVE ||
- cp->instance.type == BROADCAST_COLLECTIVE);
- if (cp->instance.type == REDUCTION_COLLECTIVE) {
- // Generate a ring permutation for each requested offset.
- CHECK_GT(cp->instance.impl_details.subdiv_offsets.size(), 0);
- VLOG(2) << "Setting up perms for cp " << cp << " subdiv_permutations "
- << &cp->instance.impl_details.subdiv_permutations;
- cp->instance.impl_details.subdiv_permutations.resize(
- cp->instance.impl_details.subdiv_offsets.size());
- cp->subdiv_rank.resize(cp->instance.impl_details.subdiv_offsets.size(), -1);
- for (int sdi = 0; sdi < cp->instance.impl_details.subdiv_offsets.size();
- ++sdi) {
- std::vector<int>& perm =
- cp->instance.impl_details.subdiv_permutations[sdi];
- CHECK_EQ(perm.size(), 0);
- int offset = cp->instance.impl_details.subdiv_offsets[sdi];
- // A negative subdivision offset is interpreted as follows:
- // 1. Reverse the local device ordering.
- // 2. Begin the subdivision at abs(offset) in the reversed ordering.
- bool reverse = false;
- if (offset < 0) {
- offset = abs(offset);
- reverse = true;
- }
- int prior_dev_count = 0; // sum over prior worker device counts
- for (int ti = 0; ti < cp->group.num_tasks; ++ti) {
- for (int di = 0; di < dev_per_task[ti]; ++di) {
- int di_offset = (di + offset) % dev_per_task[ti];
- int offset_di =
- reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset;
- // Device index in global subdivision permutation.
- int permuted_di = prior_dev_count + offset_di;
- int rank = static_cast<int>(perm.size());
- perm.push_back(permuted_di);
- if (cp->instance.device_names[permuted_di] == device) {
- CHECK_EQ(permuted_di, cp->default_rank);
- cp->subdiv_rank[sdi] = rank;
- }
- }
- prior_dev_count += dev_per_task[ti];
- }
- CHECK_EQ(cp->group.group_size, perm.size());
- }
- } else if (cp->instance.type == BROADCAST_COLLECTIVE) {
- GenerateBcastSubdivPerms(device, source_rank, dev_per_task, cp);
- }
-
- if (VLOG_IS_ON(1)) {
- // Log the computed ring order for each subdiv.
- string buf;
- for (int sdi = 0;
- sdi < cp->instance.impl_details.subdiv_permutations.size(); ++sdi) {
- buf = strings::StrCat("Subdiv ", sdi, " device order:\n");
- for (int di = 0;
- di < cp->instance.impl_details.subdiv_permutations[sdi].size();
- ++di) {
- int idx = cp->instance.impl_details.subdiv_permutations[sdi][di];
- if (idx >= 0) {
- CHECK_GT(cp->instance.device_names.size(), idx);
- strings::StrAppend(&buf, cp->instance.device_names[idx], "\n");
- }
- }
- strings::StrAppend(&buf, " subdiv_offsets: ");
- for (auto o : cp->instance.impl_details.subdiv_offsets)
- strings::StrAppend(&buf, o, " ");
- strings::StrAppend(&buf, " SubdivRank: ");
- for (auto d : cp->subdiv_rank) strings::StrAppend(&buf, d, " ");
- if (cp->instance.type == BROADCAST_COLLECTIVE) {
- strings::StrAppend(&buf, " subdiv_source_rank: ");
- for (auto src : cp->instance.impl_details.subdiv_source_rank)
- strings::StrAppend(&buf, src, " ");
- }
- VLOG(1) << buf;
- }
- }
-}
-
void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name,
CollectiveParams* cp) {
cp->task.is_local.resize(cp->group.group_size, false);
@@ -785,29 +598,39 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
// Populate the fields common across task, also default_rank.
SetDefaultRank(device, cp);
CompleteTaskIsLocal(task_name_, cp);
+ // TODO(b/113171733): we need a better way to pick the collective
+ // implementation. The ideal way would depend upon the topology and link
+ // strength before picking a particular implementation.
+ cp->instance.impl_details.collective_name =
+ (cp->instance.type == BROADCAST_COLLECTIVE) ? "HierarchicalTreeBroadcast"
+ : "RingReduce";
+ CollectiveImplementationInterface* col_impl;
+ Status lookup_status = CollectiveRegistry::LookupParamResolverInstance(
+ cp->instance.impl_details.collective_name, &col_impl);
+ if (!lookup_status.ok()) {
+ done(lookup_status);
+ return;
+ }
// If broadcast, may need to wait for source discovery.
if (cp->instance.type == BROADCAST_COLLECTIVE) {
CompleteInstanceSource(ir, cp, is_source,
- [this, ir, device, cp, done](InstanceRec* irec) {
+ [col_impl, ir, device, cp, done](InstanceRec* irec) {
CHECK_EQ(ir, irec);
Status s;
- int source_rank;
{
mutex_lock l(irec->out_mu);
irec->WaitForOutMu(l);
s = irec->status;
- source_rank = irec->source_rank;
+ cp->source_rank = irec->source_rank;
}
if (s.ok()) {
- GenerateSubdivPerms(device, source_rank, cp);
+ s = col_impl->InitializeCollectiveParams(cp);
}
done(s);
});
- return;
} else {
- GenerateSubdivPerms(device, 0, cp);
+ done(col_impl->InitializeCollectiveParams(cp));
}
- done(Status::OK());
}
void CollectiveParamResolverLocal::CompleteInstanceSource(InstanceRec* ir,
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
index 9372fd6272..c5c3497e28 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.h
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h
@@ -15,7 +15,11 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+#include <functional>
+#include <memory>
+#include <set>
#include <string>
+#include <vector>
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -79,6 +83,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
// Used to complete/verify CollInstance.
struct InstanceRec;
+
typedef std::function<void(InstanceRec*)> IRConsumer;
struct InstanceRec {
// This structure has two mutexes so that a possibly long
@@ -212,18 +217,6 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
void CallbackWithStatus(const InstanceRecCallback& done, InstanceRec* irec)
LOCKS_EXCLUDED(irec->out_mu);
- friend class CollectiveParamResolverLocalTest;
- // Establishes the requested number of subdivision permutations based on the
- // ring order implicit in the device order.
- static void GenerateSubdivPerms(const string& device, int source_rank,
- CollectiveParams* cp);
- // Establishes the subdivisions for broadcast op. The first subdiv executes
- // binary tree bcast with one device per task. Each subsequent subdiv
- // executes intra-task binary tree broadcast.
- static void GenerateBcastSubdivPerms(const string& device, int source_rank,
- const std::vector<int>& dev_per_task,
- CollectiveParams* cp);
-
const DeviceMgr* dev_mgr_;
DeviceResolverInterface* dev_resolver_; // Not owned.
string task_name_;
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
index 9ea23b72d2..9e1e2e8d5b 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
@@ -44,31 +44,6 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
task_name));
}
- void GenSubdivPerms(const string& device, int source_rank,
- CollectiveParams* cp) {
- CollectiveParamResolverLocal::GenerateSubdivPerms(device, source_rank, cp);
- }
-
- // Calls GenerateBcastSubdivPerms for device at `device_rank`. Checks if the
- // generated subdiv perms, ranks, and source ranks match the expected values.
- void BcastSubdivPerms(
- CollectiveParams* cp, const std::vector<int>& dev_per_task,
- int device_rank, int source_rank,
- const std::vector<std::vector<int>>& expected_subdiv_perms,
- const std::vector<int>& expected_subdiv_rank,
- const std::vector<int>& expected_subdiv_source_rank) {
- cp->subdiv_rank.clear();
- cp->instance.impl_details.subdiv_permutations.clear();
- cp->instance.impl_details.subdiv_source_rank.clear();
- CollectiveParamResolverLocal::GenerateBcastSubdivPerms(
- cp->instance.device_names[device_rank], source_rank, dev_per_task, cp);
- EXPECT_EQ(expected_subdiv_perms,
- cp->instance.impl_details.subdiv_permutations);
- EXPECT_EQ(expected_subdiv_rank, cp->subdiv_rank);
- EXPECT_EQ(expected_subdiv_source_rank,
- cp->instance.impl_details.subdiv_source_rank);
- }
-
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<DeviceResolverLocal> drl_;
@@ -114,7 +89,6 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) {
cps[i].instance.device_names[j]);
EXPECT_TRUE(cps[i].task.is_local[j]);
}
- EXPECT_EQ(cps[i].subdiv_rank[0], i);
EXPECT_EQ(cps[i].instance.impl_details.subdiv_source_rank.size(), 0);
EXPECT_FALSE(cps[i].is_source);
EXPECT_EQ(cps[i].default_rank, i);
@@ -161,188 +135,10 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
cps[i].instance.device_names[j]);
EXPECT_TRUE(cps[i].task.is_local[j]);
}
- ASSERT_GT(cps[i].subdiv_rank.size(), 0);
- EXPECT_EQ(cps[i].subdiv_rank[0], i);
- ASSERT_GT(cps[i].instance.impl_details.subdiv_source_rank.size(), 0);
- EXPECT_EQ(cps[i].instance.impl_details.subdiv_source_rank[0], 1);
EXPECT_EQ(cps[i].is_source, (i == 1));
EXPECT_EQ(cps[i].default_rank, i);
EXPECT_TRUE(cps[i].instance.same_num_devices_per_task);
}
}
-TEST_F(CollectiveParamResolverLocalTest, GenerateSubdivPerms) {
- static const int kNumDevsPerTask = 8;
- static const int kNumTasks = 3;
- static const int kNumDevs = kNumDevsPerTask * kNumTasks;
- CollectiveParams cp;
- std::vector<string> device_names;
- std::vector<string> task_names;
- cp.group.group_key = 1;
- cp.group.group_size = kNumDevs;
- cp.group.device_type = DeviceType("GPU");
- cp.group.num_tasks = kNumTasks;
- cp.instance.instance_key = 3;
- cp.instance.type = REDUCTION_COLLECTIVE;
- cp.instance.data_type = DataType(DT_FLOAT);
- cp.instance.shape = TensorShape({5});
- cp.instance.impl_details.subdiv_offsets.push_back(0);
- cp.is_source = false;
- for (int i = 0; i < kNumDevs; ++i) {
- int task_id = i / kNumDevsPerTask;
- int dev_id = i % kNumDevsPerTask;
- string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
- task_names.push_back(task_name);
- string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
- device_names.push_back(device_name);
- cp.instance.task_names.push_back(task_name);
- cp.instance.device_names.push_back(device_name);
- }
-
- int test_rank = 0;
- cp.default_rank = test_rank;
- cp.instance.impl_details.subdiv_offsets = {0, 4};
- GenSubdivPerms(cp.instance.device_names[test_rank], 0, &cp);
- std::vector<int> expected_0 = {0, 1, 2, 3, 4, 5, 6, 7,
- 8, 9, 10, 11, 12, 13, 14, 15,
- 16, 17, 18, 19, 20, 21, 22, 23};
- std::vector<int> expected_1 = {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15,
- 8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19};
- for (int i = 0; i < kNumDevs; ++i) {
- EXPECT_EQ(expected_0[i],
- cp.instance.impl_details.subdiv_permutations[0][i]);
- EXPECT_EQ(expected_1[i],
- cp.instance.impl_details.subdiv_permutations[1][i]);
- }
- EXPECT_EQ(0, cp.subdiv_rank[0]);
- EXPECT_EQ(4, cp.subdiv_rank[1]);
-
- test_rank = 3;
- cp.default_rank = test_rank;
- cp.instance.impl_details.subdiv_offsets = {3, -3};
- cp.instance.impl_details.subdiv_permutations.clear();
- GenSubdivPerms(cp.instance.device_names[test_rank], 0, &cp);
- expected_0 = {3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14,
- 15, 8, 9, 10, 19, 20, 21, 22, 23, 16, 17, 18};
- expected_1 = {4, 3, 2, 1, 0, 7, 6, 5, 12, 11, 10, 9,
- 8, 15, 14, 13, 20, 19, 18, 17, 16, 23, 22, 21};
- for (int i = 0; i < kNumDevs; ++i) {
- EXPECT_EQ(expected_0[i],
- cp.instance.impl_details.subdiv_permutations[0][i]);
- EXPECT_EQ(expected_1[i],
- cp.instance.impl_details.subdiv_permutations[1][i]);
- }
- EXPECT_EQ(0, cp.subdiv_rank[0]);
- EXPECT_EQ(1, cp.subdiv_rank[1]);
-}
-
-TEST_F(CollectiveParamResolverLocalTest, GenerateBcastSubdivPerms1Task8GPU) {
- CollectiveParams cp;
- cp.group.device_type = DeviceType("GPU");
- cp.group.num_tasks = 1;
- cp.instance.type = BROADCAST_COLLECTIVE;
- for (int i = 0; i < 8; i++) {
- string dev_name =
- strings::StrCat("/job:worker/replica:0/task:0/device:GPU:", i);
- cp.instance.device_names.push_back(dev_name);
- }
- std::vector<int> dev_per_task = {8};
-
- // source 0 device 0
- BcastSubdivPerms(&cp, dev_per_task, 0, 0, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0},
- {0});
-
- // source 2 device 2
- BcastSubdivPerms(&cp, dev_per_task, 2, 2, {{0, 1, 2, 3, 4, 5, 6, 7}}, {2},
- {2});
-
- // source 2 device 0
- BcastSubdivPerms(&cp, dev_per_task, 0, 2, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0},
- {2});
-}
-
-TEST_F(CollectiveParamResolverLocalTest, GenerateBcastSubdivPerms4Tasks8GPU) {
- CollectiveParams cp;
- cp.group.device_type = DeviceType("GPU");
- cp.group.num_tasks = 4;
- cp.instance.type = BROADCAST_COLLECTIVE;
- for (int ti = 0; ti < cp.group.num_tasks; ti++) {
- for (int di = 0; di < 8; di++) {
- string dev_name = strings::StrCat("/job:worker/replica:0/task:", ti,
- "/device:GPU:", di);
- cp.instance.device_names.push_back(dev_name);
- }
- }
- std::vector<int> dev_per_task = {8, 8, 8, 8};
-
- // source 0 device 0
- BcastSubdivPerms(&cp, dev_per_task, 0, 0,
- {{0, 8, 16, 24},
- {0, 1, 2, 3, 4, 5, 6, 7},
- {8, 9, 10, 11, 12, 13, 14, 15},
- {16, 17, 18, 19, 20, 21, 22, 23},
- {24, 25, 26, 27, 28, 29, 30, 31}},
- {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
-
- // source 2 device 0
- BcastSubdivPerms(&cp, dev_per_task, 0, 2,
- {{2, 8, 16, 24},
- {0, 1, 2, 3, 4, 5, 6, 7},
- {8, 9, 10, 11, 12, 13, 14, 15},
- {16, 17, 18, 19, 20, 21, 22, 23},
- {24, 25, 26, 27, 28, 29, 30, 31}},
- {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
-
- // source 9 device 9
- BcastSubdivPerms(&cp, dev_per_task, 9, 9,
- {{0, 9, 16, 24},
- {0, 1, 2, 3, 4, 5, 6, 7},
- {8, 9, 10, 11, 12, 13, 14, 15},
- {16, 17, 18, 19, 20, 21, 22, 23},
- {24, 25, 26, 27, 28, 29, 30, 31}},
- {1, -1, 1, -1, -1}, {1, 0, 1, 0, 0});
-}
-
-TEST_F(CollectiveParamResolverLocalTest,
- GenerateBcastSubdivPerms4TasksVariableGPU) {
- CollectiveParams cp;
- cp.group.device_type = DeviceType("GPU");
- cp.group.num_tasks = 4;
- std::vector<int> dev_per_task = {4, 4, 6, 8};
- for (int ti = 0; ti < cp.group.num_tasks; ti++) {
- for (int di = 0; di < dev_per_task[ti]; di++) {
- string dev_name = strings::StrCat("/job:worker/replica:0/task:", ti,
- "/device:GPU:", di);
- cp.instance.device_names.push_back(dev_name);
- }
- }
-
- // source 0 device 0
- BcastSubdivPerms(&cp, dev_per_task, 0, 0,
- {{0, 4, 8, 14},
- {0, 1, 2, 3},
- {4, 5, 6, 7},
- {8, 9, 10, 11, 12, 13},
- {14, 15, 16, 17, 18, 19, 20, 21}},
- {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
-
- // source 2 device 0
- BcastSubdivPerms(&cp, dev_per_task, 0, 2,
- {{2, 4, 8, 14},
- {0, 1, 2, 3},
- {4, 5, 6, 7},
- {8, 9, 10, 11, 12, 13},
- {14, 15, 16, 17, 18, 19, 20, 21}},
- {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
-
- // source 9 device 5
- BcastSubdivPerms(&cp, dev_per_task, 5, 9,
- {{0, 4, 9, 14},
- {0, 1, 2, 3},
- {4, 5, 6, 7},
- {8, 9, 10, 11, 12, 13},
- {14, 15, 16, 17, 18, 19, 20, 21}},
- {-1, -1, 1, -1, -1}, {2, 0, 0, 1, 0});
-}
-
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_util.cc b/tensorflow/core/common_runtime/collective_util.cc
new file mode 100644
index 0000000000..195521a078
--- /dev/null
+++ b/tensorflow/core/common_runtime/collective_util.cc
@@ -0,0 +1,83 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_util.h"
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace collective_util {
+
+/*static*/
+Status InitializeDeviceAndLocality(const DeviceMgr* dev_mgr,
+ const string& device_name, Device** device,
+ DeviceLocality* device_locality) {
+ if (!dev_mgr) {
+ return errors::Internal("Required non-null dev_mgr ", dev_mgr,
+ " for InitializeDeviceAndLocality");
+ }
+
+ Status status = dev_mgr->LookupDevice(device_name, device);
+ if (status.ok()) {
+ CHECK(*device);
+ *device_locality = (*device)->attributes().locality();
+ } else {
+ LOG(ERROR) << "Failed to find device " << device_name;
+ for (auto d : dev_mgr->ListDevices()) {
+ LOG(ERROR) << "Available devices " << d->name();
+ }
+ }
+ return status;
+}
+
+/*static*/
+string SubdivPermDebugString(const CollectiveParams& col_params) {
+ const auto& subdiv_perms =
+ col_params.instance.impl_details.subdiv_permutations;
+ string buf;
+ for (int sdi = 0; sdi < subdiv_perms.size(); ++sdi) {
+ strings::StrAppend(&buf, "Subdiv ", sdi, " device order:\n");
+ for (int di = 0; di < subdiv_perms[sdi].size(); ++di) {
+ int idx = subdiv_perms[sdi][di];
+ if (idx >= 0) {
+ CHECK_GT(col_params.instance.device_names.size(), idx);
+ strings::StrAppend(&buf, col_params.instance.device_names[idx], "\n");
+ }
+ }
+ strings::StrAppend(&buf, " subdiv_offsets: ");
+ for (auto o : col_params.instance.impl_details.subdiv_offsets)
+ strings::StrAppend(&buf, o, " ");
+ strings::StrAppend(&buf, " SubdivRank: ");
+ for (auto d : col_params.subdiv_rank) strings::StrAppend(&buf, d, " ");
+ if (col_params.instance.type == BROADCAST_COLLECTIVE) {
+ strings::StrAppend(&buf, " subdiv_source_rank: ");
+ for (auto src : col_params.instance.impl_details.subdiv_source_rank)
+ strings::StrAppend(&buf, src, " ");
+ }
+ strings::StrAppend(&buf, "\n");
+ }
+ return buf;
+}
+
+} // namespace collective_util
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_util.h b/tensorflow/core/common_runtime/collective_util.h
new file mode 100644
index 0000000000..ebb5731bec
--- /dev/null
+++ b/tensorflow/core/common_runtime/collective_util.h
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_UTIL_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_UTIL_H_
+
+#include <string>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace collective_util {
+
+Status InitializeDeviceAndLocality(const DeviceMgr* dev_mgr,
+ const string& device_name, Device** device,
+ DeviceLocality* device_locality);
+string SubdivPermDebugString(const CollectiveParams& col_params);
+
+} // namespace collective_util
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_UTIL_H_
diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc
new file mode 100644
index 0000000000..eae34997d9
--- /dev/null
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc
@@ -0,0 +1,440 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+#include "tensorflow/core/common_runtime/collective_util.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/types.h"
+
+// Set true for greater intelligibility of debug mode log messages.
+#define READABLE_KEYS false
+
+namespace tensorflow {
+
+namespace {
+// Key to be used for BufRendezvous by Broadcaster.
+string BroadcastBufKey(const string& exec_key, int subdiv, int src_rank,
+ int dst_rank) {
+ if (READABLE_KEYS) {
+ return strings::StrCat("broadcast(", exec_key, "):subdiv(", subdiv,
+ "):src(", src_rank, "):dst(", dst_rank, ")");
+ } else {
+ // TODO(b/78352018): Try a denser format, e.g. a 64 or 128 bit hash.
+ return strings::StrCat(exec_key, ":", subdiv, ":", src_rank, ":", dst_rank);
+ }
+}
+} // namespace
+
+HierarchicalTreeBroadcaster::HierarchicalTreeBroadcaster()
+ : col_ctx_(nullptr),
+ col_params_(nullptr),
+ done_(nullptr),
+ is_source_(false) {}
+
+int HierarchicalTreeBroadcaster::GetDeviceTask(
+ int device_rank, const std::vector<int>& dev_per_task) {
+ int num_tasks = static_cast<int>(dev_per_task.size());
+ int task_lo = 0;
+ int task_hi;
+ for (int ti = 0; ti < num_tasks; ti++) {
+ task_hi = task_lo + dev_per_task[ti];
+ if (task_lo <= device_rank && device_rank < task_hi) return ti;
+ task_lo = task_hi;
+ }
+ LOG(FATAL) << "Unexpected device rank " << device_rank << " for " << task_hi
+ << " devices";
+ return -1;
+}
+
+Status HierarchicalTreeBroadcaster::InitializeCollectiveParams(
+ CollectiveParams* col_params) {
+ CHECK_EQ(col_params->instance.type, BROADCAST_COLLECTIVE);
+ CHECK_EQ(col_params->instance.impl_details.collective_name,
+ "HierarchicalTreeBroadcast");
+ const string& device_name =
+ col_params->instance.device_names[col_params->default_rank];
+ // Start by counting the devices in each task.
+ // Precondition: device_names must be sorted so that all devices in
+ // the same task are adjacent.
+ VLOG(2) << "Sorted task names: "
+ << str_util::Join(col_params->instance.task_names, ", ");
+ std::vector<int> dev_per_task;
+ const string* prior_task_name = &col_params->instance.task_names[0];
+ int dev_count = 1;
+ for (int di = 1; di < col_params->group.group_size; ++di) {
+ if (col_params->instance.task_names[di] != *prior_task_name) {
+ dev_per_task.push_back(dev_count);
+ dev_count = 1;
+ prior_task_name = &col_params->instance.task_names[di];
+ } else {
+ ++dev_count;
+ }
+ }
+ dev_per_task.push_back(dev_count);
+ CHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
+
+ if (VLOG_IS_ON(2)) {
+ string dpt_buf;
+ for (int dpt : dev_per_task) strings::StrAppend(&dpt_buf, dpt, ";");
+ VLOG(2) << "HierarchicalTreeBroadcaster::InitializeCollectiveParams device="
+ << device_name << " source_rank=" << col_params->source_rank
+ << " dev_per_task=" << dpt_buf;
+ }
+ int num_tasks = col_params->group.num_tasks;
+ // If there is just 1 task, then execute binary tree broadcast over all
+ // devices. Otherwise, the first subdiv is inter-task broadcast, and then
+ // there are N more subdivs, where N is #task.
+ int num_subdivs = num_tasks + (num_tasks > 1 ? 1 : 0);
+ int total_num_devices = 0;
+ for (int num_dev : dev_per_task) total_num_devices += num_dev;
+
+ col_params->instance.impl_details.subdiv_permutations.resize(num_subdivs);
+ col_params->subdiv_rank.reserve(num_subdivs);
+ col_params->instance.impl_details.subdiv_source_rank.reserve(num_subdivs);
+
+ // Inter-task subdiv. Pick one device from each task - this is the source
+ // device if it belongs to that task, or device 0 for that task. If a device
+ // does not participate in the subdiv, set subdiv_rank to -1.
+ if (num_tasks > 1) {
+ const int sdi = 0;
+ std::vector<int>& perm =
+ col_params->instance.impl_details.subdiv_permutations[sdi];
+ CHECK_EQ(perm.size(), 0);
+ int device_count = 0;
+ int source_task = GetDeviceTask(col_params->source_rank, dev_per_task);
+ for (int ti = 0; ti < col_params->group.num_tasks; ti++) {
+ bool participate = false;
+ if (source_task == ti) {
+ // Source device belongs to this task.
+ perm.push_back(col_params->source_rank);
+ participate =
+ col_params->instance.device_names[col_params->source_rank] ==
+ device_name;
+ } else {
+ // Source does not belong to this task, choose dev 0.
+ perm.push_back(device_count);
+ participate =
+ col_params->instance.device_names[device_count] == device_name;
+ }
+ if (participate) col_params->subdiv_rank.push_back(ti);
+ device_count += dev_per_task[ti];
+ }
+ if (col_params->subdiv_rank.empty()) col_params->subdiv_rank.push_back(-1);
+ col_params->instance.impl_details.subdiv_source_rank.push_back(source_task);
+ }
+
+ // Intra-task subdivs. Pick all devices in task ti for subdiv sdi. Set
+ // source to dev 0 for that task if it does not contain original source, else
+ // set to rank of original source. If a device does not participate in
+ // the subdiv, set subdiv_rank to -1;
+ int abs_di = 0;
+ for (int ti = 0; ti < col_params->group.num_tasks; ti++) {
+ const int sdi = ti + (num_tasks > 1 ? 1 : 0);
+ std::vector<int>& perm =
+ col_params->instance.impl_details.subdiv_permutations[sdi];
+ CHECK_EQ(perm.size(), 0);
+ bool participate = false;
+ int subdiv_source = 0;
+ for (int di = 0; di < dev_per_task[ti]; di++) {
+ perm.push_back(abs_di);
+ if (col_params->instance.device_names[abs_di] == device_name) {
+ participate = true;
+ col_params->subdiv_rank.push_back(di);
+ }
+ if (abs_di == col_params->source_rank) subdiv_source = di;
+ abs_di++;
+ }
+ if (!participate) col_params->subdiv_rank.push_back(-1);
+ col_params->instance.impl_details.subdiv_source_rank.push_back(
+ subdiv_source);
+ }
+
+ for (int sri = 0; sri < num_subdivs; sri++) {
+ CHECK_GE(col_params->instance.impl_details.subdiv_source_rank[sri], 0);
+ }
+
+ VLOG(2) << collective_util::SubdivPermDebugString(*col_params);
+ return Status::OK();
+}
+
+Status HierarchicalTreeBroadcaster::InitializeCollectiveContext(
+ CollectiveContext* col_ctx) {
+ CHECK(col_ctx->dev_mgr);
+ col_ctx_ = col_ctx;
+ col_params_ = &col_ctx->col_params;
+ return collective_util::InitializeDeviceAndLocality(
+ col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
+ &col_ctx->device_locality);
+}
+
+void HierarchicalTreeBroadcaster::Run(StatusCallback done) {
+ CHECK(col_ctx_);
+ CHECK(col_params_);
+ done_ = std::move(done);
+ is_source_ = col_params_->is_source;
+ RunTree();
+}
+
+// Binary tree parent/child relations are trivial to calculate, i.e.
+// device at rank r is the parent of 2r+1 and 2r+2. The one exception
+// is if the source is not rank 0. We treat that case as though the
+// source is appended to the front of the rank ordering as well as
+// continuing to occupy its current position. Hence we calculate as
+// though each device's rank is actually r+1, then subtract 1 again to
+// get the descendent ranks. If the source is not rank 0 then its
+// descendants include both {0,1} and the descendents of its current
+// position. Where a non-0-rank source is a descendent of another
+// device, no send to it is necessary.
+
+/* static*/
+int HierarchicalTreeBroadcaster::TreeRecvFrom(const CollectiveParams& cp,
+ int subdiv) {
+ DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
+ int my_rank = cp.subdiv_rank[subdiv];
+ if (-1 == my_rank) return -1;
+
+ const auto& impl = cp.instance.impl_details;
+ DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
+ int source_rank = impl.subdiv_source_rank[subdiv];
+ if (my_rank == source_rank) return -1;
+ if (source_rank == 0) {
+ return (my_rank - 1) / 2;
+ } else {
+ int predecessor_rank = (my_rank / 2) - 1;
+ return (predecessor_rank < 0) ? source_rank : predecessor_rank;
+ }
+}
+
+/* static */
+void HierarchicalTreeBroadcaster::TreeSendTo(const CollectiveParams& cp,
+ int subdiv,
+ std::vector<int>* targets) {
+ DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
+ int my_rank = cp.subdiv_rank[subdiv];
+ if (-1 == my_rank) return;
+
+ const auto& impl = cp.instance.impl_details;
+ DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
+ int source_rank = impl.subdiv_source_rank[subdiv];
+
+ int group_size = 0;
+ for (int i = 0; i < impl.subdiv_permutations[subdiv].size(); i++) {
+ if (impl.subdiv_permutations[subdiv][i] >= 0) {
+ group_size++;
+ }
+ }
+
+ targets->clear();
+ int successor_rank = 0;
+ if (source_rank == 0) {
+ successor_rank = (2 * my_rank) + 1;
+ } else {
+ successor_rank = (2 * (my_rank + 1));
+ }
+ DCHECK_NE(successor_rank, my_rank);
+ if (cp.is_source && source_rank != 0) {
+ // The source sends to rank 0,1 in addition to its positional
+ // descendants.
+ if (group_size > 1) {
+ targets->push_back(0);
+ }
+ if (group_size > 2 && source_rank != 1) {
+ targets->push_back(1);
+ }
+ }
+ for (int i = 0; i < 2; ++i) {
+ if (successor_rank < group_size && successor_rank != source_rank) {
+ targets->push_back(successor_rank);
+ }
+ ++successor_rank;
+ }
+}
+
+// Executes a hierarchical tree broadcast.
+// Each subdiv is a broadcast between a subset of the devices.
+// If there is only one task, there is one subdiv comprising a broadcast between
+// all devices belonging to the task.
+// If there are n tasks, n>1, then there are n+1 subdivs. In the first (global)
+// subdiv, one device from each task participates in a binary tree broadcast.
+// Each task receives a copy of the tensor on one device via this broadcast.
+// Subsequent subdivs correspond to intra-task broadcasts. Subdiv i+1
+// corresponds to broadcast between all devices on task i. Thus, each task
+// participates in at most 2 subdivs.
+void HierarchicalTreeBroadcaster::RunTree() {
+ int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size());
+ // TODO(b/78352018): this is easily improved when a node participates in both
+ // first and second subdivision. It would first send to its descendents in
+ // the first subdiv, then wait until all pending ops are finished before
+ // sending to descendents in second subdiv. A better implementation would
+ // collapse the two send blocks.
+ for (int si = 0; si < num_subdivs; si++) {
+ int my_rank = col_params_->subdiv_rank[si];
+ // If rank is -1, this device does not participate in this subdiv.
+ if (-1 == my_rank) continue;
+ int source_rank = col_params_->instance.impl_details.subdiv_source_rank[si];
+ if (VLOG_IS_ON(1)) {
+ string subdiv_buf;
+ for (int r : col_params_->instance.impl_details.subdiv_permutations[si]) {
+ strings::StrAppend(&subdiv_buf, r, ",");
+ }
+ VLOG(1) << "Running Broadcast tree device=" << col_ctx_->device_name
+ << " subdiv=" << si << " perm=" << subdiv_buf
+ << " my_rank=" << my_rank << " source_rank=" << source_rank;
+ }
+
+ mutex mu; // also guards status_ while callbacks are pending
+ int pending_count = 0; // GUARDED_BY(mu)
+ condition_variable all_done;
+
+ if (my_rank >= 0 && my_rank != source_rank) {
+ // Begin by receiving the value.
+ int recv_from_rank = TreeRecvFrom(*col_params_, si);
+ Notification note;
+ DispatchRecv(si, recv_from_rank, my_rank, col_ctx_->output,
+ [this, &mu, &note](const Status& s) {
+ mutex_lock l(mu);
+ status_.Update(s);
+ note.Notify();
+ });
+ note.WaitForNotification();
+ }
+
+ // Then forward value to all descendent devices.
+ if (my_rank >= 0 && status_.ok()) {
+ std::vector<int> send_to_ranks;
+ TreeSendTo(*col_params_, si, &send_to_ranks);
+ for (int i = 0; i < send_to_ranks.size(); ++i) {
+ int target_rank = send_to_ranks[i];
+ {
+ mutex_lock l(mu);
+ ++pending_count;
+ }
+ DispatchSend(si, target_rank, my_rank,
+ (is_source_ ? col_ctx_->input : col_ctx_->output),
+ [this, &mu, &pending_count, &all_done](const Status& s) {
+ mutex_lock l(mu);
+ status_.Update(s);
+ --pending_count;
+ if (pending_count == 0) {
+ all_done.notify_all();
+ }
+ });
+ }
+ }
+
+ // For the original source device, we copy input to output if they are
+ // different.
+ // If there is only 1 subdiv, we do this in that subdiv. If there is more
+ // than 1 subdiv, then the original source device will participate in 2
+ // subdivs - the global inter-task broadcast and one local intra-task
+ // broadcast. In this case, we perform the copy in the second subdiv for
+ // this device.
+ if (status_.ok() && is_source_ && (1 == num_subdivs || 0 != si)) {
+ VLOG(2) << "copying input to output for device=" << col_ctx_->device_name
+ << " subdiv=" << si;
+ if (col_ctx_->input != col_ctx_->output &&
+ (DMAHelper::base(col_ctx_->input) !=
+ DMAHelper::base(col_ctx_->output))) {
+ {
+ mutex_lock l(mu);
+ ++pending_count;
+ }
+ DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context();
+ CollectiveRemoteAccessLocal::MemCpyAsync(
+ op_dev_ctx, op_dev_ctx, col_ctx_->device, col_ctx_->device,
+ col_ctx_->op_ctx->input_alloc_attr(0),
+ col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input,
+ col_ctx_->output, 0, /*stream_index*/
+ [this, &mu, &pending_count, &all_done](const Status& s) {
+ mutex_lock l(mu);
+ status_.Update(s);
+ --pending_count;
+ if (0 == pending_count) {
+ all_done.notify_all();
+ }
+ });
+ }
+ }
+
+ // Then wait for all pending actions to complete.
+ {
+ mutex_lock l(mu);
+ if (pending_count > 0) {
+ all_done.wait(l);
+ }
+ }
+ }
+ VLOG(2) << "device=" << col_ctx_->device_name << " return status " << status_;
+ done_(status_);
+}
+
+void HierarchicalTreeBroadcaster::DispatchSend(int subdiv, int dst_rank,
+ int src_rank,
+ const Tensor* src_tensor,
+ const StatusCallback& done) {
+ string send_buf_key =
+ BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank);
+ int dst_idx =
+ col_params_->instance.impl_details.subdiv_permutations[subdiv][dst_rank];
+ VLOG(3) << "DispatchSend " << send_buf_key << " from_device "
+ << col_ctx_->device_name << " to_device "
+ << col_params_->instance.device_names[dst_idx] << " subdiv=" << subdiv
+ << " dst_rank=" << dst_rank << " dst_idx=" << dst_idx;
+ col_ctx_->col_exec->PostToPeer(col_params_->instance.device_names[dst_idx],
+ col_params_->instance.task_names[dst_idx],
+ send_buf_key, col_ctx_->device,
+ col_ctx_->op_ctx->op_device_context(),
+ col_ctx_->op_ctx->output_alloc_attr(0),
+ src_tensor, col_ctx_->device_locality, done);
+}
+
+void HierarchicalTreeBroadcaster::DispatchRecv(int subdiv, int src_rank,
+ int dst_rank, Tensor* dst_tensor,
+ const StatusCallback& done) {
+ string recv_buf_key =
+ BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank);
+ int src_idx =
+ col_params_->instance.impl_details.subdiv_permutations[subdiv][src_rank];
+ VLOG(3) << "DispatchRecv " << recv_buf_key << " from_device "
+ << col_params_->instance.device_names[src_idx] << " to_device "
+ << col_ctx_->device_name << " subdiv=" << subdiv
+ << " src_rank=" << src_rank << " src_idx=" << src_idx;
+ col_ctx_->col_exec->RecvFromPeer(
+ col_params_->instance.device_names[src_idx],
+ col_params_->instance.task_names[src_idx],
+ col_params_->task.is_local[src_idx], recv_buf_key, col_ctx_->device,
+ col_ctx_->op_ctx->op_device_context(),
+ col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
+ col_ctx_->device_locality, 0 /*stream_index*/, done);
+}
+
+REGISTER_COLLECTIVE(HierarchicalTreeBroadcast, HierarchicalTreeBroadcaster);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/broadcaster.h b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h
index 799228b161..ceb9baad30 100644
--- a/tensorflow/core/common_runtime/broadcaster.h
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h
@@ -12,25 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BROADCASTER_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_BROADCASTER_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_
#include <vector>
+
#include "tensorflow/core/common_runtime/base_collective_executor.h"
#include "tensorflow/core/framework/collective.h"
-#include "tensorflow/core/framework/device_attributes.pb.h"
namespace tensorflow {
-// Tree-algorithm implementation of collective broadcast.
-class Broadcaster {
+// Hierarchical tree-algorithm implementation of collective broadcast.
+class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface {
public:
- Broadcaster(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
- OpKernelContext* ctx, OpKernelContext::Params* params,
- const CollectiveParams& col_params, const string& exec_key,
- int64 step_id, Tensor* output);
+ HierarchicalTreeBroadcaster();
+ ~HierarchicalTreeBroadcaster() override = default;
+
+ // Establishes the subdiv permutations needed for a hierarchical broadcast.
+ // If all devices are local, establishes a single subdiv comprising all
+ // devices. If any devices are on a different task, establishes n+1 subdivs
+ // for n tasks.
+ // The first subdiv comprises one device per task which gets the tensor on
+ // each task. Subdiv i+1 corresponds to a task-local tree-broadcast for task
+ // i.
+ Status InitializeCollectiveParams(CollectiveParams* col_params) override;
- void Run(StatusCallback done);
+ // Initializes members of CollectiveContext not yet initialized, i.e. device
+ // and device_locality. Also saves the CollectiveContext in this object.
+ Status InitializeCollectiveContext(CollectiveContext* col_ctx) override;
+
+ // Begins async execution of the hierarchical tree broadcast.
+ // Must be called in a blockable thread.
+ // TODO(b/80529858): remove the previous warning when we have a dedicated
+ // collective threadpool.
+ void Run(StatusCallback done) override;
// Returns the rank of the device from which this device should receive
// its value, -1 if no value should be received.
@@ -42,32 +57,29 @@ class Broadcaster {
std::vector<int>* targets);
private:
+ // Get the task to which the device at `device_rank` belongs.
+ int GetDeviceTask(int device_rank, const std::vector<int>& dev_per_task);
+
// Sends `src_tensor` asynchronously from this device to device at `dst_rank`
// in `subdiv`. Calls `done` upon completion.
void DispatchSend(int subdiv, int dst_rank, int src_rank,
const Tensor* src_tensor, const StatusCallback& done);
+
// Receives a tensor into the memory buffer owned by `dst_tensor` at this
// device from device at `src_rank` in `subdiv`. Calls `done` upon
// completion.
void DispatchRecv(int subdiv, int src_rank, int dst_rank, Tensor* dst_tensor,
const StatusCallback& done);
+
// Executes the hierarchical broadcast defined by this op.
void RunTree();
- Status status_;
- CollectiveExecutor* col_exec_; // Not owned
- const DeviceMgr* dev_mgr_; // Not owned
- OpKernelContext* ctx_; // Not owned
- const CollectiveParams& col_params_;
- const string exec_key_;
- const int rank_;
- const bool is_source_;
- Tensor* output_; // Not owned
- std::unique_ptr<CollectiveAdapter> ca_;
+ CollectiveContext* col_ctx_; // Not owned
+ const CollectiveParams* col_params_; // Not owned
StatusCallback done_;
- Device* device_; // The device for which this instance labors
- DeviceLocality device_locality_;
+ Status status_;
+ bool is_source_;
};
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_COMMON_RUNTIME_BROADCASTER_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_
diff --git a/tensorflow/core/common_runtime/broadcaster_test.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
index 3960fc6c97..da0e359cf8 100644
--- a/tensorflow/core/common_runtime/broadcaster_test.cc
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/common_runtime/broadcaster.h"
+#include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
#include <algorithm>
#include "tensorflow/core/common_runtime/base_collective_executor.h"
@@ -41,7 +41,7 @@ static int64 kStepId = 123;
// The test harness won't allow a mixture of fixture and non-fixture
// tests in one file, so this is a trival fixture for tests that don't
-// need the heavy-weight BroadcasterTest fixture.
+// need the heavy-weight HierarchicalTreeBroadcasterTest fixture.
class TrivialTest : public ::testing::Test {
protected:
TrivialTest() {}
@@ -53,23 +53,23 @@ class TrivialTest : public ::testing::Test {
// R = tested rank
// RF = receive-from rank
// ST = send_to rank vector
-#define DEF_TL_TEST(D, S, R, RF, ST) \
- TEST_F(TrivialTest, TreeLinks_##D##Devs_##S##Source_##R##Rank) { \
- CollectiveParams cp; \
- cp.group.group_size = D; \
- cp.instance.impl_details.subdiv_source_rank = {S}; \
- cp.instance.impl_details.subdiv_permutations.push_back( \
- std::vector<int>(D, 0)); \
- cp.subdiv_rank = {R}; \
- cp.is_source = (S == R); \
- EXPECT_EQ(RF, Broadcaster::TreeRecvFrom(cp, 0)); \
- std::vector<int> expected = ST; \
- std::vector<int> send_to; \
- Broadcaster::TreeSendTo(cp, 0, &send_to); \
- ASSERT_EQ(expected.size(), send_to.size()); \
- for (int i = 0; i < expected.size(); ++i) { \
- EXPECT_EQ(expected[i], send_to[i]); \
- } \
+#define DEF_TL_TEST(D, S, R, RF, ST) \
+ TEST_F(TrivialTest, TreeLinks_##D##Devs_##S##Source_##R##Rank) { \
+ CollectiveParams cp; \
+ cp.group.group_size = D; \
+ cp.instance.impl_details.subdiv_source_rank = {S}; \
+ cp.instance.impl_details.subdiv_permutations.push_back( \
+ std::vector<int>(D, 0)); \
+ cp.subdiv_rank = {R}; \
+ cp.is_source = (S == R); \
+ EXPECT_EQ(RF, HierarchicalTreeBroadcaster::TreeRecvFrom(cp, 0)); \
+ std::vector<int> expected = ST; \
+ std::vector<int> send_to; \
+ HierarchicalTreeBroadcaster::TreeSendTo(cp, 0, &send_to); \
+ ASSERT_EQ(expected.size(), send_to.size()); \
+ for (int i = 0; i < expected.size(); ++i) { \
+ EXPECT_EQ(expected[i], send_to[i]); \
+ } \
}
#define V(...) std::vector<int>({__VA_ARGS__})
@@ -130,7 +130,7 @@ DEF_TL_TEST(8, 7, 7, -1, V(0, 1))
// Wraps CollectiveRemoteAccessLocal with the ability to return an
// error status to the N'th action.
-// TODO(tucker): factor out of this file and ring_reducer_test.cc
+// TODO(b/113171733): factor out of this file and ring_reducer_test.cc
// into a single common source.
class FailTestRMA : public CollectiveRemoteAccessLocal {
public:
@@ -187,31 +187,32 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
int fail_after_ GUARDED_BY(mu_);
};
-class BroadcasterTest : public ::testing::Test {
+class HierarchicalTreeBroadcasterTest : public ::testing::Test {
protected:
- BroadcasterTest() : device_type_(DEVICE_CPU) {}
+ HierarchicalTreeBroadcasterTest() : device_type_(DEVICE_CPU) {}
- ~BroadcasterTest() override {
+ ~HierarchicalTreeBroadcasterTest() override {
stop_ = true;
- for (auto i : instances_) {
- delete i;
- }
+ for (auto i : instances_) delete i;
if (col_exec_) col_exec_->Unref();
}
- void SetUp() override {
-#if GOOGLE_CUDA
+#ifdef GOOGLE_CUDA
+ void InitGPUDevices() {
auto device_factory = DeviceFactory::GetFactory("GPU");
CHECK(device_factory);
SessionOptions options;
Status s = device_factory->CreateDevices(
options, "/job:worker/replica:0/task:0", &gpu_devices_);
CHECK(s.ok());
-#endif
}
+#endif
void Init(int num_workers, int num_devices_per_worker, DataType dtype,
const DeviceType& device_type, int fail_after) {
+#ifdef GOOGLE_CUDA
+ InitGPUDevices();
+#endif
VLOG(2) << "num_workers=" << num_workers
<< " num_devices_per_worker=" << num_devices_per_worker;
int total_num_devices = num_workers * num_devices_per_worker;
@@ -400,8 +401,6 @@ class BroadcasterTest : public ::testing::Test {
return GetKernel(node_def, device_type, device);
}
- void BuildColParams() {}
-
template <typename T>
void RunTest(DataType dtype, const DeviceType& device_type, int num_workers,
int num_devices, int tensor_len, int fail_after,
@@ -511,10 +510,47 @@ class BroadcasterTest : public ::testing::Test {
}
}
+ void RunSubdivPermsTest(
+ CollectiveParams* cp,
+ const std::vector<std::vector<int>>& expected_subdiv_perms,
+ const std::vector<int>& expected_subdiv_rank,
+ const std::vector<int>& expected_subdiv_source_rank) {
+ col_exec_ = nullptr;
+ cp->instance.impl_details.subdiv_permutations.clear();
+ cp->subdiv_rank.clear();
+ cp->instance.impl_details.subdiv_source_rank.clear();
+ // Create a stub broadcaster only for testing param initialization.
+ HierarchicalTreeBroadcaster broadcaster;
+ TF_CHECK_OK(broadcaster.InitializeCollectiveParams(cp));
+ EXPECT_EQ(expected_subdiv_perms,
+ cp->instance.impl_details.subdiv_permutations);
+ EXPECT_EQ(expected_subdiv_rank, cp->subdiv_rank);
+ EXPECT_EQ(expected_subdiv_source_rank,
+ cp->instance.impl_details.subdiv_source_rank);
+ }
+
+ void PrepColParamsForSubdivPermsTest(CollectiveParams* cp, int num_tasks,
+ int num_gpus) {
+ cp->group.device_type = DeviceType("GPU");
+ cp->group.num_tasks = num_tasks;
+ cp->group.group_size = num_tasks * num_gpus;
+ cp->instance.type = BROADCAST_COLLECTIVE;
+ cp->instance.impl_details.collective_name = "HierarchicalTreeBroadcast";
+ for (int ti = 0; ti < num_tasks; ti++) {
+ string task_name = strings::StrCat("/job:worker/replica:0/task:", ti);
+ for (int di = 0; di < num_gpus; di++) {
+ string dev_name = strings::StrCat(task_name, "/device:GPU:", di);
+ cp->instance.task_names.push_back(task_name);
+ cp->instance.device_names.push_back(dev_name);
+ }
+ }
+ }
+
class DeviceInstance {
public:
DeviceInstance(int rank, const string& dev_name,
- const DeviceType& device_type, BroadcasterTest* parent)
+ const DeviceType& device_type,
+ HierarchicalTreeBroadcasterTest* parent)
: parent_(parent),
dev_name_(dev_name),
device_type_(device_type),
@@ -636,21 +672,20 @@ class BroadcasterTest : public ::testing::Test {
ctx.allocate_output(0, tensor_.shape(), &output_tensor_ptr));
}
CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0));
+ const Tensor* input_tensor_ptr =
+ col_params_.is_source ? &tensor_ : nullptr;
// Prepare a Broadcaster instance.
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
- Broadcaster broadcaster(parent_->col_exec_, parent_->dev_mgr_.get(), &ctx,
- &op_params, col_params_, exec_key, kStepId,
- output_tensor_ptr);
-
- // Start execution in a threadpool then wait for completion.
- Notification notification;
- broadcaster.Run([this, &notification](Status s) {
- status_ = s;
- notification.Notify();
- });
- notification.WaitForNotification();
+ HierarchicalTreeBroadcaster broadcaster;
+ CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
+ &ctx, &op_params, col_params_, exec_key,
+ kStepId, input_tensor_ptr, output_tensor_ptr);
+ TF_CHECK_OK(broadcaster.InitializeCollectiveContext(&col_ctx));
+
+ // Run the broadcast.
+ broadcaster.Run([this](Status s) { status_ = s; });
if (status_.ok()) {
CHECK(tensor_.CopyFrom(*ctx.mutable_output(0), tensor_.shape()));
}
@@ -658,15 +693,13 @@ class BroadcasterTest : public ::testing::Test {
dev_ctx->Unref();
}
- BroadcasterTest* parent_;
+ HierarchicalTreeBroadcasterTest* parent_;
string dev_name_;
DeviceType device_type_ = DEVICE_CPU;
int rank_;
Tensor tensor_;
Device* device_;
CollectiveParams col_params_;
- std::unique_ptr<CollectiveAdapter> ca_;
- std::unique_ptr<OpKernelContext> ctx_;
Status status_;
}; // class DeviceInstance
@@ -688,6 +721,118 @@ class BroadcasterTest : public ::testing::Test {
int failure_count_ GUARDED_BY(mu_) = 0;
};
+TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams1Task8GPU) {
+ CollectiveParams cp;
+ PrepColParamsForSubdivPermsTest(&cp, 1, 8);
+
+ // source 0 device 0
+ cp.source_rank = 0;
+ cp.default_rank = 0;
+ RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {0});
+
+ // source 2 device 2
+ cp.source_rank = 2;
+ cp.default_rank = 2;
+ RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {2}, {2});
+
+ // source 2 device 0
+ cp.source_rank = 2;
+ cp.default_rank = 0;
+ RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {2});
+}
+
+TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) {
+ CollectiveParams cp;
+ PrepColParamsForSubdivPermsTest(&cp, 4, 8);
+
+ // source 0 device 0
+ cp.source_rank = 0;
+ cp.default_rank = 0;
+ RunSubdivPermsTest(&cp,
+ {{0, 8, 16, 24},
+ {0, 1, 2, 3, 4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13, 14, 15},
+ {16, 17, 18, 19, 20, 21, 22, 23},
+ {24, 25, 26, 27, 28, 29, 30, 31}},
+ {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
+
+ // source 2 device 0
+ cp.source_rank = 2;
+ cp.default_rank = 0;
+ RunSubdivPermsTest(&cp,
+ {{2, 8, 16, 24},
+ {0, 1, 2, 3, 4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13, 14, 15},
+ {16, 17, 18, 19, 20, 21, 22, 23},
+ {24, 25, 26, 27, 28, 29, 30, 31}},
+ {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
+
+ // source 9 device 9
+ cp.source_rank = 9;
+ cp.default_rank = 9;
+ RunSubdivPermsTest(&cp,
+ {{0, 9, 16, 24},
+ {0, 1, 2, 3, 4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13, 14, 15},
+ {16, 17, 18, 19, 20, 21, 22, 23},
+ {24, 25, 26, 27, 28, 29, 30, 31}},
+ {1, -1, 1, -1, -1}, {1, 0, 1, 0, 0});
+}
+
+TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) {
+ CollectiveParams cp;
+ int num_tasks = 4;
+ cp.group.device_type = DeviceType("GPU");
+ cp.group.num_tasks = num_tasks;
+ cp.group.group_size = 0;
+ cp.instance.type = BROADCAST_COLLECTIVE;
+ cp.instance.impl_details.collective_name = "HierarchicalTreeBroadcast";
+ std::vector<int> dev_per_task = {4, 4, 6, 8};
+ for (int ti = 0; ti < cp.group.num_tasks; ti++) {
+ string task_name = strings::StrCat("/job:worker/replica:0/task:", ti);
+ for (int di = 0; di < dev_per_task[ti]; di++) {
+ string dev_name = strings::StrCat(task_name, "/device:GPU:", di);
+ cp.instance.task_names.push_back(task_name);
+ cp.instance.device_names.push_back(dev_name);
+ cp.group.group_size++;
+ }
+ }
+
+ // source 0 device 0
+ cp.source_rank = 0;
+ cp.default_rank = 0;
+ RunSubdivPermsTest(&cp,
+ {{0, 4, 8, 14},
+ {0, 1, 2, 3},
+ {4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13},
+ {14, 15, 16, 17, 18, 19, 20, 21}},
+ {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
+
+ // source 2 device 0
+ cp.source_rank = 2;
+ cp.default_rank = 0;
+ RunSubdivPermsTest(&cp,
+ {{2, 4, 8, 14},
+ {0, 1, 2, 3},
+ {4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13},
+ {14, 15, 16, 17, 18, 19, 20, 21}},
+ {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
+
+ // source 9 device 5
+ cp.source_rank = 9;
+ cp.default_rank = 5;
+ RunSubdivPermsTest(&cp,
+ {{0, 4, 9, 14},
+ {0, 1, 2, 3},
+ {4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13},
+ {14, 15, 16, 17, 18, 19, 20, 21}},
+ {-1, -1, 1, -1, -1}, {2, 0, 0, 1, 0});
+}
+
+// TODO(b/113171733): change to use TEST_P.
// Tests of full broadcast algorithm, with different device and
// data types.
// B = data element type
@@ -697,7 +842,7 @@ class BroadcasterTest : public ::testing::Test {
// L = tensor length
// A = abort after count
#define DEF_TEST(B, T, W, D, L, A, F) \
- TEST_F(BroadcasterTest, \
+ TEST_F(HierarchicalTreeBroadcasterTest, \
DaTy##B##_DevTy##T##_Wkr##W##_Dev##D##_Len##L##_Abt##A##_Fw##F) { \
DataType dtype = DT_##B; \
switch (dtype) { \
diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc
index e26761703b..bb8eeb141a 100644
--- a/tensorflow/core/common_runtime/ring_reducer.cc
+++ b/tensorflow/core/common_runtime/ring_reducer.cc
@@ -14,13 +14,29 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/ring_reducer.h"
+#include <stdlib.h>
+#include <atomic>
+#include <functional>
+#include <utility>
+
#include "tensorflow/core/common_runtime/collective_rma_local.h"
+#include "tensorflow/core/common_runtime/collective_util.h"
#include "tensorflow/core/common_runtime/copy_tensor.h"
+#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/types.h"
// Set true for greater intelligibility of debug mode log messages.
#define READABLE_KEYS false
@@ -36,7 +52,8 @@ string RingReduceBufKey(const string& exec_key, int pass, int section,
return strings::StrCat("rred(", exec_key, "):pass(", pass, "):section(",
section, "):srcrank(", source_rank, ")");
} else {
- // TODO(tucker): Try out some kind of denser encoding, e.g. 128 bit hash.
+ // TODO(b/78352018): Try out some kind of denser encoding, e.g. 128 bit
+ // hash.
return strings::StrCat(exec_key, ":", pass, ":", section, ":", source_rank);
}
}
@@ -65,105 +82,149 @@ RingReducer::RingField* RingReducer::PCQueue::Dequeue() {
return rf;
}
-RingReducer::RingReducer(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
- OpKernelContext* ctx,
- OpKernelContext::Params* op_params,
- const CollectiveParams& col_params,
- const string& exec_key, int64 step_id,
- const Tensor* input, Tensor* output)
- : col_exec_(col_exec),
- dev_mgr_(dev_mgr),
- ctx_(ctx),
- op_params_(op_params),
- col_params_(col_params),
- exec_key_(exec_key),
- input_(input),
- output_(output),
- rank_(col_params.subdiv_rank[0]),
- step_id_(step_id),
- group_size_(col_params.group.group_size),
- num_subdivs_(static_cast<int>(
- col_params.instance.impl_details.subdiv_permutations.size())),
+RingReducer::RingReducer()
+ : col_ctx_(nullptr),
+ col_params_(nullptr),
done_(nullptr),
- device_(nullptr),
- device_name_(
- col_params_.instance.device_names[col_params_.default_rank]) {
- CHECK_GT(group_size_, 0);
- CHECK_GT(num_subdivs_, 0);
-}
+ group_size_(-1),
+ num_subdivs_(-1) {}
RingReducer::~RingReducer() { group_size_tensor_ready_.WaitForNotification(); }
-string RingReducer::TensorDebugString(Tensor tensor) {
- const DeviceBase::GpuDeviceInfo* gpu_device_info =
- ctx_->device()->tensorflow_gpu_device_info();
- if (gpu_device_info) {
- Tensor cpu_tensor(tensor.dtype(), tensor.shape());
- Notification note;
- gpu_device_info->default_context->CopyDeviceTensorToCPU(
- &tensor, "" /*tensor_name*/, device_, &cpu_tensor,
- [&note](const Status& s) {
- CHECK(s.ok());
- note.Notify();
- });
- note.WaitForNotification();
- return cpu_tensor.SummarizeValue(64);
- } else {
- return tensor.SummarizeValue(64);
+Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) {
+ CHECK_EQ(col_params->instance.type, REDUCTION_COLLECTIVE);
+ CHECK_EQ(col_params->instance.impl_details.collective_name, "RingReduce");
+ const string& device_name =
+ col_params->instance.device_names[col_params->default_rank];
+ // Each subdiv permutation is a ring formed by rotating each
+ // single-task subsequence of devices by an offset. This makes most
+ // sense when each task has the same number of devices but we can't
+ // depend on that being the case so we'll compute something that
+ // works in any case.
+
+ // Start by counting the devices in each task.
+ // Precondition: device_names must be sorted so that all devices in
+ // the same task are adjacent.
+ VLOG(2) << "Sorted task names: "
+ << str_util::Join(col_params->instance.task_names, ", ");
+ std::vector<int> dev_per_task;
+ const string* prior_task_name = &col_params->instance.task_names[0];
+ int dev_count = 1;
+ for (int di = 1; di < col_params->group.group_size; ++di) {
+ if (col_params->instance.task_names[di] != *prior_task_name) {
+ dev_per_task.push_back(dev_count);
+ dev_count = 1;
+ prior_task_name = &col_params->instance.task_names[di];
+ } else {
+ ++dev_count;
+ }
+ }
+ dev_per_task.push_back(dev_count);
+ CHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
+
+ // Generate a ring permutation for each requested offset.
+ if (col_params->instance.impl_details.subdiv_offsets.empty()) {
+ return errors::Internal(
+ "Subdiv offsets should be non-empty for ring reducer, size=",
+ col_params->instance.impl_details.subdiv_offsets.size());
+ }
+ VLOG(2) << "Setting up perms for col_params " << col_params
+ << " subdiv_permutations "
+ << &col_params->instance.impl_details.subdiv_permutations;
+ col_params->instance.impl_details.subdiv_permutations.resize(
+ col_params->instance.impl_details.subdiv_offsets.size());
+ col_params->subdiv_rank.resize(
+ col_params->instance.impl_details.subdiv_offsets.size(), -1);
+ for (int sdi = 0;
+ sdi < col_params->instance.impl_details.subdiv_offsets.size(); ++sdi) {
+ std::vector<int>& perm =
+ col_params->instance.impl_details.subdiv_permutations[sdi];
+ CHECK_EQ(perm.size(), 0);
+ int offset = col_params->instance.impl_details.subdiv_offsets[sdi];
+ // A negative subdivision offset is interpreted as follows:
+ // 1. Reverse the local device ordering.
+ // 2. Begin the subdivision at abs(offset) in the reversed ordering.
+ bool reverse = false;
+ if (offset < 0) {
+ offset = abs(offset);
+ reverse = true;
+ }
+ int prior_dev_count = 0; // sum over prior worker device counts
+ for (int ti = 0; ti < col_params->group.num_tasks; ++ti) {
+ for (int di = 0; di < dev_per_task[ti]; ++di) {
+ int di_offset = (di + offset) % dev_per_task[ti];
+ int offset_di =
+ reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset;
+ // Device index in global subdivision permutation.
+ int permuted_di = prior_dev_count + offset_di;
+ int rank = static_cast<int>(perm.size());
+ perm.push_back(permuted_di);
+ if (col_params->instance.device_names[permuted_di] == device_name) {
+ CHECK_EQ(permuted_di, col_params->default_rank);
+ col_params->subdiv_rank[sdi] = rank;
+ }
+ }
+ prior_dev_count += dev_per_task[ti];
+ }
+ CHECK_EQ(col_params->group.group_size, perm.size());
}
+
+ VLOG(2) << collective_util::SubdivPermDebugString(*col_params);
+ return Status::OK();
+}
+
+Status RingReducer::InitializeCollectiveContext(CollectiveContext* col_ctx) {
+ CHECK(col_ctx->dev_mgr);
+ col_ctx_ = col_ctx;
+ col_params_ = &col_ctx->col_params;
+ return collective_util::InitializeDeviceAndLocality(
+ col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
+ &col_ctx->device_locality);
}
void RingReducer::Run(StatusCallback done) {
+ CHECK(col_ctx_);
+ CHECK(col_params_);
done_ = std::move(done);
+ group_size_ = col_params_->group.group_size;
+ num_subdivs_ = static_cast<int>(
+ col_params_->instance.impl_details.subdiv_permutations.size());
+ CHECK_GT(num_subdivs_, 0);
- // Get local execution device.
if (VLOG_IS_ON(1)) {
string buf;
- for (int r = 0; r < col_params_.instance.device_names.size(); ++r) {
+ for (int r = 0; r < col_params_->instance.device_names.size(); ++r) {
strings::StrAppend(&buf, "dev ", r, " : ",
- col_params_.instance.device_names[r], "\n");
+ col_params_->instance.device_names[r], "\n");
}
for (int sd = 0;
- sd < col_params_.instance.impl_details.subdiv_permutations.size();
+ sd < col_params_->instance.impl_details.subdiv_permutations.size();
++sd) {
strings::StrAppend(&buf, "\nsubdiv ", sd, " perm: ");
- for (auto x : col_params_.instance.impl_details.subdiv_permutations[sd]) {
+ for (auto x :
+ col_params_->instance.impl_details.subdiv_permutations[sd]) {
strings::StrAppend(&buf, x, ", ");
}
}
- VLOG(1) << "RingReducer::Run for device " << device_name_
- << " default_rank " << col_params_.default_rank << "\n"
+ VLOG(1) << "RingReducer::Run for device " << col_ctx_->device_name
+ << " default_rank " << col_params_->default_rank << "\n"
<< buf;
}
- CHECK(dev_mgr_);
- Status status = dev_mgr_->LookupDevice(
- col_params_.instance.device_names[col_params_.default_rank], &device_);
- if (!status.ok()) {
- LOG(ERROR) << "Failed to find device "
- << col_params_.instance.device_names[col_params_.default_rank];
- for (auto d : dev_mgr_->ListDevices()) {
- LOG(ERROR) << "Available device " << d->name();
- }
- done_(status);
- return;
- }
- CHECK(device_);
- device_locality_ = device_->attributes().locality();
-
- VLOG(1) << this << " default_rank " << col_params_.default_rank << " cp "
- << &col_params_ << ": " << col_params_.ToString();
// Start by copying input to output if they're not already the same, i.e. if
// we're not computing in-place on the input tensor.
- if ((input_ != output_) &&
- (DMAHelper::base(input_) != DMAHelper::base(output_))) {
+ if ((col_ctx_->input != col_ctx_->output) &&
+ (DMAHelper::base(col_ctx_->input) != DMAHelper::base(col_ctx_->output))) {
// We are running in a blockable thread and the callback can't block so
// just wait here on the copy.
Notification note;
+ Status status;
CollectiveRemoteAccessLocal::MemCpyAsync(
- ctx_->input_device_context(0), ctx_->op_device_context(), device_,
- device_, ctx_->input_alloc_attr(0), ctx_->output_alloc_attr(0), input_,
- output_, 0 /*dev_to_dev_stream_index*/,
+ col_ctx_->op_ctx->input_device_context(0),
+ col_ctx_->op_ctx->op_device_context(), col_ctx_->device,
+ col_ctx_->device, col_ctx_->op_ctx->input_alloc_attr(0),
+ col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input,
+ col_ctx_->output, 0 /*dev_to_dev_stream_index*/,
[this, &note, &status](const Status& s) {
status.Update(s);
note.Notify();
@@ -177,24 +238,43 @@ void RingReducer::Run(StatusCallback done) {
ContinueAfterInputCopy();
}
+string RingReducer::TensorDebugString(const Tensor& tensor) {
+ const DeviceBase::GpuDeviceInfo* gpu_device_info =
+ col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
+ if (gpu_device_info) {
+ Tensor cpu_tensor(tensor.dtype(), tensor.shape());
+ Notification note;
+ gpu_device_info->default_context->CopyDeviceTensorToCPU(
+ &tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor,
+ [&note](const Status& s) {
+ CHECK(s.ok());
+ note.Notify();
+ });
+ note.WaitForNotification();
+ return cpu_tensor.SummarizeValue(64);
+ } else {
+ return tensor.SummarizeValue(64);
+ }
+}
+
// Note that this function is blocking and must not run in any thread
// which cannot be blocked.
void RingReducer::ContinueAfterInputCopy() {
- AllocatorAttributes attr = ctx_->output_alloc_attr(0);
- ca_.reset(MakeCollectiveAdapter(output_, group_size_ * num_subdivs_,
- device_->GetAllocator(attr)));
+ AllocatorAttributes attr = col_ctx_->op_ctx->output_alloc_attr(0);
+ ca_.reset(MakeCollectiveAdapter(col_ctx_->output, group_size_ * num_subdivs_,
+ col_ctx_->device->GetAllocator(attr)));
- if (col_params_.final_op) {
+ if (col_params_->final_op) {
// Create an on-device scalar value from group_size_ that may be needed
// later.
// TODO(tucker): Cache and reuse across invocations? Or maybe the scalar
// can be provided to the kernel in host memory?
Tensor group_size_val = ca_->Scalar(group_size_);
- if (col_params_.group.device_type != "CPU") {
- group_size_tensor_ =
- ca_->Scalar(device_->GetAllocator(ctx_->input_alloc_attr(0)));
- DeviceContext* op_dev_ctx = ctx_->op_device_context();
- op_dev_ctx->CopyCPUTensorToDevice(&group_size_val, device_,
+ if (col_params_->group.device_type != "CPU") {
+ group_size_tensor_ = ca_->Scalar(col_ctx_->device->GetAllocator(
+ col_ctx_->op_ctx->input_alloc_attr(0)));
+ DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context();
+ op_dev_ctx->CopyCPUTensorToDevice(&group_size_val, col_ctx_->device,
&group_size_tensor_,
[this](const Status& s) {
if (!s.ok()) {
@@ -231,14 +311,14 @@ void RingReducer::StartAbort(const Status& s) {
// cancellation on all of the outstanding CollectiveRemoteAccess
// actions.
if (abort_started) {
- col_exec_->StartAbort(s);
+ col_ctx_->col_exec->StartAbort(s);
}
}
void RingReducer::Finish(bool ok) {
if (ok) {
// Recover the output from the adaptor.
- ca_->ConsumeFinalValue(output_);
+ ca_->ConsumeFinalValue(col_ctx_->output);
}
Status s;
{
@@ -275,7 +355,7 @@ Status RingReducer::ComputeBinOp(Device* device, OpKernel* op, Tensor* output,
// TODO(tucker): Is it possible to cache and reuse these objects? They're
// mostly identical inside one device execution.
std::unique_ptr<SubContext> sub_ctx(
- new SubContext(ctx_, op_params_, op, output, input));
+ new SubContext(col_ctx_->op_ctx, col_ctx_->op_params, op, output, input));
device->Compute(op, sub_ctx->sub_ctx_);
return sub_ctx->sub_ctx_->status();
}
@@ -295,18 +375,18 @@ void RingReducer::InitRingField(RingField* rf, int chunk_idx, int subdiv_idx,
rf->chunk_idx = chunk_idx;
rf->subdiv_idx = subdiv_idx;
rf->sc_idx = field_idx;
- rf->rank = col_params_.subdiv_rank[subdiv_idx];
+ rf->rank = col_params_->subdiv_rank[subdiv_idx];
rf->second_pass = false;
rf->action = RF_INIT;
// Recv from the device with preceding rank within the subdivision.
int recv_from_rank = (rf->rank + (group_size_ - 1)) % group_size_;
int send_to_rank = (rf->rank + 1) % group_size_;
- rf->recv_dev_idx = col_params_.instance.impl_details
+ rf->recv_dev_idx = col_params_->instance.impl_details
.subdiv_permutations[subdiv_idx][recv_from_rank];
- int send_dev_idx = col_params_.instance.impl_details
+ int send_dev_idx = col_params_->instance.impl_details
.subdiv_permutations[subdiv_idx][send_to_rank];
- rf->recv_is_remote = !col_params_.task.is_local[rf->recv_dev_idx];
- rf->send_is_remote = !col_params_.task.is_local[send_dev_idx];
+ rf->recv_is_remote = !col_params_->task.is_local[rf->recv_dev_idx];
+ rf->send_is_remote = !col_params_->task.is_local[send_dev_idx];
if (ca_->ChunkBytes(rf->sc_idx) > 0) {
// In pass 0 we skip Recv when rank = chunk_idx
rf->do_recv = (rf->chunk_idx != rf->rank);
@@ -360,45 +440,47 @@ string RingReducer::RingField::DebugString() const {
void RingReducer::DispatchSend(RingField* rf, const StatusCallback& done) {
CHECK(rf->do_send);
- string send_buf_key =
- RingReduceBufKey(exec_key_, rf->second_pass, rf->sc_idx, rf->rank);
- VLOG(3) << "DispatchSend rank=" << col_params_.default_rank << " send key "
+ string send_buf_key = RingReduceBufKey(col_ctx_->exec_key, rf->second_pass,
+ rf->sc_idx, rf->rank);
+ VLOG(3) << "DispatchSend rank=" << col_params_->default_rank << " send key "
<< send_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " sc_idx "
<< rf->sc_idx;
int send_to_rank = (rf->rank + 1) % group_size_;
- int send_to_dev_idx = col_params_.instance.impl_details
+ int send_to_dev_idx = col_params_->instance.impl_details
.subdiv_permutations[rf->subdiv_idx][send_to_rank];
- col_exec_->PostToPeer(col_params_.instance.device_names[send_to_dev_idx],
- col_params_.instance.task_names[send_to_dev_idx],
- send_buf_key, device_, ctx_->op_device_context(),
- ctx_->output_alloc_attr(0), &rf->chunk,
- device_locality_, done);
+ col_ctx_->col_exec->PostToPeer(
+ col_params_->instance.device_names[send_to_dev_idx],
+ col_params_->instance.task_names[send_to_dev_idx], send_buf_key,
+ col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
+ col_ctx_->op_ctx->output_alloc_attr(0), &rf->chunk,
+ col_ctx_->device_locality, done);
}
void RingReducer::DispatchRecv(RingField* rf, const StatusCallback& done) {
CHECK(rf->do_recv);
string recv_buf_key =
- RingReduceBufKey(exec_key_, rf->second_pass, rf->sc_idx,
+ RingReduceBufKey(col_ctx_->exec_key, rf->second_pass, rf->sc_idx,
(rf->rank + (group_size_ - 1)) % group_size_);
- VLOG(3) << "DispatchRecv rank=" << col_params_.default_rank << " recv key "
+ VLOG(3) << "DispatchRecv rank=" << col_params_->default_rank << " recv key "
<< recv_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " into "
- << ((col_params_.merge_op != nullptr) ? "tmp_chunk" : "chunk");
- Tensor* dst_tensor = (!rf->second_pass && (col_params_.merge_op != nullptr))
+ << ((col_params_->merge_op != nullptr) ? "tmp_chunk" : "chunk");
+ Tensor* dst_tensor = (!rf->second_pass && (col_params_->merge_op != nullptr))
? &rf->tmp_chunk
: &rf->chunk;
- col_exec_->RecvFromPeer(col_params_.instance.device_names[rf->recv_dev_idx],
- col_params_.instance.task_names[rf->recv_dev_idx],
- col_params_.task.is_local[rf->recv_dev_idx],
- recv_buf_key, device_, ctx_->op_device_context(),
- ctx_->output_alloc_attr(0), dst_tensor,
- device_locality_, rf->subdiv_idx, done);
+ col_ctx_->col_exec->RecvFromPeer(
+ col_params_->instance.device_names[rf->recv_dev_idx],
+ col_params_->instance.task_names[rf->recv_dev_idx],
+ col_params_->task.is_local[rf->recv_dev_idx], recv_buf_key,
+ col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
+ col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
+ col_ctx_->device_locality, rf->subdiv_idx, done);
}
string RingReducer::FieldState() {
- string s = strings::StrCat("RingReducer ",
- strings::Hex(reinterpret_cast<uint64>(this)),
- " exec ", exec_key_, " step_id=", step_id_,
- " state of all ", rfv_.size(), " fields:");
+ string s = strings::StrCat(
+ "RingReducer ", strings::Hex(reinterpret_cast<uint64>(this)), " exec ",
+ col_ctx_->exec_key, " step_id=", col_ctx_->step_id, " state of all ",
+ rfv_.size(), " fields:");
for (int i = 0; i < rfv_.size(); ++i) {
s.append("\n");
s.append(rfv_[i].DebugString());
@@ -468,8 +550,9 @@ bool RingReducer::RunAsyncParts() {
--recv_pending_count;
if (!rf->second_pass) {
rf->action = RF_REDUCE;
- Status s = ComputeBinOp(device_, col_params_.merge_op.get(),
- &rf->chunk, &rf->tmp_chunk);
+ Status s =
+ ComputeBinOp(col_ctx_->device, col_params_->merge_op.get(),
+ &rf->chunk, &rf->tmp_chunk);
if (!s.ok()) {
aborted = true;
StartAbort(s);
@@ -479,11 +562,12 @@ bool RingReducer::RunAsyncParts() {
}
break;
case RF_REDUCE:
- if (!rf->second_pass && col_params_.final_op.get() && rf->is_final) {
+ if (!rf->second_pass && col_params_->final_op.get() && rf->is_final) {
rf->action = RF_FINALIZE;
group_size_tensor_ready_.WaitForNotification();
- Status s = ComputeBinOp(device_, col_params_.final_op.get(),
- &rf->chunk, &group_size_tensor_);
+ Status s =
+ ComputeBinOp(col_ctx_->device, col_params_->final_op.get(),
+ &rf->chunk, &group_size_tensor_);
if (!s.ok()) {
aborted = true;
StartAbort(s);
@@ -552,9 +636,11 @@ bool RingReducer::RunAsyncParts() {
CHECK_EQ(send_pending_count, 0);
CHECK_EQ(recv_pending_count, 0);
- VLOG(2) << this << " rank=" << rank_ << " finish;"
+ VLOG(2) << this << " device=" << col_ctx_->device_name << " finish;"
<< " final value " << TensorDebugString(ca_->Value());
return !aborted;
}
+REGISTER_COLLECTIVE(RingReduce, RingReducer);
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/ring_reducer.h b/tensorflow/core/common_runtime/ring_reducer.h
index 3e1988e787..0848e37b52 100644
--- a/tensorflow/core/common_runtime/ring_reducer.h
+++ b/tensorflow/core/common_runtime/ring_reducer.h
@@ -16,25 +16,35 @@ limitations under the License.
#define TENSORFLOW_CORE_COMMON_RUNTIME_RING_REDUCER_H_
#include <deque>
+#include <memory>
+#include <string>
+#include <vector>
#include "tensorflow/core/common_runtime/base_collective_executor.h"
#include "tensorflow/core/framework/collective.h"
-#include "tensorflow/core/framework/device_attributes.pb.h"
namespace tensorflow {
-class DeviceMgr;
+class Device;
// Ring-algorithm implementation of collective all-reduce.
-class RingReducer {
+class RingReducer : public CollectiveImplementationInterface {
public:
- RingReducer(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
- OpKernelContext* ctx, OpKernelContext::Params* op_params,
- const CollectiveParams& col_params, const string& exec_key,
- int64 step_id, const Tensor* input, Tensor* output);
+ RingReducer();
+ ~RingReducer() override;
- virtual ~RingReducer();
+ // Establishes the requested number of subdivision permutations based on the
+ // ring order implicit in the device order.
+ Status InitializeCollectiveParams(CollectiveParams* col_params) override;
- void Run(StatusCallback done);
+ // Initializes members of CollectiveContext not yet initialized, i.e. device
+ // and device_locality. Also saves the CollectiveContext in this object.
+ Status InitializeCollectiveContext(CollectiveContext* col_ctx) override;
+
+ // Begins async execution of the ring reduce algorithm.
+ // Must be called in a blockable thread.
+ // TODO(b/80529858): remove the previous warning when we have a dedicated
+ // collective threadpool.
+ void Run(StatusCallback done) override;
private:
// Called when a bad status is received that implies we should terminate
@@ -101,7 +111,7 @@ class RingReducer {
// For constructing log messages for debugging.
string FieldState();
- string TensorDebugString(Tensor tensor);
+ string TensorDebugString(const Tensor& tensor);
// Producer/Consumer Queue of RingField structs.
class PCQueue {
@@ -116,30 +126,19 @@ class RingReducer {
std::deque<RingField*> deque_ GUARDED_BY(pcq_mu_);
};
- CollectiveExecutor* col_exec_; // Not owned
- const DeviceMgr* dev_mgr_; // Not owned
- OpKernelContext* ctx_; // Not owned
- OpKernelContext::Params* op_params_; // Not owned
- const CollectiveParams& col_params_;
- const string exec_key_;
- const Tensor* input_; // Not owned
- Tensor* output_; // Not owned
- const int rank_;
- const int64 step_id_;
- const int group_size_;
- const int num_subdivs_;
+ CollectiveContext* col_ctx_; // Not owned
+ const CollectiveParams* col_params_; // Not owned
+ StatusCallback done_;
+ int group_size_;
+ int num_subdivs_;
Tensor group_size_tensor_;
Notification group_size_tensor_ready_;
std::unique_ptr<CollectiveAdapter> ca_;
- StatusCallback done_;
- Device* device_; // The device for which this instance labors
- const string device_name_;
- DeviceLocality device_locality_;
-
mutex status_mu_;
Status status_ GUARDED_BY(status_mu_);
-
std::vector<RingField> rfv_;
+
+ friend class RingReducerTest;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc
index fcdf9deff8..5e079dbce6 100644
--- a/tensorflow/core/common_runtime/ring_reducer_test.cc
+++ b/tensorflow/core/common_runtime/ring_reducer_test.cc
@@ -37,7 +37,6 @@ limitations under the License.
#include "tensorflow/core/public/version.h"
namespace tensorflow {
-namespace {
// Wraps CollectiveRemoteAccessLocal with the ability to return an
// error status to the N'th action.
@@ -135,27 +134,28 @@ class RingReducerTest : public ::testing::Test {
protected:
RingReducerTest() : device_type_(DEVICE_CPU) {}
- void SetUp() override {
-#if GOOGLE_CUDA
+#ifdef GOOGLE_CUDA
+ void InitGPUDevices() {
auto device_factory = DeviceFactory::GetFactory("GPU");
CHECK(device_factory);
SessionOptions options;
Status s = device_factory->CreateDevices(
options, "/job:worker/replica:0/task:0", &gpu_devices_);
CHECK(s.ok());
-#endif
}
+#endif
~RingReducerTest() override {
stop_ = true;
- for (auto i : instances_) {
- delete i;
- }
+ for (auto i : instances_) delete i;
if (col_exec_) col_exec_->Unref();
}
void Init(int num_workers, int num_devices, DataType dtype,
const DeviceType& device_type, int num_subdivs, int fail_after) {
+#ifdef GOOGLE_CUDA
+ InitGPUDevices();
+#endif
device_type_ = device_type;
std::vector<Device*> local_devices;
SessionOptions sess_opts;
@@ -201,6 +201,7 @@ class RingReducerTest : public ::testing::Test {
col_params_.instance.instance_key = kInstanceKey;
col_params_.instance.impl_details.subdiv_offsets.clear();
col_params_.instance.type = REDUCTION_COLLECTIVE;
+ col_params_.instance.impl_details.collective_name = "RingReduce";
col_params_.instance.data_type = dtype;
col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs);
col_params_.subdiv_rank.resize(num_subdivs);
@@ -373,6 +374,22 @@ class RingReducerTest : public ::testing::Test {
return GetKernel(node_def, device_type, device);
}
+ void RunSubdivPermsTest(
+ CollectiveParams* cp,
+ const std::vector<std::vector<int>>& expected_subdiv_perms,
+ const std::vector<int>& expected_subdiv_rank) {
+ col_exec_ = nullptr;
+ cp->instance.impl_details.subdiv_permutations.clear();
+ cp->subdiv_rank.clear();
+ // Create a stub ring reducer only for testing param initialization.
+ RingReducer reducer;
+ TF_CHECK_OK(reducer.InitializeCollectiveParams(cp));
+ EXPECT_EQ(expected_subdiv_perms,
+ cp->instance.impl_details.subdiv_permutations);
+ EXPECT_EQ(expected_subdiv_rank, cp->subdiv_rank);
+ reducer.group_size_tensor_ready_.Notify(); // To unblock destructor.
+ }
+
class DeviceInstance {
public:
DeviceInstance(int rank, const string& dev_name,
@@ -475,8 +492,8 @@ class RingReducerTest : public ::testing::Test {
op_params.op_kernel = op.get();
OpKernelContext ctx(&op_params, 1);
- // We never actually execute the kernel, so we need to do the
- // output allocation that it would do, ourselves.
+ // We never actually execute the kernel, so we need to do the output
+ // allocation it would do, ourselves.
Tensor* output_tensor_ptr = nullptr;
TF_CHECK_OK(ctx.forward_input_or_allocate_output({0}, 0, tensor_.shape(),
&output_tensor_ptr));
@@ -485,20 +502,17 @@ class RingReducerTest : public ::testing::Test {
// Prepare a RingReducer instance.
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
- RingReducer rr(parent_->col_exec_, parent_->dev_mgr_.get(), &ctx,
- &op_params, col_params_, exec_key, kStepId, &tensor_,
- &tensor_);
-
- // Start execution in a threadpool then wait for completion.
- Notification notification;
- SchedClosure([this, &notification, &rr]() {
- rr.Run([this, &notification](Status s) {
- status_ = s;
- notification.Notify();
- });
- });
- notification.WaitForNotification();
- CHECK(tensor_.CopyFrom(*ctx.mutable_output(0), tensor_.shape()));
+ RingReducer reducer;
+ CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
+ &ctx, &op_params, col_params_, exec_key,
+ kStepId, &tensor_, &tensor_);
+ TF_CHECK_OK(reducer.InitializeCollectiveContext(&col_ctx));
+
+ // Run the all-reduce.
+ reducer.Run([this](Status s) { status_ = s; });
+ if (status_.ok()) {
+ CHECK(tensor_.CopyFrom(*ctx.mutable_output(0), tensor_.shape()));
+ }
dev_ctx->Unref();
}
@@ -531,6 +545,57 @@ class RingReducerTest : public ::testing::Test {
int32 reduce_counter_ GUARDED_BY(mu_) = 0;
};
+TEST_F(RingReducerTest, InitializeParams) {
+ static const int kNumDevsPerTask = 8;
+ static const int kNumTasks = 3;
+ static const int kNumDevs = kNumDevsPerTask * kNumTasks;
+ CollectiveParams cp;
+ std::vector<string> device_names;
+ std::vector<string> task_names;
+ cp.group.group_key = 1;
+ cp.group.group_size = kNumDevs;
+ cp.group.device_type = DeviceType("GPU");
+ cp.group.num_tasks = kNumTasks;
+ cp.instance.instance_key = 3;
+ cp.instance.type = REDUCTION_COLLECTIVE;
+ cp.instance.data_type = DataType(DT_FLOAT);
+ cp.instance.shape = TensorShape({5});
+ cp.instance.impl_details.collective_name = "RingReduce";
+ cp.instance.impl_details.subdiv_offsets.push_back(0);
+ cp.is_source = false;
+ for (int i = 0; i < kNumDevs; ++i) {
+ int task_id = i / kNumDevsPerTask;
+ int dev_id = i % kNumDevsPerTask;
+ string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
+ task_names.push_back(task_name);
+ string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
+ device_names.push_back(device_name);
+ cp.instance.task_names.push_back(task_name);
+ cp.instance.device_names.push_back(device_name);
+ }
+
+ int test_rank = 0;
+ cp.default_rank = test_rank;
+ cp.instance.impl_details.subdiv_offsets = {0, 4};
+ RunSubdivPermsTest(&cp,
+ {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
+ {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15,
+ 8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19}},
+ {0, 4});
+
+ test_rank = 3;
+ cp.default_rank = test_rank;
+ cp.instance.impl_details.subdiv_offsets = {3, -3};
+ RunSubdivPermsTest(&cp,
+ {{3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14,
+ 15, 8, 9, 10, 19, 20, 21, 22, 23, 16, 17, 18},
+ {4, 3, 2, 1, 0, 7, 6, 5, 12, 11, 10, 9,
+ 8, 15, 14, 13, 20, 19, 18, 17, 16, 23, 22, 21}},
+ {0, 1});
+}
+
+// TODO(b/113171733): change to use TEST_P.
#define DEF_TEST(B, T, W, D, S, L, A) \
TEST_F(RingReducerTest, \
DaTy##B##_DevTy##T##_Wkr##W##_Dev##D##_Sdiv##S##_Len##L##_Abrt##A) { \
@@ -604,5 +669,4 @@ DEF_TEST(FLOAT, GPU, 1, 8, 1, 9408, 2)
DEF_TEST(FLOAT, GPU, 1, 8, 2, 9408, 5)
#endif
-} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/framework/collective.cc b/tensorflow/core/framework/collective.cc
index d4ac50cbbe..4cb277d5a8 100644
--- a/tensorflow/core/framework/collective.cc
+++ b/tensorflow/core/framework/collective.cc
@@ -21,6 +21,31 @@ limitations under the License.
namespace tensorflow {
+namespace {
+// A RegistrationInfo object stores a collective implementation registration
+// details. `factory` is used to create instances of the collective
+// implementation.
+struct RegistrationInfo {
+ // This constructor also creates, and stores in `param_resolver_instance`,
+ // what is effectively a static instance of the collective implementation.
+ // During param resolution of collective ops we return this static instance.
+ // The actual op execution gets a fresh instance using `factory`.
+ RegistrationInfo(const string& n, CollectiveRegistry::Factory f)
+ : name(n),
+ factory(std::move(f)),
+ param_resolver_instance(this->factory()) {}
+ string name;
+ CollectiveRegistry::Factory factory;
+ CollectiveImplementationInterface* param_resolver_instance;
+};
+
+std::vector<RegistrationInfo>* MutableCollectiveRegistry() {
+ static std::vector<RegistrationInfo>* registry =
+ new std::vector<RegistrationInfo>;
+ return registry;
+}
+} // namespace
+
string CollGroupParams::ToString() const {
return strings::StrCat("CollGroupParams {group_key=", group_key,
" group_size=", group_size,
@@ -102,7 +127,8 @@ string CollectiveParams::ToString() const {
strings::StrAppend(&v, " ", instance.ToString());
strings::StrAppend(&v, " ", task.ToString());
strings::StrAppend(&v, " default_rank=", default_rank,
- " is_source=", is_source, " subdiv_rank={");
+ " is_source=", is_source, " source_rank=", source_rank,
+ " subdiv_rank={");
for (const auto& r : subdiv_rank) {
strings::StrAppend(&v, r, ",");
}
@@ -115,7 +141,81 @@ string CollectiveParams::ToString() const {
return ctx->params_;
}
+CollectiveContext::CollectiveContext(CollectiveExecutor* col_exec,
+ const DeviceMgr* dev_mgr,
+ OpKernelContext* ctx,
+ OpKernelContext::Params* op_params,
+ const CollectiveParams& col_params,
+ const string& exec_key, int64 step_id,
+ const Tensor* input, Tensor* output)
+ : col_exec(col_exec),
+ dev_mgr(dev_mgr),
+ op_ctx(ctx),
+ op_params(op_params),
+ col_params(col_params),
+ exec_key(exec_key),
+ step_id(step_id),
+ input(input),
+ output(output),
+ device(nullptr),
+ device_name(col_params.instance.device_names[col_params.default_rank]) {}
+
/*static*/
int64 CollectiveExecutor::kInvalidId = -1;
+/*static*/
+Status CollectiveRegistry::Lookup(
+ const string& collective_name,
+ CollectiveImplementationInterface** implementation) {
+ return LookupHelper(collective_name, implementation, false);
+}
+
+/*static*/
+Status CollectiveRegistry::LookupParamResolverInstance(
+ const string& collective_name,
+ CollectiveImplementationInterface** implementation) {
+ return LookupHelper(collective_name, implementation, true);
+}
+
+/*static*/
+void CollectiveRegistry::GetAll(
+ std::vector<CollectiveImplementationInterface*>* implementations) {
+ std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
+ for (const RegistrationInfo& reg_info : *registry)
+ implementations->emplace_back(reg_info.factory());
+}
+
+/*static*/
+Status CollectiveRegistry::Register(const string& collective_name,
+ Factory factory) {
+ std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
+ for (const RegistrationInfo& reg_info : *registry) {
+ if (reg_info.name == collective_name)
+ return errors::Internal("Already registered collective ",
+ collective_name);
+ }
+ registry->emplace_back(collective_name, std::move(factory));
+ return Status::OK();
+}
+
+/*static*/
+Status CollectiveRegistry::LookupHelper(
+ const string& collective_name,
+ CollectiveImplementationInterface** implementation, bool param_resolver) {
+ std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
+ for (const RegistrationInfo& reg_info : *registry) {
+ if (reg_info.name == collective_name) {
+ if (param_resolver) {
+ *implementation = reg_info.param_resolver_instance;
+ } else {
+ *implementation = reg_info.factory();
+ }
+ return Status::OK();
+ }
+ }
+ return errors::Internal(
+ "CollectiveRegistry::Lookup did not find collective implementation ",
+ collective_name);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h
index 0b37b3a88c..e35edb09d0 100644
--- a/tensorflow/core/framework/collective.h
+++ b/tensorflow/core/framework/collective.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
@@ -30,7 +31,8 @@ class CompleteGroupRequest;
class CompleteGroupResponse;
class CompleteInstanceRequest;
class CompleteInstanceResponse;
-class DeviceLocality;
+class Device;
+class DeviceMgr;
class GetStepSequenceRequest;
class GetStepSequenceResponse;
class Op;
@@ -64,10 +66,10 @@ struct CollGroupParams {
// interpretation. On first execution the runtime will update this
// structure with decisions that will guide all subsequent executions.
struct CollImplDetails {
+ string collective_name;
std::vector<std::vector<int>> subdiv_permutations;
std::vector<int> subdiv_offsets;
- // broadcast only: rank of source in each subdiv
- std::vector<int> subdiv_source_rank;
+ std::vector<int> subdiv_source_rank; // rank of source in each subdiv
};
// Data common to all members of a collective instance.
@@ -104,6 +106,7 @@ struct CollectiveParams {
string name = ""; // node name used only for log or error messages
int default_rank = -1; // index of this op within device_names
bool is_source = false; // broadcast only
+ int source_rank = -1; // broadcast only
// Rank of this device in each subdivision permutation.
std::vector<int> subdiv_rank;
std::unique_ptr<OpKernel> merge_op; // reduction only
@@ -306,6 +309,110 @@ class PerStepCollectiveRemoteAccess : public CollectiveRemoteAccess {
virtual void StartAbort(const Status& s) = 0;
};
+class CollectiveContext {
+ public:
+ CollectiveContext(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
+ OpKernelContext* ctx, OpKernelContext::Params* op_params,
+ const CollectiveParams& col_params, const string& exec_key,
+ int64 step_id, const Tensor* input, Tensor* output);
+
+ virtual ~CollectiveContext() = default;
+
+ CollectiveExecutor* col_exec; // Not owned
+ const DeviceMgr* dev_mgr; // Not owned
+ OpKernelContext* op_ctx; // Not owned
+ OpKernelContext::Params* op_params; // Not owned
+ const CollectiveParams& col_params;
+ const string exec_key;
+ const int64 step_id;
+ const Tensor* input; // Not owned
+ Tensor* output; // Not owned
+ Device* device; // The device for which this instance labors
+ const string device_name;
+ DeviceLocality device_locality;
+};
+
+// Interface of a Collective Op implementation. Each specific CollectiveOp will
+// implement this interface and register the implementation via the
+// CollectiveRegistry detailed below. See common_runtime/ring_reducer and
+// common_runtime/hierarchical_tree_broadcaster for examples.
+class CollectiveImplementationInterface {
+ public:
+ virtual ~CollectiveImplementationInterface() = default;
+
+ // Initializes the portions of `col_params` specific to this
+ // implementation. Called exactly once for every Collective instance during
+ // the CollectiveParams resolution process when the graph is first executed.
+ // NOTE(ayushd): This is effectively a static function because it modifies the
+ // `col_params` passed in and should not manipulate any data members. However
+ // because it is virtual and needs to be implemented by every derived class we
+ // do not mark it as static.
+ virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0;
+
+ // Prepares the CollectiveContext for executing this CollectiveImplementation.
+ // Called from CollectiveExecutor right before calling Run(). The
+ // CollectiveContext passed in must outlive the CollectiveImplementation
+ // object.
+ virtual Status InitializeCollectiveContext(CollectiveContext* col_ctx) = 0;
+
+ // Processes and moves data according to the logic of this Collective
+ // implementation. Relies on appropriate initialization of op-specific
+ // CollectiveParams in InitializeCollectiveParams(), as well as appropriate
+ // context initialization in InitializeCollectiveContext().
+ virtual void Run(StatusCallback done) = 0;
+};
+
+// Static-methods only class for registering and looking up collective
+// implementations.
+class CollectiveRegistry {
+ public:
+ using Factory = std::function<CollectiveImplementationInterface*()>;
+ // Looks up a previously registered CollectiveImplementation under
+ // `collective_name`. If found, creates an instance of the implementation and
+ // assign to `implementation`.
+ static Status Lookup(const string& collective_name,
+ CollectiveImplementationInterface** implementation);
+
+ // Looks up a previously registered CollectiveImplementation under
+ // `collective_name`. If found, returns the static instance of this
+ // implementation via `implementation`. This instance should only be used to
+ // call InitializateCollectiveParams.
+ static Status LookupParamResolverInstance(
+ const string& collective_name,
+ CollectiveImplementationInterface** implementation);
+
+ // Returns all registered collective implementations.
+ static void GetAll(
+ std::vector<CollectiveImplementationInterface*>* implementations);
+
+ private:
+ friend class CollectiveRegistration;
+ // Registers a CollectiveImplementation with name `collective_name` and
+ // factory `factory`. The latter is a function used to create instances of
+ // the CollectiveImplementation. Also creates a static instance of the
+ // implementation - this instance is used during param resolution and should
+ // only be used to call InitializeCollectiveParams.
+ static Status Register(const string& collective_name, Factory factory);
+
+ static Status LookupHelper(const string& collective_name,
+ CollectiveImplementationInterface** implementation,
+ bool param_resolver);
+};
+
+// Class used to call CollectiveRegistry::Register. This should only be used to
+// create a global static object.
+class CollectiveRegistration {
+ public:
+ CollectiveRegistration(const string& collective_name,
+ CollectiveRegistry::Factory factory) {
+ TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory));
+ }
+};
+
+#define REGISTER_COLLECTIVE(name, implementation) \
+ static CollectiveRegistration register_##name##_collective( \
+ #name, []() { return new implementation; });
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_