diff options
Diffstat (limited to 'tensorflow/core/kernels/scatter_functor.h')
-rw-r--r-- | tensorflow/core/kernels/scatter_functor.h | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/scatter_functor.h b/tensorflow/core/kernels/scatter_functor.h index a84d89c296..a27cc83e4c 100644 --- a/tensorflow/core/kernels/scatter_functor.h +++ b/tensorflow/core/kernels/scatter_functor.h @@ -25,6 +25,9 @@ namespace tensorflow { class OpKernelContext; typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL namespace scatter_op { @@ -82,10 +85,9 @@ struct ScatterFunctor { typename TTypes<Index>::ConstFlat indices); }; -// Specializations of scatter functor for CPU. -template <typename T, typename Index, scatter_op::UpdateOp op> -struct ScatterFunctor<CPUDevice, T, Index, op> { - Index operator()(OpKernelContext* c, const CPUDevice& d, +template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> +struct ScatterFunctorBase { + Index operator()(OpKernelContext* c, const Device& d, typename TTypes<T>::Matrix params, typename TTypes<T>::ConstMatrix updates, typename TTypes<Index>::ConstFlat indices) { @@ -106,6 +108,15 @@ struct ScatterFunctor<CPUDevice, T, Index, op> { } }; +template <typename T, typename Index, scatter_op::UpdateOp op> +struct ScatterFunctor<CPUDevice, T, Index, op> + : ScatterFunctorBase<CPUDevice, T, Index, op>{}; +#if TENSORFLOW_USE_SYCL +template<typename T, typename Index, scatter_op::UpdateOp op> +struct ScatterFunctor<SYCLDevice, T, Index, op> + : ScatterFunctorBase<SYCLDevice, T, Index, op>{}; +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow |