aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/reduction_ops_common.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/reduction_ops_common.h')
-rw-r--r--tensorflow/core/kernels/reduction_ops_common.h29
1 files changed, 22 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h
index 1bb1a9fc50..625cea4228 100644
--- a/tensorflow/core/kernels/reduction_ops_common.h
+++ b/tensorflow/core/kernels/reduction_ops_common.h
@@ -40,6 +40,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
template <typename Device>
struct Constants {
@@ -60,13 +63,16 @@ struct Constants {
};
#if defined(EIGEN_HAS_INDEX_LIST)
-template <>
-struct Constants<CPUDevice> {
+struct ConstantsBase {
const Eigen::IndexList<Eigen::type2index<0>> kZero;
const Eigen::IndexList<Eigen::type2index<1>> kOne;
const Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<2>> kZeroTwo;
};
-#endif
+template<> struct Constants<CPUDevice> : ConstantsBase{};
+#ifdef TENSORFLOW_USE_SYCL
+template<> struct Constants<SYCLDevice> : ConstantsBase{};
+#endif // TENSORFLOW_USE_SYCL
+#endif // EIGEN_HAS_INDEX_LIST
class ReductionHelper {
public:
@@ -239,22 +245,31 @@ class ReductionOp : public OpKernel {
namespace functor {
-template <typename Reducer>
-struct ReduceFunctor<CPUDevice, Reducer> {
+template <typename Device, typename Reducer>
+struct ReduceFunctorBase {
template <typename OUT_T, typename IN_T, typename ReductionAxes>
- static void Reduce(const CPUDevice& d, OUT_T out, IN_T in,
+ static void Reduce(const Device& d, OUT_T out, IN_T in,
const ReductionAxes& reduction_axes,
const Reducer& reducer) {
ReduceEigenImpl(d, out, in, reduction_axes, reducer);
}
template <typename OUT_T>
- static void FillIdentity(const CPUDevice& d, OUT_T out,
+ static void FillIdentity(const Device& d, OUT_T out,
const Reducer& reducer) {
FillIdentityEigenImpl(d, out, reducer);
}
};
+template <typename Reducer>
+struct ReduceFunctor<CPUDevice, Reducer>
+ : ReduceFunctorBase<CPUDevice, Reducer>{};
+#if TENSORFLOW_USE_SYCL
+template <typename Reducer>
+struct ReduceFunctor<SYCLDevice, Reducer>
+ : ReduceFunctorBase<SYCLDevice, Reducer>{};
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace functor
} // namespace tensorflow