diff options
Diffstat (limited to 'tensorflow/core/common_runtime/device_mgr.cc')
-rw-r--r-- | tensorflow/core/common_runtime/device_mgr.cc | 90 |
1 files changed, 90 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc new file mode 100644 index 0000000000..4fa13f6b4b --- /dev/null +++ b/tensorflow/core/common_runtime/device_mgr.cc @@ -0,0 +1,90 @@ +#include "tensorflow/core/common_runtime/device_mgr.h" + +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +DeviceMgr::DeviceMgr(const std::vector<Device*>& devices) { + for (Device* d : devices) { + devices_.push_back(d); + + // Register under both the full name and the local name. + device_map_[d->name()] = d; + device_map_[DeviceNameUtils::LocalName(d->name())] = d; + device_type_counts_[d->device_type()]++; + } +} + +DeviceMgr::~DeviceMgr() { + for (auto p : devices_) delete p; +} + +void DeviceMgr::ListDeviceAttributes( + std::vector<DeviceAttributes>* devices) const { + devices->reserve(devices_.size()); + for (Device* dev : devices_) { + devices->emplace_back(dev->attributes()); + } +} + +std::vector<Device*> DeviceMgr::ListDevices() const { + return std::vector<Device*>(devices_.begin(), devices_.end()); +} + +string DeviceMgr::DebugString() const { + string out; + for (Device* dev : devices_) { + strings::StrAppend(&out, dev->name(), "\n"); + } + return out; +} + +string DeviceMgr::DeviceMappingString() const { + string out; + for (Device* dev : devices_) { + if (!dev->attributes().physical_device_desc().empty()) { + strings::StrAppend(&out, dev->name(), " -> ", + dev->attributes().physical_device_desc(), "\n"); + } + } + return out; +} + +Status DeviceMgr::LookupDevice(const string& name, Device** device) const { + Status s; + auto iter = device_map_.find(name); + if (iter == device_map_.end()) { + return errors::InvalidArgument(name, " unknown device."); + } + *device = iter->second; + return Status::OK(); +} + +void DeviceMgr::ClearContainers(gtl::ArraySlice<string> containers) const { + Status s; + for (Device* dev : devices_) { + if (containers.empty()) { + s.Update(dev->resource_manager()->Cleanup( + dev->resource_manager()->default_container())); + } else { + for (const string& c : containers) { + s.Update(dev->resource_manager()->Cleanup(c)); + } + } + if (!s.ok()) { + LOG(WARNING) << s; + } + } +} + +int DeviceMgr::NumDeviceType(const string& type) const { + auto iter = device_type_counts_.find(type); + if (iter != device_type_counts_.end()) return iter->second; + return 0; +} + +} // namespace tensorflow |