diff options
-rw-r--r-- | tensorflow/core/common_runtime/device_set.cc | 9 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/device_set_test.cc | 20 |
2 files changed, 19 insertions, 10 deletions
diff --git a/tensorflow/core/common_runtime/device_set.cc b/tensorflow/core/common_runtime/device_set.cc index 8ff93760d4..84f435b022 100644 --- a/tensorflow/core/common_runtime/device_set.cc +++ b/tensorflow/core/common_runtime/device_set.cc @@ -54,10 +54,15 @@ Device* DeviceSet::FindDeviceByName(const string& name) const { int DeviceSet::DeviceTypeOrder(const DeviceType& d) { if (StringPiece(d.type()) == DEVICE_CPU) { return 3; - } else if (StringPiece(d.type()) == DEVICE_GPU) { + } else if (StringPiece(d.type()) == DEVICE_GPU || + StringPiece(d.type()) == DEVICE_SYCL) { return 2; } else { - return 1; + // Non-CPU/GPU devices are never prioritized over CPU and GPU, and + // must be explicitly selected. This is to prevent surprising + // placements that cause a lot of cross-device communication + // between the host CPU device and other devices. + return 10; } } diff --git a/tensorflow/core/common_runtime/device_set_test.cc b/tensorflow/core/common_runtime/device_set_test.cc index 8ca744c004..7c59b43489 100644 --- a/tensorflow/core/common_runtime/device_set_test.cc +++ b/tensorflow/core/common_runtime/device_set_test.cc @@ -69,18 +69,22 @@ TEST_F(DeviceSetTest, PrioritizedDeviceTypeList) { types()); AddDevice("SYCL", "/job:a/replica:0/task:0/device:sycl:0"); - EXPECT_EQ( - (std::vector<DeviceType>{DeviceType(DEVICE_SYCL), DeviceType(DEVICE_GPU), - DeviceType(DEVICE_CPU)}), - types()); + EXPECT_TRUE((types()[0] == DeviceType(DEVICE_SYCL) || + types()[0] == DeviceType(DEVICE_GPU))); + EXPECT_TRUE((types()[1] == DeviceType(DEVICE_SYCL) || + types()[1] == DeviceType(DEVICE_GPU))); + EXPECT_TRUE(types()[2] == DeviceType(DEVICE_CPU)); AddDevice("T1", "/job:a/replica:0/task:0/device:T1:0"); AddDevice("T1", "/job:a/replica:0/task:0/device:T1:1"); AddDevice("T2", "/job:a/replica:0/task:0/device:T2:0"); - EXPECT_EQ((std::vector<DeviceType>{DeviceType(DEVICE_SYCL), DeviceType("T1"), - DeviceType("T2"), DeviceType(DEVICE_GPU), - DeviceType(DEVICE_CPU)}), - types()); + EXPECT_TRUE((types()[0] == DeviceType(DEVICE_SYCL) || + types()[0] == DeviceType(DEVICE_GPU))); + EXPECT_TRUE((types()[1] == DeviceType(DEVICE_SYCL) || + types()[1] == DeviceType(DEVICE_GPU))); + EXPECT_TRUE(types()[2] == DeviceType(DEVICE_CPU)); + EXPECT_TRUE(types()[3] == DeviceType("T1")); + EXPECT_TRUE(types()[4] == DeviceType("T2")); } } // namespace |