diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc | 76 |
1 files changed, 53 insertions, 23 deletions
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc index 7c64529441..9bb030b220 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc @@ -28,30 +28,60 @@ namespace { class PinToHostOptimizerTest : public GrapplerTest {}; -TEST_F(PinToHostOptimizerTest, TryFindHostDevice) { +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceNoDevices) { gtl::FlatSet<string> devices = {}; - EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC")); - - devices = {"/device:CPU:0", "/device:XLA_GPU:0"}; - EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"), - "/device:CPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"), - "/device:CPU:0"); - - devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"}; - EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), ""); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"), - "/device:XLA_CPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"), - "/device:XLA_CPU:0"); - - devices = {"/device:XLA_GPU:0"}; - EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), ""); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"), - "/device:XLA_GPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"), - "/device:XLA_GPU:*"); + + string device = "ABC"; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "ABC"); +} + +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceCpuXlaGpu) { + gtl::FlatSet<string> devices = {"/device:CPU:0", "/device:XLA_GPU:0"}; + + string device = ""; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device)); + EXPECT_EQ(device, "/device:CPU:0"); + + device = "/device:XLA_GPU:0"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device)); + EXPECT_EQ(device, "/device:CPU:0"); + + device = "/device:XLA_GPU:*"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device)); + EXPECT_EQ(device, "/device:CPU:0"); +} + +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaCpuXlaGpu) { + gtl::FlatSet<string> devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"}; + + string device = ""; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_TRUE(device.empty()); + + device = "/device:XLA_GPU:0"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_CPU:0"); + + device = "/device:XLA_GPU:*"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_CPU:0"); +} + +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaGpu) { + gtl::FlatSet<string> devices = {"/device:XLA_GPU:0"}; + + string device = ""; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_TRUE(device.empty()); + + device = "/device:XLA_GPU:0"; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_GPU:0"); + + device = "/device:XLA_GPU:*"; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_GPU:*"); } TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) { |