aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/collective_param_resolver_local.cc
diff options
context:
space:
mode:
authorGravatar Ayush Dubey <ayushd@google.com>2018-08-27 14:19:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 14:23:24 -0700
commit85a6164912e21bc398b930943da7ea90ffe3bc20 (patch)
treeaf2efcf298518583c03dc2d7d415cd72df1d60b1 /tensorflow/core/common_runtime/collective_param_resolver_local.cc
parent59f3c57182fac4d745bb01f3976bb9832c06333d (diff)
Refactor collectives to colocate implementation-specific code.
Before this change, introducing a new collective algorithm required touching multiple files. CollectiveParams setup was in common_runtime/collective_param_resolver_local, and the data movement was in common_runtime/reducer and common_runtime/broadcaster. This change introduces CollectiveImplementationInterface. CollectiveImplementationInterface brings together param initialization and data movement for a collective algorithm. Every collective implementation will implement this interface and override the virtual methods. This should hopefully reduce obscurity and lead to code with fewer dependencies. PiperOrigin-RevId: 210430157
Diffstat (limited to 'tensorflow/core/common_runtime/collective_param_resolver_local.cc')
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.cc237
1 files changed, 30 insertions, 207 deletions
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index 2a14493a67..52eedae9b7 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -14,7 +14,20 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
+#include <stddef.h>
+#include <algorithm>
+#include <unordered_map>
+#include <utility>
+
#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@@ -319,206 +332,6 @@ void SortDevicesAndTasks(CollectiveParams* cp) {
}
} // namespace
-int GetDeviceTask(int device_rank, const std::vector<int>& dev_per_task) {
- int num_tasks = static_cast<int>(dev_per_task.size());
- int task_lo = 0;
- int task_hi;
- for (int ti = 0; ti < num_tasks; ti++) {
- task_hi = task_lo + dev_per_task[ti];
- if (task_lo <= device_rank && device_rank < task_hi) return ti;
- task_lo += dev_per_task[ti];
- }
- LOG(FATAL) << "Unexpected device rank " << device_rank << " for " << task_hi
- << " devices";
- return -1;
-}
-
-void CollectiveParamResolverLocal::GenerateBcastSubdivPerms(
- const string& device, int source_rank, const std::vector<int>& dev_per_task,
- CollectiveParams* cp) {
- if (VLOG_IS_ON(1)) {
- string dpt_buf;
- for (int dpt : dev_per_task) strings::StrAppend(&dpt_buf, dpt, ";");
- VLOG(1) << "GenerateBcastSubdivPerms device=" << device
- << " source_rank=" << source_rank << " dev_per_task=" << dpt_buf;
- }
- int num_tasks = cp->group.num_tasks;
- // If there is just 1 task, then execute binary tree broadcast over all
- // devices. Otherwise, the first subdiv is inter-task broadcast, and then
- // there are N more subdivs, where N is #task.
- int num_subdivs = num_tasks + (num_tasks > 1 ? 1 : 0);
- int total_num_devices = 0;
- for (int num_dev : dev_per_task) total_num_devices += num_dev;
-
- cp->instance.impl_details.subdiv_permutations.resize(num_subdivs);
- cp->subdiv_rank.reserve(num_subdivs);
- cp->instance.impl_details.subdiv_source_rank.reserve(num_subdivs);
-
- // Inter-task subdiv. Pick one device from each task - this is the source
- // device if it belongs to that task, or device 0 for that task. If a device
- // does not participate in the subdiv, set subdiv_rank to -1.
- if (num_tasks > 1) {
- const int sdi = 0;
- std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi];
- CHECK_EQ(perm.size(), 0);
- int device_count = 0;
- int source_task = GetDeviceTask(source_rank, dev_per_task);
- for (int ti = 0; ti < cp->group.num_tasks; ti++) {
- bool participate = false;
- if (source_task == ti) {
- // Source device belongs to this task.
- perm.push_back(source_rank);
- participate = cp->instance.device_names[source_rank] == device;
- } else {
- // Source does not belong to this task, choose dev 0.
- perm.push_back(device_count);
- participate = cp->instance.device_names[device_count] == device;
- }
- if (participate) cp->subdiv_rank.push_back(ti);
- device_count += dev_per_task[ti];
- }
- if (cp->subdiv_rank.empty()) cp->subdiv_rank.push_back(-1);
- cp->instance.impl_details.subdiv_source_rank.push_back(source_task);
- }
-
- // Intra-task subdivs. Pick all devices in task ti for subdiv sdi. Set
- // source to dev 0 for that task if it does not contain original source, else
- // set to rank of original source. If a device does not participate in the
- // subdiv, set subdiv_rank to -1;
- int abs_di = 0;
- for (int ti = 0; ti < cp->group.num_tasks; ti++) {
- const int sdi = ti + (num_tasks > 1 ? 1 : 0);
- std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi];
- CHECK_EQ(perm.size(), 0);
- bool participate = false;
- int subdiv_source = 0;
- for (int di = 0; di < dev_per_task[ti]; di++) {
- perm.push_back(abs_di);
- if (cp->instance.device_names[abs_di] == device) {
- participate = true;
- cp->subdiv_rank.push_back(di);
- }
- if (abs_di == source_rank) subdiv_source = di;
- abs_di++;
- }
- if (!participate) cp->subdiv_rank.push_back(-1);
- cp->instance.impl_details.subdiv_source_rank.push_back(subdiv_source);
- }
-
- for (int sri = 0; sri < num_subdivs; sri++) {
- CHECK_GE(cp->instance.impl_details.subdiv_source_rank[sri], 0);
- }
-}
-
-// Establish the requested number of subdivision permutations based on the
-// ring order implicit in the device order.
-/*static*/
-void CollectiveParamResolverLocal::GenerateSubdivPerms(const string& device,
- int source_rank,
- CollectiveParams* cp) {
- // Each subdiv permutation is a ring formed by rotating each
- // single-task subsequence of devices by an offset. This makes most
- // sense when each task has the same number of devices but we can't
- // depend on that being the case so we'll compute something that
- // works in any case.
-
- // Start by counting the devices in each task.
- // Precondition: device_names must be sorted so that all devices in
- // the same task are adjacent.
- VLOG(2) << "Sorted task names: "
- << str_util::Join(cp->instance.task_names, ", ");
- std::vector<int> dev_per_task;
- const string* prior_task_name = &cp->instance.task_names[0];
- int dev_count = 1;
- for (int di = 1; di < cp->group.group_size; ++di) {
- if (cp->instance.task_names[di] != *prior_task_name) {
- dev_per_task.push_back(dev_count);
- dev_count = 1;
- prior_task_name = &cp->instance.task_names[di];
- } else {
- ++dev_count;
- }
- }
- dev_per_task.push_back(dev_count);
- CHECK_EQ(cp->group.num_tasks, dev_per_task.size());
-
- CHECK(cp->instance.type == REDUCTION_COLLECTIVE ||
- cp->instance.type == BROADCAST_COLLECTIVE);
- if (cp->instance.type == REDUCTION_COLLECTIVE) {
- // Generate a ring permutation for each requested offset.
- CHECK_GT(cp->instance.impl_details.subdiv_offsets.size(), 0);
- VLOG(2) << "Setting up perms for cp " << cp << " subdiv_permutations "
- << &cp->instance.impl_details.subdiv_permutations;
- cp->instance.impl_details.subdiv_permutations.resize(
- cp->instance.impl_details.subdiv_offsets.size());
- cp->subdiv_rank.resize(cp->instance.impl_details.subdiv_offsets.size(), -1);
- for (int sdi = 0; sdi < cp->instance.impl_details.subdiv_offsets.size();
- ++sdi) {
- std::vector<int>& perm =
- cp->instance.impl_details.subdiv_permutations[sdi];
- CHECK_EQ(perm.size(), 0);
- int offset = cp->instance.impl_details.subdiv_offsets[sdi];
- // A negative subdivision offset is interpreted as follows:
- // 1. Reverse the local device ordering.
- // 2. Begin the subdivision at abs(offset) in the reversed ordering.
- bool reverse = false;
- if (offset < 0) {
- offset = abs(offset);
- reverse = true;
- }
- int prior_dev_count = 0; // sum over prior worker device counts
- for (int ti = 0; ti < cp->group.num_tasks; ++ti) {
- for (int di = 0; di < dev_per_task[ti]; ++di) {
- int di_offset = (di + offset) % dev_per_task[ti];
- int offset_di =
- reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset;
- // Device index in global subdivision permutation.
- int permuted_di = prior_dev_count + offset_di;
- int rank = static_cast<int>(perm.size());
- perm.push_back(permuted_di);
- if (cp->instance.device_names[permuted_di] == device) {
- CHECK_EQ(permuted_di, cp->default_rank);
- cp->subdiv_rank[sdi] = rank;
- }
- }
- prior_dev_count += dev_per_task[ti];
- }
- CHECK_EQ(cp->group.group_size, perm.size());
- }
- } else if (cp->instance.type == BROADCAST_COLLECTIVE) {
- GenerateBcastSubdivPerms(device, source_rank, dev_per_task, cp);
- }
-
- if (VLOG_IS_ON(1)) {
- // Log the computed ring order for each subdiv.
- string buf;
- for (int sdi = 0;
- sdi < cp->instance.impl_details.subdiv_permutations.size(); ++sdi) {
- buf = strings::StrCat("Subdiv ", sdi, " device order:\n");
- for (int di = 0;
- di < cp->instance.impl_details.subdiv_permutations[sdi].size();
- ++di) {
- int idx = cp->instance.impl_details.subdiv_permutations[sdi][di];
- if (idx >= 0) {
- CHECK_GT(cp->instance.device_names.size(), idx);
- strings::StrAppend(&buf, cp->instance.device_names[idx], "\n");
- }
- }
- strings::StrAppend(&buf, " subdiv_offsets: ");
- for (auto o : cp->instance.impl_details.subdiv_offsets)
- strings::StrAppend(&buf, o, " ");
- strings::StrAppend(&buf, " SubdivRank: ");
- for (auto d : cp->subdiv_rank) strings::StrAppend(&buf, d, " ");
- if (cp->instance.type == BROADCAST_COLLECTIVE) {
- strings::StrAppend(&buf, " subdiv_source_rank: ");
- for (auto src : cp->instance.impl_details.subdiv_source_rank)
- strings::StrAppend(&buf, src, " ");
- }
- VLOG(1) << buf;
- }
- }
-}
-
void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name,
CollectiveParams* cp) {
cp->task.is_local.resize(cp->group.group_size, false);
@@ -785,29 +598,39 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
// Populate the fields common across task, also default_rank.
SetDefaultRank(device, cp);
CompleteTaskIsLocal(task_name_, cp);
+ // TODO(b/113171733): we need a better way to pick the collective
+ // implementation. The ideal way would depend upon the topology and link
+ // strength before picking a particular implementation.
+ cp->instance.impl_details.collective_name =
+ (cp->instance.type == BROADCAST_COLLECTIVE) ? "HierarchicalTreeBroadcast"
+ : "RingReduce";
+ CollectiveImplementationInterface* col_impl;
+ Status lookup_status = CollectiveRegistry::LookupParamResolverInstance(
+ cp->instance.impl_details.collective_name, &col_impl);
+ if (!lookup_status.ok()) {
+ done(lookup_status);
+ return;
+ }
// If broadcast, may need to wait for source discovery.
if (cp->instance.type == BROADCAST_COLLECTIVE) {
CompleteInstanceSource(ir, cp, is_source,
- [this, ir, device, cp, done](InstanceRec* irec) {
+ [col_impl, ir, device, cp, done](InstanceRec* irec) {
CHECK_EQ(ir, irec);
Status s;
- int source_rank;
{
mutex_lock l(irec->out_mu);
irec->WaitForOutMu(l);
s = irec->status;
- source_rank = irec->source_rank;
+ cp->source_rank = irec->source_rank;
}
if (s.ok()) {
- GenerateSubdivPerms(device, source_rank, cp);
+ s = col_impl->InitializeCollectiveParams(cp);
}
done(s);
});
- return;
} else {
- GenerateSubdivPerms(device, 0, cp);
+ done(col_impl->InitializeCollectiveParams(cp));
}
- done(Status::OK());
}
void CollectiveParamResolverLocal::CompleteInstanceSource(InstanceRec* ir,