aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/register_types_traits.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/register_types_traits.h')
-rw-r--r--tensorflow/core/framework/register_types_traits.h19
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/core/framework/register_types_traits.h b/tensorflow/core/framework/register_types_traits.h
index 8f8d9fd08e..c1fe5517c6 100644
--- a/tensorflow/core/framework/register_types_traits.h
+++ b/tensorflow/core/framework/register_types_traits.h
@@ -21,6 +21,10 @@ limitations under the License.
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
+
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/platform/types.h"
@@ -66,6 +70,17 @@ struct proxy_type_pod<GPUDevice, 2> {
typedef Eigen::half type;
};
+#ifdef TENSORFLOW_USE_SYCL
+template <>
+struct proxy_type_pod<SYCLDevice, 8> {
+ typedef double type;
+};
+template <>
+struct proxy_type_pod<SYCLDevice, 4> {
+ typedef float type;
+};
+#endif // TENSORFLOW_USE_SYCL
+
/// If POD we use proxy_type_pod, otherwise this maps to identiy.
template <typename Device, typename T>
struct proxy_type {
@@ -81,6 +96,10 @@ struct proxy_type {
TF_CALL_int8(m) TF_CALL_complex128(m)
#define TF_CALL_GPU_PROXY_TYPES(m) \
TF_CALL_double(m) TF_CALL_float(m) TF_CALL_half(m) TF_CALL_int32(m)
+#ifdef TENSORFLOW_USE_SYCL
+#define TF_CALL_SYCL_PROXY_TYPES(m) \
+ TF_CALL_double(m) TF_CALL_float(m) TF_CALL_int32(m)
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_TRAITS_H_