diff options
Diffstat (limited to 'tensorflow/core/framework/register_types_traits.h')
-rw-r--r-- | tensorflow/core/framework/register_types_traits.h | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/core/framework/register_types_traits.h b/tensorflow/core/framework/register_types_traits.h index 8f8d9fd08e..c1fe5517c6 100644 --- a/tensorflow/core/framework/register_types_traits.h +++ b/tensorflow/core/framework/register_types_traits.h @@ -21,6 +21,10 @@ limitations under the License. typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL + #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/platform/types.h" @@ -66,6 +70,17 @@ struct proxy_type_pod<GPUDevice, 2> { typedef Eigen::half type; }; +#ifdef TENSORFLOW_USE_SYCL +template <> +struct proxy_type_pod<SYCLDevice, 8> { + typedef double type; +}; +template <> +struct proxy_type_pod<SYCLDevice, 4> { + typedef float type; +}; +#endif // TENSORFLOW_USE_SYCL + /// If POD we use proxy_type_pod, otherwise this maps to identiy. template <typename Device, typename T> struct proxy_type { @@ -81,6 +96,10 @@ struct proxy_type { TF_CALL_int8(m) TF_CALL_complex128(m) #define TF_CALL_GPU_PROXY_TYPES(m) \ TF_CALL_double(m) TF_CALL_float(m) TF_CALL_half(m) TF_CALL_int32(m) +#ifdef TENSORFLOW_USE_SYCL +#define TF_CALL_SYCL_PROXY_TYPES(m) \ + TF_CALL_double(m) TF_CALL_float(m) TF_CALL_int32(m) +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow #endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_TRAITS_H_ |