aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/device_set_test.cc
blob: 1b80a5b697f9faa45ed9610ba4588364cc728c65 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include "tensorflow/core/common_runtime/device_set.h"

#include "tensorflow/core/public/status.h"
#include <gtest/gtest.h>

namespace tensorflow {
namespace {

// Return a fake device with the specified type and name.
static Device* Dev(const char* type, const char* name) {
  class FakeDevice : public Device {
   public:
    explicit FakeDevice(const DeviceAttributes& attr)
        : Device(nullptr, attr, nullptr) {}
    Status Sync() override { return Status::OK(); }
    Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
  };
  DeviceAttributes attr;
  attr.set_name(name);
  attr.set_device_type(type);
  return new FakeDevice(attr);
}

class DeviceSetTest : public testing::Test {
 public:
  void AddDevice(const char* type, const char* name) {
    Device* d = Dev(type, name);
    owned_.emplace_back(d);
    devices_.AddDevice(d);
  }

  std::vector<DeviceType> types() const {
    return devices_.PrioritizedDeviceTypeList();
  }

 private:
  DeviceSet devices_;
  std::vector<std::unique_ptr<Device>> owned_;
};

TEST_F(DeviceSetTest, PrioritizedDeviceTypeList) {
  EXPECT_EQ(std::vector<DeviceType>{}, types());

  AddDevice("CPU", "/job:a/replica:0/task:0/cpu:0");
  EXPECT_EQ(std::vector<DeviceType>{DeviceType(DEVICE_CPU)}, types());

  AddDevice("CPU", "/job:a/replica:0/task:0/cpu:1");
  EXPECT_EQ(std::vector<DeviceType>{DeviceType(DEVICE_CPU)}, types());

  AddDevice("GPU", "/job:a/replica:0/task:0/gpu:0");
  EXPECT_EQ(
      (std::vector<DeviceType>{DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}),
      types());

  AddDevice("T1", "/job:a/replica:0/task:0/device:T1:0");
  AddDevice("T1", "/job:a/replica:0/task:0/device:T1:1");
  AddDevice("T2", "/job:a/replica:0/task:0/device:T2:0");
  EXPECT_EQ(
      (std::vector<DeviceType>{DeviceType("T1"), DeviceType("T2"),
                               DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}),
      types());
}

}  // namespace
}  // namespace tensorflow