diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-07-26 22:11:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-26 22:15:29 -0700 |
commit | dd1f0cdddbfb3104e390026184474d4374f52642 (patch) | |
tree | 6bd1f58a74e4ba771f99e104ce874c0a98a51ffc /tensorflow | |
parent | 631a364cd1ddd822cf4b8712a5388d5ea39ecd7e (diff) |
Supports lookup devices by fullname either in the canonical form or the
legacy form. This makes DeviceSet behaves the same as DeviceMgr's
FindDevice method.
PiperOrigin-RevId: 163300346
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/common_runtime/device_mgr.cc | 19 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/device_set.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/util/device_name_utils.cc | 12 | ||||
-rw-r--r-- | tensorflow/core/util/device_name_utils.h | 5 | ||||
-rw-r--r-- | tensorflow/core/util/device_name_utils_test.cc | 10 |
5 files changed, 34 insertions, 17 deletions
diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc index 471463fc8b..0a4e0afc87 100644 --- a/tensorflow/core/common_runtime/device_mgr.cc +++ b/tensorflow/core/common_runtime/device_mgr.cc @@ -30,22 +30,9 @@ DeviceMgr::DeviceMgr(const std::vector<Device*>& devices) devices_.push_back(d); // Register under the (1) full name, (2) canonical name, and (3) local name. - string full_name = d->name(); - device_map_[CopyToBackingStore(full_name)] = d; - - // TODO(b/62909072): Upgrade device_map_ to a better data structure. - DeviceNameUtils::ParsedName parsed_name = d->parsed_name(); - if (parsed_name.has_job && parsed_name.has_replica && - parsed_name.has_task && parsed_name.has_type && parsed_name.has_id) { - string canonical_name = DeviceNameUtils::FullName( - parsed_name.job, parsed_name.replica, parsed_name.task, - parsed_name.type, parsed_name.id); - device_map_[CopyToBackingStore(canonical_name)] = d; - - string legacy_name = DeviceNameUtils::LegacyName( - parsed_name.job, parsed_name.replica, parsed_name.task, - parsed_name.type, parsed_name.id); - device_map_[CopyToBackingStore(legacy_name)] = d; + for (const string& name : + DeviceNameUtils::GetNamesForDeviceMappings(d->parsed_name())) { + device_map_[CopyToBackingStore(name)] = d; } string lname = DeviceNameUtils::LocalName(d->name()); device_map_[CopyToBackingStore(lname)] = d; diff --git a/tensorflow/core/common_runtime/device_set.cc b/tensorflow/core/common_runtime/device_set.cc index 493349176e..f6b4115cbf 100644 --- a/tensorflow/core/common_runtime/device_set.cc +++ b/tensorflow/core/common_runtime/device_set.cc @@ -32,7 +32,10 @@ DeviceSet::~DeviceSet() {} void DeviceSet::AddDevice(Device* device) { devices_.push_back(device); - device_by_name_.insert({device->name(), device}); + for (const string& name : + DeviceNameUtils::GetNamesForDeviceMappings(device->parsed_name())) { + device_by_name_.insert({name, device}); + } } void DeviceSet::FindMatchingDevices(const DeviceNameUtils::ParsedName& spec, diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc index 1773ac00b7..ea25f90a64 100644 --- a/tensorflow/core/util/device_name_utils.cc +++ b/tensorflow/core/util/device_name_utils.cc @@ -387,4 +387,16 @@ bool DeviceNameUtils::SplitDeviceName(StringPiece name, string* task, return false; } +std::vector<string> DeviceNameUtils::GetNamesForDeviceMappings( + const ParsedName& pn) { + if (pn.has_job && pn.has_replica && pn.has_task && pn.has_type && pn.has_id) { + return { + DeviceNameUtils::FullName(pn.job, pn.replica, pn.task, pn.type, pn.id), + DeviceNameUtils::LegacyName(pn.job, pn.replica, pn.task, pn.type, + pn.id)}; + } else { + return {}; + } +} + } // namespace tensorflow diff --git a/tensorflow/core/util/device_name_utils.h b/tensorflow/core/util/device_name_utils.h index 1f32828bae..740aa13fa7 100644 --- a/tensorflow/core/util/device_name_utils.h +++ b/tensorflow/core/util/device_name_utils.h @@ -152,6 +152,11 @@ class DeviceNameUtils { static bool SplitDeviceName(StringPiece name, string* task, string* device); static string ParsedNameToString(const ParsedName& pn); + + // Returns canonical and legacy full names for the given parsed + // device name 'pn'. The returned string names are often useful to + // lookup devices from a mapping. + static std::vector<string> GetNamesForDeviceMappings(const ParsedName& pn); }; } // namespace tensorflow diff --git a/tensorflow/core/util/device_name_utils_test.cc b/tensorflow/core/util/device_name_utils_test.cc index e44b840967..008100aa44 100644 --- a/tensorflow/core/util/device_name_utils_test.cc +++ b/tensorflow/core/util/device_name_utils_test.cc @@ -443,6 +443,16 @@ TEST(DeviceNameUtilsTest, MergeDevNamesAllowSoftPlacement) { MergeDevNamesHelperAllowSoftPlacement("/gpu:1", "/gpu:2", "/gpu:*"); } +TEST(DeviceNameUtilsTest, GetNamesForDeviceMappings) { + DeviceNameUtils::ParsedName p = Name("/job:foo/replica:10/task:0/gpu:1"); + EXPECT_EQ(str_util::Join(DeviceNameUtils::GetNamesForDeviceMappings(p), ","), + "/job:foo/replica:10/task:0/device:GPU:1," + "/job:foo/replica:10/task:0/gpu:1"); + p.has_task = false; + EXPECT_EQ(str_util::Join(DeviceNameUtils::GetNamesForDeviceMappings(p), ","), + ""); +} + static void BM_ParseFullName(int iters) { DeviceNameUtils::ParsedName p; while (iters--) { |