aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cwise_ops_gradients.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/cwise_ops_gradients.h')
-rw-r--r--tensorflow/core/kernels/cwise_ops_gradients.h15
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/cwise_ops_gradients.h b/tensorflow/core/kernels/cwise_ops_gradients.h
index 671de380d3..77b330f589 100644
--- a/tensorflow/core/kernels/cwise_ops_gradients.h
+++ b/tensorflow/core/kernels/cwise_ops_gradients.h
@@ -171,6 +171,21 @@ struct SimpleBinaryFunctor<CPUDevice, Functor> {
}
};
+
+#ifdef TENSORFLOW_USE_SYCL
+// Partial specialization of BinaryFunctor for SYCL devices
+typedef Eigen::SyclDevice SYCLDevice;
+template <typename Functor>
+struct SimpleBinaryFunctor<SYCLDevice, Functor> {
+ void operator()(const SYCLDevice& d, typename Functor::tout_type out,
+ typename Functor::tin_type in0,
+ typename Functor::tin_type in1) {
+ out.device(d) = in0.binaryExpr(in1, typename Functor::func());
+ }
+};
+
+#endif // TENSORFLOW_USE_SYCL
+
template <typename T>
struct tanh_grad : base<T, Eigen::internal::scalar_tanh_gradient_op<T>> {};