aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/common_runtime/device_set.cc9
-rw-r--r--tensorflow/core/common_runtime/device_set_test.cc20
2 files changed, 19 insertions, 10 deletions
diff --git a/tensorflow/core/common_runtime/device_set.cc b/tensorflow/core/common_runtime/device_set.cc
index 8ff93760d4..84f435b022 100644
--- a/tensorflow/core/common_runtime/device_set.cc
+++ b/tensorflow/core/common_runtime/device_set.cc
@@ -54,10 +54,15 @@ Device* DeviceSet::FindDeviceByName(const string& name) const {
int DeviceSet::DeviceTypeOrder(const DeviceType& d) {
if (StringPiece(d.type()) == DEVICE_CPU) {
return 3;
- } else if (StringPiece(d.type()) == DEVICE_GPU) {
+ } else if (StringPiece(d.type()) == DEVICE_GPU ||
+ StringPiece(d.type()) == DEVICE_SYCL) {
return 2;
} else {
- return 1;
+ // Non-CPU/GPU devices are never prioritized over CPU and GPU, and
+ // must be explicitly selected. This is to prevent surprising
+ // placements that cause a lot of cross-device communication
+ // between the host CPU device and other devices.
+ return 10;
}
}
diff --git a/tensorflow/core/common_runtime/device_set_test.cc b/tensorflow/core/common_runtime/device_set_test.cc
index 8ca744c004..7c59b43489 100644
--- a/tensorflow/core/common_runtime/device_set_test.cc
+++ b/tensorflow/core/common_runtime/device_set_test.cc
@@ -69,18 +69,22 @@ TEST_F(DeviceSetTest, PrioritizedDeviceTypeList) {
types());
AddDevice("SYCL", "/job:a/replica:0/task:0/device:sycl:0");
- EXPECT_EQ(
- (std::vector<DeviceType>{DeviceType(DEVICE_SYCL), DeviceType(DEVICE_GPU),
- DeviceType(DEVICE_CPU)}),
- types());
+ EXPECT_TRUE((types()[0] == DeviceType(DEVICE_SYCL) ||
+ types()[0] == DeviceType(DEVICE_GPU)));
+ EXPECT_TRUE((types()[1] == DeviceType(DEVICE_SYCL) ||
+ types()[1] == DeviceType(DEVICE_GPU)));
+ EXPECT_TRUE(types()[2] == DeviceType(DEVICE_CPU));
AddDevice("T1", "/job:a/replica:0/task:0/device:T1:0");
AddDevice("T1", "/job:a/replica:0/task:0/device:T1:1");
AddDevice("T2", "/job:a/replica:0/task:0/device:T2:0");
- EXPECT_EQ((std::vector<DeviceType>{DeviceType(DEVICE_SYCL), DeviceType("T1"),
- DeviceType("T2"), DeviceType(DEVICE_GPU),
- DeviceType(DEVICE_CPU)}),
- types());
+ EXPECT_TRUE((types()[0] == DeviceType(DEVICE_SYCL) ||
+ types()[0] == DeviceType(DEVICE_GPU)));
+ EXPECT_TRUE((types()[1] == DeviceType(DEVICE_SYCL) ||
+ types()[1] == DeviceType(DEVICE_GPU)));
+ EXPECT_TRUE(types()[2] == DeviceType(DEVICE_CPU));
+ EXPECT_TRUE(types()[3] == DeviceType("T1"));
+ EXPECT_TRUE(types()[4] == DeviceType("T2"));
}
} // namespace