aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/scatter_functor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/scatter_functor.h')
-rw-r--r--tensorflow/core/kernels/scatter_functor.h19
1 files changed, 15 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/scatter_functor.h b/tensorflow/core/kernels/scatter_functor.h
index a84d89c296..a27cc83e4c 100644
--- a/tensorflow/core/kernels/scatter_functor.h
+++ b/tensorflow/core/kernels/scatter_functor.h
@@ -25,6 +25,9 @@ namespace tensorflow {
class OpKernelContext;
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
namespace scatter_op {
@@ -82,10 +85,9 @@ struct ScatterFunctor {
typename TTypes<Index>::ConstFlat indices);
};
-// Specializations of scatter functor for CPU.
-template <typename T, typename Index, scatter_op::UpdateOp op>
-struct ScatterFunctor<CPUDevice, T, Index, op> {
- Index operator()(OpKernelContext* c, const CPUDevice& d,
+template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterFunctorBase {
+ Index operator()(OpKernelContext* c, const Device& d,
typename TTypes<T>::Matrix params,
typename TTypes<T>::ConstMatrix updates,
typename TTypes<Index>::ConstFlat indices) {
@@ -106,6 +108,15 @@ struct ScatterFunctor<CPUDevice, T, Index, op> {
}
};
+template <typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterFunctor<CPUDevice, T, Index, op>
+ : ScatterFunctorBase<CPUDevice, T, Index, op>{};
+#if TENSORFLOW_USE_SYCL
+template<typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterFunctor<SYCLDevice, T, Index, op>
+ : ScatterFunctorBase<SYCLDevice, T, Index, op>{};
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace functor
} // namespace tensorflow