aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-26 22:11:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-26 22:15:29 -0700
commitdd1f0cdddbfb3104e390026184474d4374f52642 (patch)
tree6bd1f58a74e4ba771f99e104ce874c0a98a51ffc /tensorflow
parent631a364cd1ddd822cf4b8712a5388d5ea39ecd7e (diff)
Supports lookup devices by fullname either in the canonical form or the
legacy form. This makes DeviceSet behaves the same as DeviceMgr's FindDevice method. PiperOrigin-RevId: 163300346
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/common_runtime/device_mgr.cc19
-rw-r--r--tensorflow/core/common_runtime/device_set.cc5
-rw-r--r--tensorflow/core/util/device_name_utils.cc12
-rw-r--r--tensorflow/core/util/device_name_utils.h5
-rw-r--r--tensorflow/core/util/device_name_utils_test.cc10
5 files changed, 34 insertions, 17 deletions
diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc
index 471463fc8b..0a4e0afc87 100644
--- a/tensorflow/core/common_runtime/device_mgr.cc
+++ b/tensorflow/core/common_runtime/device_mgr.cc
@@ -30,22 +30,9 @@ DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
devices_.push_back(d);
// Register under the (1) full name, (2) canonical name, and (3) local name.
- string full_name = d->name();
- device_map_[CopyToBackingStore(full_name)] = d;
-
- // TODO(b/62909072): Upgrade device_map_ to a better data structure.
- DeviceNameUtils::ParsedName parsed_name = d->parsed_name();
- if (parsed_name.has_job && parsed_name.has_replica &&
- parsed_name.has_task && parsed_name.has_type && parsed_name.has_id) {
- string canonical_name = DeviceNameUtils::FullName(
- parsed_name.job, parsed_name.replica, parsed_name.task,
- parsed_name.type, parsed_name.id);
- device_map_[CopyToBackingStore(canonical_name)] = d;
-
- string legacy_name = DeviceNameUtils::LegacyName(
- parsed_name.job, parsed_name.replica, parsed_name.task,
- parsed_name.type, parsed_name.id);
- device_map_[CopyToBackingStore(legacy_name)] = d;
+ for (const string& name :
+ DeviceNameUtils::GetNamesForDeviceMappings(d->parsed_name())) {
+ device_map_[CopyToBackingStore(name)] = d;
}
string lname = DeviceNameUtils::LocalName(d->name());
device_map_[CopyToBackingStore(lname)] = d;
diff --git a/tensorflow/core/common_runtime/device_set.cc b/tensorflow/core/common_runtime/device_set.cc
index 493349176e..f6b4115cbf 100644
--- a/tensorflow/core/common_runtime/device_set.cc
+++ b/tensorflow/core/common_runtime/device_set.cc
@@ -32,7 +32,10 @@ DeviceSet::~DeviceSet() {}
void DeviceSet::AddDevice(Device* device) {
devices_.push_back(device);
- device_by_name_.insert({device->name(), device});
+ for (const string& name :
+ DeviceNameUtils::GetNamesForDeviceMappings(device->parsed_name())) {
+ device_by_name_.insert({name, device});
+ }
}
void DeviceSet::FindMatchingDevices(const DeviceNameUtils::ParsedName& spec,
diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc
index 1773ac00b7..ea25f90a64 100644
--- a/tensorflow/core/util/device_name_utils.cc
+++ b/tensorflow/core/util/device_name_utils.cc
@@ -387,4 +387,16 @@ bool DeviceNameUtils::SplitDeviceName(StringPiece name, string* task,
return false;
}
+std::vector<string> DeviceNameUtils::GetNamesForDeviceMappings(
+ const ParsedName& pn) {
+ if (pn.has_job && pn.has_replica && pn.has_task && pn.has_type && pn.has_id) {
+ return {
+ DeviceNameUtils::FullName(pn.job, pn.replica, pn.task, pn.type, pn.id),
+ DeviceNameUtils::LegacyName(pn.job, pn.replica, pn.task, pn.type,
+ pn.id)};
+ } else {
+ return {};
+ }
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/util/device_name_utils.h b/tensorflow/core/util/device_name_utils.h
index 1f32828bae..740aa13fa7 100644
--- a/tensorflow/core/util/device_name_utils.h
+++ b/tensorflow/core/util/device_name_utils.h
@@ -152,6 +152,11 @@ class DeviceNameUtils {
static bool SplitDeviceName(StringPiece name, string* task, string* device);
static string ParsedNameToString(const ParsedName& pn);
+
+ // Returns canonical and legacy full names for the given parsed
+ // device name 'pn'. The returned string names are often useful to
+ // lookup devices from a mapping.
+ static std::vector<string> GetNamesForDeviceMappings(const ParsedName& pn);
};
} // namespace tensorflow
diff --git a/tensorflow/core/util/device_name_utils_test.cc b/tensorflow/core/util/device_name_utils_test.cc
index e44b840967..008100aa44 100644
--- a/tensorflow/core/util/device_name_utils_test.cc
+++ b/tensorflow/core/util/device_name_utils_test.cc
@@ -443,6 +443,16 @@ TEST(DeviceNameUtilsTest, MergeDevNamesAllowSoftPlacement) {
MergeDevNamesHelperAllowSoftPlacement("/gpu:1", "/gpu:2", "/gpu:*");
}
+TEST(DeviceNameUtilsTest, GetNamesForDeviceMappings) {
+ DeviceNameUtils::ParsedName p = Name("/job:foo/replica:10/task:0/gpu:1");
+ EXPECT_EQ(str_util::Join(DeviceNameUtils::GetNamesForDeviceMappings(p), ","),
+ "/job:foo/replica:10/task:0/device:GPU:1,"
+ "/job:foo/replica:10/task:0/gpu:1");
+ p.has_task = false;
+ EXPECT_EQ(str_util::Join(DeviceNameUtils::GetNamesForDeviceMappings(p), ","),
+ "");
+}
+
static void BM_ParseFullName(int iters) {
DeviceNameUtils::ParsedName p;
while (iters--) {