diff options
Diffstat (limited to 'tensorflow/core/kernels/cwise_ops_gradients.h')
-rw-r--r-- | tensorflow/core/kernels/cwise_ops_gradients.h | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/cwise_ops_gradients.h b/tensorflow/core/kernels/cwise_ops_gradients.h index 671de380d3..77b330f589 100644 --- a/tensorflow/core/kernels/cwise_ops_gradients.h +++ b/tensorflow/core/kernels/cwise_ops_gradients.h @@ -171,6 +171,21 @@ struct SimpleBinaryFunctor<CPUDevice, Functor> { } }; + +#ifdef TENSORFLOW_USE_SYCL +// Partial specialization of BinaryFunctor for SYCL devices +typedef Eigen::SyclDevice SYCLDevice; +template <typename Functor> +struct SimpleBinaryFunctor<SYCLDevice, Functor> { + void operator()(const SYCLDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in0, + typename Functor::tin_type in1) { + out.device(d) = in0.binaryExpr(in1, typename Functor::func()); + } +}; + +#endif // TENSORFLOW_USE_SYCL + template <typename T> struct tanh_grad : base<T, Eigen::internal::scalar_tanh_gradient_op<T>> {}; |