diff options
Diffstat (limited to 'tensorflow/core/common_runtime/collective_param_resolver_local_test.cc')
-rw-r--r-- | tensorflow/core/common_runtime/collective_param_resolver_local_test.cc | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc index d5be8f927e..9ea23b72d2 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc @@ -49,6 +49,26 @@ class CollectiveParamResolverLocalTest : public ::testing::Test { CollectiveParamResolverLocal::GenerateSubdivPerms(device, source_rank, cp); } + // Calls GenerateBcastSubdivPerms for device at `device_rank`. Checks if the + // generated subdiv perms, ranks, and source ranks match the expected values. + void BcastSubdivPerms( + CollectiveParams* cp, const std::vector<int>& dev_per_task, + int device_rank, int source_rank, + const std::vector<std::vector<int>>& expected_subdiv_perms, + const std::vector<int>& expected_subdiv_rank, + const std::vector<int>& expected_subdiv_source_rank) { + cp->subdiv_rank.clear(); + cp->instance.impl_details.subdiv_permutations.clear(); + cp->instance.impl_details.subdiv_source_rank.clear(); + CollectiveParamResolverLocal::GenerateBcastSubdivPerms( + cp->instance.device_names[device_rank], source_rank, dev_per_task, cp); + EXPECT_EQ(expected_subdiv_perms, + cp->instance.impl_details.subdiv_permutations); + EXPECT_EQ(expected_subdiv_rank, cp->subdiv_rank); + EXPECT_EQ(expected_subdiv_source_rank, + cp->instance.impl_details.subdiv_source_rank); + } + std::vector<Device*> devices_; std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<DeviceResolverLocal> drl_; @@ -216,4 +236,113 @@ TEST_F(CollectiveParamResolverLocalTest, GenerateSubdivPerms) { EXPECT_EQ(1, cp.subdiv_rank[1]); } +TEST_F(CollectiveParamResolverLocalTest, GenerateBcastSubdivPerms1Task8GPU) { + CollectiveParams cp; + cp.group.device_type = DeviceType("GPU"); + cp.group.num_tasks = 1; + cp.instance.type = BROADCAST_COLLECTIVE; + for (int i = 0; i < 8; i++) { + string dev_name = + strings::StrCat("/job:worker/replica:0/task:0/device:GPU:", i); + cp.instance.device_names.push_back(dev_name); + } + std::vector<int> dev_per_task = {8}; + + // source 0 device 0 + BcastSubdivPerms(&cp, dev_per_task, 0, 0, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, + {0}); + + // source 2 device 2 + BcastSubdivPerms(&cp, dev_per_task, 2, 2, {{0, 1, 2, 3, 4, 5, 6, 7}}, {2}, + {2}); + + // source 2 device 0 + BcastSubdivPerms(&cp, dev_per_task, 0, 2, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, + {2}); +} + +TEST_F(CollectiveParamResolverLocalTest, GenerateBcastSubdivPerms4Tasks8GPU) { + CollectiveParams cp; + cp.group.device_type = DeviceType("GPU"); + cp.group.num_tasks = 4; + cp.instance.type = BROADCAST_COLLECTIVE; + for (int ti = 0; ti < cp.group.num_tasks; ti++) { + for (int di = 0; di < 8; di++) { + string dev_name = strings::StrCat("/job:worker/replica:0/task:", ti, + "/device:GPU:", di); + cp.instance.device_names.push_back(dev_name); + } + } + std::vector<int> dev_per_task = {8, 8, 8, 8}; + + // source 0 device 0 + BcastSubdivPerms(&cp, dev_per_task, 0, 0, + {{0, 8, 16, 24}, + {0, 1, 2, 3, 4, 5, 6, 7}, + {8, 9, 10, 11, 12, 13, 14, 15}, + {16, 17, 18, 19, 20, 21, 22, 23}, + {24, 25, 26, 27, 28, 29, 30, 31}}, + {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0}); + + // source 2 device 0 + BcastSubdivPerms(&cp, dev_per_task, 0, 2, + {{2, 8, 16, 24}, + {0, 1, 2, 3, 4, 5, 6, 7}, + {8, 9, 10, 11, 12, 13, 14, 15}, + {16, 17, 18, 19, 20, 21, 22, 23}, + {24, 25, 26, 27, 28, 29, 30, 31}}, + {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0}); + + // source 9 device 9 + BcastSubdivPerms(&cp, dev_per_task, 9, 9, + {{0, 9, 16, 24}, + {0, 1, 2, 3, 4, 5, 6, 7}, + {8, 9, 10, 11, 12, 13, 14, 15}, + {16, 17, 18, 19, 20, 21, 22, 23}, + {24, 25, 26, 27, 28, 29, 30, 31}}, + {1, -1, 1, -1, -1}, {1, 0, 1, 0, 0}); +} + +TEST_F(CollectiveParamResolverLocalTest, + GenerateBcastSubdivPerms4TasksVariableGPU) { + CollectiveParams cp; + cp.group.device_type = DeviceType("GPU"); + cp.group.num_tasks = 4; + std::vector<int> dev_per_task = {4, 4, 6, 8}; + for (int ti = 0; ti < cp.group.num_tasks; ti++) { + for (int di = 0; di < dev_per_task[ti]; di++) { + string dev_name = strings::StrCat("/job:worker/replica:0/task:", ti, + "/device:GPU:", di); + cp.instance.device_names.push_back(dev_name); + } + } + + // source 0 device 0 + BcastSubdivPerms(&cp, dev_per_task, 0, 0, + {{0, 4, 8, 14}, + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 10, 11, 12, 13}, + {14, 15, 16, 17, 18, 19, 20, 21}}, + {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0}); + + // source 2 device 0 + BcastSubdivPerms(&cp, dev_per_task, 0, 2, + {{2, 4, 8, 14}, + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 10, 11, 12, 13}, + {14, 15, 16, 17, 18, 19, 20, 21}}, + {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0}); + + // source 9 device 5 + BcastSubdivPerms(&cp, dev_per_task, 5, 9, + {{0, 4, 9, 14}, + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 10, 11, 12, 13}, + {14, 15, 16, 17, 18, 19, 20, 21}}, + {-1, -1, 1, -1, -1}, {2, 0, 0, 1, 0}); +} + } // namespace tensorflow |