diff options
Diffstat (limited to 'tensorflow/core/kernels/reduction_ops_common.h')
-rw-r--r-- | tensorflow/core/kernels/reduction_ops_common.h | 29 |
1 files changed, 22 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h index 1bb1a9fc50..625cea4228 100644 --- a/tensorflow/core/kernels/reduction_ops_common.h +++ b/tensorflow/core/kernels/reduction_ops_common.h @@ -40,6 +40,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL template <typename Device> struct Constants { @@ -60,13 +63,16 @@ struct Constants { }; #if defined(EIGEN_HAS_INDEX_LIST) -template <> -struct Constants<CPUDevice> { +struct ConstantsBase { const Eigen::IndexList<Eigen::type2index<0>> kZero; const Eigen::IndexList<Eigen::type2index<1>> kOne; const Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<2>> kZeroTwo; }; -#endif +template<> struct Constants<CPUDevice> : ConstantsBase{}; +#ifdef TENSORFLOW_USE_SYCL +template<> struct Constants<SYCLDevice> : ConstantsBase{}; +#endif // TENSORFLOW_USE_SYCL +#endif // EIGEN_HAS_INDEX_LIST class ReductionHelper { public: @@ -239,22 +245,31 @@ class ReductionOp : public OpKernel { namespace functor { -template <typename Reducer> -struct ReduceFunctor<CPUDevice, Reducer> { +template <typename Device, typename Reducer> +struct ReduceFunctorBase { template <typename OUT_T, typename IN_T, typename ReductionAxes> - static void Reduce(const CPUDevice& d, OUT_T out, IN_T in, + static void Reduce(const Device& d, OUT_T out, IN_T in, const ReductionAxes& reduction_axes, const Reducer& reducer) { ReduceEigenImpl(d, out, in, reduction_axes, reducer); } template <typename OUT_T> - static void FillIdentity(const CPUDevice& d, OUT_T out, + static void FillIdentity(const Device& d, OUT_T out, const Reducer& reducer) { FillIdentityEigenImpl(d, out, reducer); } }; +template <typename Reducer> +struct ReduceFunctor<CPUDevice, Reducer> + : ReduceFunctorBase<CPUDevice, Reducer>{}; +#if TENSORFLOW_USE_SYCL +template <typename Reducer> +struct ReduceFunctor<SYCLDevice, Reducer> + : ReduceFunctorBase<SYCLDevice, Reducer>{}; +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow |