diff options
author | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
---|---|---|
committer | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
commit | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch) | |
tree | ef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/common_runtime/device_set.cc |
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation
using data flow graphs.
Base CL: 107276108
Diffstat (limited to 'tensorflow/core/common_runtime/device_set.cc')
-rw-r--r-- | tensorflow/core/common_runtime/device_set.cc | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/device_set.cc b/tensorflow/core/common_runtime/device_set.cc new file mode 100644 index 0000000000..3b0465d9a6 --- /dev/null +++ b/tensorflow/core/common_runtime/device_set.cc @@ -0,0 +1,68 @@ +#include "tensorflow/core/common_runtime/device_set.h" + +#include <set> +#include <utility> +#include <vector> + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace tensorflow { + +DeviceSet::DeviceSet() {} + +DeviceSet::~DeviceSet() {} + +void DeviceSet::AddDevice(Device* device) { + devices_.push_back(device); + device_by_name_.insert({device->name(), device}); +} + +void DeviceSet::FindMatchingDevices(const DeviceNameUtils::ParsedName& spec, + std::vector<Device*>* devices) const { + // TODO(jeff): If we are going to repeatedly lookup the set of devices + // for the same spec, maybe we should have a cache of some sort + devices->clear(); + for (Device* d : devices_) { + if (DeviceNameUtils::IsCompleteSpecification(spec, d->parsed_name())) { + devices->push_back(d); + } + } +} + +Device* DeviceSet::FindDeviceByName(const string& name) const { + return gtl::FindPtrOrNull(device_by_name_, name); +} + +// Higher result implies lower priority. +static int Order(const DeviceType& d) { + if (StringPiece(d.type()) == DEVICE_CPU) { + return 3; + } else if (StringPiece(d.type()) == DEVICE_GPU) { + return 2; + } else { + return 1; + } +} + +static bool ByPriority(const DeviceType& a, const DeviceType& b) { + // Order by "order number"; break ties lexicographically. + return std::make_pair(Order(a), StringPiece(a.type())) < + std::make_pair(Order(b), StringPiece(b.type())); +} + +std::vector<DeviceType> DeviceSet::PrioritizedDeviceTypeList() const { + std::vector<DeviceType> result; + std::set<string> seen; + for (Device* d : devices_) { + auto t = d->device_type(); + if (seen.insert(t).second) { + result.emplace_back(DeviceType(t)); + } + } + std::sort(result.begin(), result.end(), ByPriority); + return result; +} + +} // namespace tensorflow |