diff options
Diffstat (limited to 'tensorflow/core/common_runtime/sycl/sycl_device_factory.cc')
-rw-r--r-- | tensorflow/core/common_runtime/sycl/sycl_device_factory.cc | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc b/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc index a643fc7258..19c14770dc 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc +++ b/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc @@ -18,24 +18,34 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/sycl/sycl_device.h" +#include "tensorflow/core/common_runtime/sycl/sycl_util.h" + namespace tensorflow { class SYCLDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions &options, const string &name_prefix, std::vector<Device *> *devices) override { - int n = 1; + + auto syclInterface = GSYCLInterface::instance(); + + size_t n = 1; auto iter = options.config.device_count().find("SYCL"); if (iter != options.config.device_count().end()) { n = iter->second; } + for (int i = 0; i < n; i++) { string name = strings::StrCat(name_prefix, "/device:SYCL:", i); devices->push_back( - new SYCLDevice(options, name, Bytes(256 << 20), DeviceLocality(), - SYCLDevice::GetShortDeviceDescription(), - cl::sycl::gpu_selector(), cpu_allocator())); + new SYCLDevice(options, name, Bytes(256 << 20), DeviceLocality() + , syclInterface->GetShortDeviceDescription(i) + , syclInterface->GetSYCLAllocator(i) + , syclInterface->GetCPUAllocator(i) + , syclInterface->GetSYCLContext(i)) + ); } + return Status::OK(); } }; |