diff options
author | 2018-08-27 14:19:20 -0700 | |
---|---|---|
committer | 2018-08-27 14:23:24 -0700 | |
commit | 85a6164912e21bc398b930943da7ea90ffe3bc20 (patch) | |
tree | af2efcf298518583c03dc2d7d415cd72df1d60b1 /tensorflow/core/common_runtime/collective_param_resolver_local.cc | |
parent | 59f3c57182fac4d745bb01f3976bb9832c06333d (diff) |
Refactor collectives to colocate implementation-specific code.
Before this change, introducing a new collective algorithm required touching
multiple files. CollectiveParams setup was in common_runtime/collective_param_resolver_local,
and the data movement was in common_runtime/reducer and common_runtime/broadcaster.
This change introduces CollectiveImplementationInterface.
CollectiveImplementationInterface brings together param initialization and data
movement for a collective algorithm. Every collective implementation will
implement this interface and override the virtual methods. This should
hopefully reduce obscurity and lead to code with fewer dependencies.
PiperOrigin-RevId: 210430157
Diffstat (limited to 'tensorflow/core/common_runtime/collective_param_resolver_local.cc')
-rw-r--r-- | tensorflow/core/common_runtime/collective_param_resolver_local.cc | 237 |
1 files changed, 30 insertions, 207 deletions
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc index 2a14493a67..52eedae9b7 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc @@ -14,7 +14,20 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/collective_param_resolver_local.h" +#include <stddef.h> +#include <algorithm> +#include <unordered_map> +#include <utility> + #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -319,206 +332,6 @@ void SortDevicesAndTasks(CollectiveParams* cp) { } } // namespace -int GetDeviceTask(int device_rank, const std::vector<int>& dev_per_task) { - int num_tasks = static_cast<int>(dev_per_task.size()); - int task_lo = 0; - int task_hi; - for (int ti = 0; ti < num_tasks; ti++) { - task_hi = task_lo + dev_per_task[ti]; - if (task_lo <= device_rank && device_rank < task_hi) return ti; - task_lo += dev_per_task[ti]; - } - LOG(FATAL) << "Unexpected device rank " << device_rank << " for " << task_hi - << " devices"; - return -1; -} - -void CollectiveParamResolverLocal::GenerateBcastSubdivPerms( - const string& device, int source_rank, const std::vector<int>& dev_per_task, - CollectiveParams* cp) { - if (VLOG_IS_ON(1)) { - string dpt_buf; - for (int dpt : dev_per_task) strings::StrAppend(&dpt_buf, dpt, ";"); - VLOG(1) << "GenerateBcastSubdivPerms device=" << device - << " source_rank=" << source_rank << " dev_per_task=" << dpt_buf; - } - int num_tasks = cp->group.num_tasks; - // If there is just 1 task, then execute binary tree broadcast over all - // devices. Otherwise, the first subdiv is inter-task broadcast, and then - // there are N more subdivs, where N is #task. - int num_subdivs = num_tasks + (num_tasks > 1 ? 1 : 0); - int total_num_devices = 0; - for (int num_dev : dev_per_task) total_num_devices += num_dev; - - cp->instance.impl_details.subdiv_permutations.resize(num_subdivs); - cp->subdiv_rank.reserve(num_subdivs); - cp->instance.impl_details.subdiv_source_rank.reserve(num_subdivs); - - // Inter-task subdiv. Pick one device from each task - this is the source - // device if it belongs to that task, or device 0 for that task. If a device - // does not participate in the subdiv, set subdiv_rank to -1. - if (num_tasks > 1) { - const int sdi = 0; - std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi]; - CHECK_EQ(perm.size(), 0); - int device_count = 0; - int source_task = GetDeviceTask(source_rank, dev_per_task); - for (int ti = 0; ti < cp->group.num_tasks; ti++) { - bool participate = false; - if (source_task == ti) { - // Source device belongs to this task. - perm.push_back(source_rank); - participate = cp->instance.device_names[source_rank] == device; - } else { - // Source does not belong to this task, choose dev 0. - perm.push_back(device_count); - participate = cp->instance.device_names[device_count] == device; - } - if (participate) cp->subdiv_rank.push_back(ti); - device_count += dev_per_task[ti]; - } - if (cp->subdiv_rank.empty()) cp->subdiv_rank.push_back(-1); - cp->instance.impl_details.subdiv_source_rank.push_back(source_task); - } - - // Intra-task subdivs. Pick all devices in task ti for subdiv sdi. Set - // source to dev 0 for that task if it does not contain original source, else - // set to rank of original source. If a device does not participate in the - // subdiv, set subdiv_rank to -1; - int abs_di = 0; - for (int ti = 0; ti < cp->group.num_tasks; ti++) { - const int sdi = ti + (num_tasks > 1 ? 1 : 0); - std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi]; - CHECK_EQ(perm.size(), 0); - bool participate = false; - int subdiv_source = 0; - for (int di = 0; di < dev_per_task[ti]; di++) { - perm.push_back(abs_di); - if (cp->instance.device_names[abs_di] == device) { - participate = true; - cp->subdiv_rank.push_back(di); - } - if (abs_di == source_rank) subdiv_source = di; - abs_di++; - } - if (!participate) cp->subdiv_rank.push_back(-1); - cp->instance.impl_details.subdiv_source_rank.push_back(subdiv_source); - } - - for (int sri = 0; sri < num_subdivs; sri++) { - CHECK_GE(cp->instance.impl_details.subdiv_source_rank[sri], 0); - } -} - -// Establish the requested number of subdivision permutations based on the -// ring order implicit in the device order. -/*static*/ -void CollectiveParamResolverLocal::GenerateSubdivPerms(const string& device, - int source_rank, - CollectiveParams* cp) { - // Each subdiv permutation is a ring formed by rotating each - // single-task subsequence of devices by an offset. This makes most - // sense when each task has the same number of devices but we can't - // depend on that being the case so we'll compute something that - // works in any case. - - // Start by counting the devices in each task. - // Precondition: device_names must be sorted so that all devices in - // the same task are adjacent. - VLOG(2) << "Sorted task names: " - << str_util::Join(cp->instance.task_names, ", "); - std::vector<int> dev_per_task; - const string* prior_task_name = &cp->instance.task_names[0]; - int dev_count = 1; - for (int di = 1; di < cp->group.group_size; ++di) { - if (cp->instance.task_names[di] != *prior_task_name) { - dev_per_task.push_back(dev_count); - dev_count = 1; - prior_task_name = &cp->instance.task_names[di]; - } else { - ++dev_count; - } - } - dev_per_task.push_back(dev_count); - CHECK_EQ(cp->group.num_tasks, dev_per_task.size()); - - CHECK(cp->instance.type == REDUCTION_COLLECTIVE || - cp->instance.type == BROADCAST_COLLECTIVE); - if (cp->instance.type == REDUCTION_COLLECTIVE) { - // Generate a ring permutation for each requested offset. - CHECK_GT(cp->instance.impl_details.subdiv_offsets.size(), 0); - VLOG(2) << "Setting up perms for cp " << cp << " subdiv_permutations " - << &cp->instance.impl_details.subdiv_permutations; - cp->instance.impl_details.subdiv_permutations.resize( - cp->instance.impl_details.subdiv_offsets.size()); - cp->subdiv_rank.resize(cp->instance.impl_details.subdiv_offsets.size(), -1); - for (int sdi = 0; sdi < cp->instance.impl_details.subdiv_offsets.size(); - ++sdi) { - std::vector<int>& perm = - cp->instance.impl_details.subdiv_permutations[sdi]; - CHECK_EQ(perm.size(), 0); - int offset = cp->instance.impl_details.subdiv_offsets[sdi]; - // A negative subdivision offset is interpreted as follows: - // 1. Reverse the local device ordering. - // 2. Begin the subdivision at abs(offset) in the reversed ordering. - bool reverse = false; - if (offset < 0) { - offset = abs(offset); - reverse = true; - } - int prior_dev_count = 0; // sum over prior worker device counts - for (int ti = 0; ti < cp->group.num_tasks; ++ti) { - for (int di = 0; di < dev_per_task[ti]; ++di) { - int di_offset = (di + offset) % dev_per_task[ti]; - int offset_di = - reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset; - // Device index in global subdivision permutation. - int permuted_di = prior_dev_count + offset_di; - int rank = static_cast<int>(perm.size()); - perm.push_back(permuted_di); - if (cp->instance.device_names[permuted_di] == device) { - CHECK_EQ(permuted_di, cp->default_rank); - cp->subdiv_rank[sdi] = rank; - } - } - prior_dev_count += dev_per_task[ti]; - } - CHECK_EQ(cp->group.group_size, perm.size()); - } - } else if (cp->instance.type == BROADCAST_COLLECTIVE) { - GenerateBcastSubdivPerms(device, source_rank, dev_per_task, cp); - } - - if (VLOG_IS_ON(1)) { - // Log the computed ring order for each subdiv. - string buf; - for (int sdi = 0; - sdi < cp->instance.impl_details.subdiv_permutations.size(); ++sdi) { - buf = strings::StrCat("Subdiv ", sdi, " device order:\n"); - for (int di = 0; - di < cp->instance.impl_details.subdiv_permutations[sdi].size(); - ++di) { - int idx = cp->instance.impl_details.subdiv_permutations[sdi][di]; - if (idx >= 0) { - CHECK_GT(cp->instance.device_names.size(), idx); - strings::StrAppend(&buf, cp->instance.device_names[idx], "\n"); - } - } - strings::StrAppend(&buf, " subdiv_offsets: "); - for (auto o : cp->instance.impl_details.subdiv_offsets) - strings::StrAppend(&buf, o, " "); - strings::StrAppend(&buf, " SubdivRank: "); - for (auto d : cp->subdiv_rank) strings::StrAppend(&buf, d, " "); - if (cp->instance.type == BROADCAST_COLLECTIVE) { - strings::StrAppend(&buf, " subdiv_source_rank: "); - for (auto src : cp->instance.impl_details.subdiv_source_rank) - strings::StrAppend(&buf, src, " "); - } - VLOG(1) << buf; - } - } -} - void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name, CollectiveParams* cp) { cp->task.is_local.resize(cp->group.group_size, false); @@ -785,29 +598,39 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec( // Populate the fields common across task, also default_rank. SetDefaultRank(device, cp); CompleteTaskIsLocal(task_name_, cp); + // TODO(b/113171733): we need a better way to pick the collective + // implementation. The ideal way would depend upon the topology and link + // strength before picking a particular implementation. + cp->instance.impl_details.collective_name = + (cp->instance.type == BROADCAST_COLLECTIVE) ? "HierarchicalTreeBroadcast" + : "RingReduce"; + CollectiveImplementationInterface* col_impl; + Status lookup_status = CollectiveRegistry::LookupParamResolverInstance( + cp->instance.impl_details.collective_name, &col_impl); + if (!lookup_status.ok()) { + done(lookup_status); + return; + } // If broadcast, may need to wait for source discovery. if (cp->instance.type == BROADCAST_COLLECTIVE) { CompleteInstanceSource(ir, cp, is_source, - [this, ir, device, cp, done](InstanceRec* irec) { + [col_impl, ir, device, cp, done](InstanceRec* irec) { CHECK_EQ(ir, irec); Status s; - int source_rank; { mutex_lock l(irec->out_mu); irec->WaitForOutMu(l); s = irec->status; - source_rank = irec->source_rank; + cp->source_rank = irec->source_rank; } if (s.ok()) { - GenerateSubdivPerms(device, source_rank, cp); + s = col_impl->InitializeCollectiveParams(cp); } done(s); }); - return; } else { - GenerateSubdivPerms(device, 0, cp); + done(col_impl->InitializeCollectiveParams(cp)); } - done(Status::OK()); } void CollectiveParamResolverLocal::CompleteInstanceSource(InstanceRec* ir, |