From bed8383c27a0a7225e6fc7ff59a2cd6388fb4d09 Mon Sep 17 00:00:00 2001 From: Jonathan Hseu Date: Thu, 22 Dec 2016 15:38:30 -0800 Subject: Merge changes from github. Change: 142805270 --- tensorflow/core/kernels/scatter_op.cc | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) (limited to 'tensorflow/core/kernels/scatter_op.cc') diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc index 604f753db1..827eb7dbca 100644 --- a/tensorflow/core/kernels/scatter_op.cc +++ b/tensorflow/core/kernels/scatter_op.cc @@ -27,6 +27,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL // Check whether updates.shape = indices.shape + params.shape[1:] static bool ValidShapes(const Tensor& params, const Tensor& updates, @@ -170,6 +173,20 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU); #endif // GOOGLE_CUDA +// Registers GPU kernels. +#if TENSORFLOW_USE_SYCL +#define REGISTER_SCATTER_ARITHEMTIC_SYCL(type) \ + REGISTER_SCATTER_ARITHEMTIC(type, SYCL); + +#define REGISTER_SCATTER_UPDATE_SYCL(type) REGISTER_SCATTER_UPDATE(type, SYCL); + +REGISTER_SCATTER_ARITHEMTIC_SYCL(float); +REGISTER_SCATTER_UPDATE_SYCL(float); + +#undef REGISTER_SCATTER_ARITHEMTIC_SYCL +#undef REGISTER_SCATTER_UPDATE_SYCL +#endif // TENSORFLOW_USE_SYCL + #undef REGISTER_SCATTER_ARITHEMTIC #undef REGISTER_SCATTER_ARITHEMTIC_CPU #undef REGISTER_SCATTER_ARITHEMTIC_GPU -- cgit v1.2.3