aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc76
1 files changed, 53 insertions, 23 deletions
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
index 7c64529441..9bb030b220 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
@@ -28,30 +28,60 @@ namespace {
class PinToHostOptimizerTest : public GrapplerTest {};
-TEST_F(PinToHostOptimizerTest, TryFindHostDevice) {
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceNoDevices) {
gtl::FlatSet<string> devices = {};
- EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC"));
-
- devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"),
- "/device:CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"),
- "/device:CPU:0");
-
- devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
- "/device:XLA_CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
- "/device:XLA_CPU:0");
-
- devices = {"/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
- "/device:XLA_GPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
- "/device:XLA_GPU:*");
+
+ string device = "ABC";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "ABC");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceCpuXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaCpuXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_TRUE(device.empty());
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_CPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_CPU:0");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_TRUE(device.empty());
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_GPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_GPU:*");
}
TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {