aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/device_base.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/device_base.h')
-rw-r--r--tensorflow/core/framework/device_base.h17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index 8f0075dcd6..6edbda1276 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -30,6 +30,9 @@ limitations under the License.
namespace Eigen {
struct ThreadPoolDevice;
+#ifdef TENSORFLOW_USE_SYCL
+struct SyclDevice;
+#endif
} // end namespace Eigen
namespace perftools {
@@ -145,6 +148,10 @@ class DeviceBase {
eigen_cpu_device_ = d;
}
+#ifdef TENSORFLOW_USE_SYCL
+ void set_eigen_sycl_device(Eigen::SyclDevice* d) { eigen_sycl_device_ = d; }
+#endif
+
// Return the Allocator implementation to use based on the allocator
// attributes requested. See allocator.h for more details.
virtual Allocator* GetAllocator(AllocatorAttributes /*attr*/) {
@@ -167,6 +174,13 @@ class DeviceBase {
return eigen_cpu_device_;
}
+#ifdef TENSORFLOW_USE_SYCL
+ const Eigen::SyclDevice* eigen_sycl_device() const {
+ CHECK(eigen_sycl_device_ != nullptr);
+ return eigen_sycl_device_;
+ }
+#endif
+
// Caller owns the return value. The OpKernelContext calls this even
// for devices that do not implement an eigen_gpu_device. Overridden
// by GPU devices to return a derived type.
@@ -203,6 +217,9 @@ class DeviceBase {
CpuWorkerThreads* cpu_worker_threads_ = nullptr;
GpuDeviceInfo* gpu_device_info_ = nullptr;
Eigen::ThreadPoolDevice* eigen_cpu_device_ = nullptr;
+#ifdef TENSORFLOW_USE_SYCL
+ Eigen::SyclDevice* eigen_sycl_device_ = nullptr;
+#endif
};
} // namespace tensorflow