diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2019-06-28 11:22:46 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2019-06-28 11:22:46 -0700 |
commit | 6e7c76481adeef47535aff2e15526ffa0d00eee0 (patch) | |
tree | 352e6e1c3235f781f47a805e93370a5bf5f27287 /unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h | |
parent | 878845cb25c1ba9e56883fd0654eafb55a22fc34 (diff) | |
parent | 1f61aee5ca3a1372e7cabf6dc8725d4b54ec54ce (diff) |
Merge with Eigen head
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h | 133 |
1 files changed, 127 insertions, 6 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h index 647c98d4e..f1ae548f7 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h @@ -442,12 +442,133 @@ EIGEN_STRONG_INLINE void TensorExecutor<Expression, GpuDevice, Vectorizable, Til // SYCL Executor policy #ifdef EIGEN_USE_SYCL -template <typename Expression, bool Vectorizable> -class TensorExecutor<Expression, SyclDevice, Vectorizable> { -public: - static EIGEN_STRONG_INLINE void run(const Expression &expr, const SyclDevice &device) { - // call TensorSYCL module - TensorSycl::run(expr, device); +template <bool Vectorizable, typename Evaluator> +struct ExecExprFunctorKernel_impl { + typedef typename Evaluator::Index Index; + const Index range; + const Index vectorizable_threads; + Evaluator evaluator; + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ExecExprFunctorKernel_impl( + const Index range_, const Index vectorizable_threads_, + Evaluator evaluator_) + : range(range_), vectorizable_threads(vectorizable_threads_), + evaluator(evaluator_) {} + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void + operator()(cl::sycl::nd_item<1> itemID) { + Index gId = static_cast<Index>(itemID.get_global_linear_id()); + Index total_threads = itemID.get_global_range(0); + EIGEN_UNROLL_LOOP + for (Index i = gId; i < range; i += total_threads) { + evaluator.evalScalar(i); + } + } +}; + +template <typename Evaluator> +struct ExecExprFunctorKernel_impl<true, Evaluator> { + typedef typename Evaluator::Index Index; + const Index range; + const Index vectorizable_threads; + Evaluator evaluator; + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ExecExprFunctorKernel_impl( + const Index range_, const Index vectorizable_threads_, + Evaluator evaluator_) + : range(range_), vectorizable_threads(vectorizable_threads_), + evaluator(evaluator_) {} + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void + operator()(cl::sycl::nd_item<1> itemID) { + Index gId = static_cast<Index>(itemID.get_global_linear_id()); + if (gId < vectorizable_threads) { + const Index PacketSize = Eigen::internal::unpacket_traits< + typename Evaluator::PacketReturnType>::size; + evaluator.evalPacket(gId * PacketSize); + gId += (vectorizable_threads * PacketSize); + EIGEN_UNROLL_LOOP + for (Index i = gId; i < range; i += vectorizable_threads) { + evaluator.evalScalar(i); + } + } + } +}; + +template <typename Expr, bool NonZeroVectoriseSize, typename Evaluator> +struct ExecExprFunctorKernel + : ExecExprFunctorKernel_impl< + ::Eigen::internal::IsVectorizable<Eigen::SyclDevice, Expr>::value, + Evaluator> { + ExecExprFunctorKernel(const Index range_, const Index vectorizable_threads_, + const Evaluator &evaluator) + : ExecExprFunctorKernel_impl< + ::Eigen::internal::IsVectorizable<Eigen::SyclDevice, Expr>::value, + Evaluator>(range_, vectorizable_threads_, evaluator) {} +}; + +template <typename Expr, typename Evaluator> +struct ExecExprFunctorKernel<Expr, false, Evaluator> + : ExecExprFunctorKernel_impl<false, Evaluator> { + ExecExprFunctorKernel(const Index range_, const Index vectorizable_threads_, + const Evaluator &evaluator) + : ExecExprFunctorKernel_impl<false, Evaluator>( + range_, vectorizable_threads_, evaluator) {} +}; + +template <typename Expression, bool Vectorizable, bool Tileable> +class TensorExecutor<Expression, Eigen::SyclDevice, Vectorizable, Tileable> { + public: + typedef typename Expression::Index Index; + static EIGEN_STRONG_INLINE void run(const Expression &expr, const Eigen::SyclDevice &dev) { + Eigen::TensorEvaluator<Expression, Eigen::SyclDevice> evaluator(expr, dev); + const bool needs_assign = evaluator.evalSubExprsIfNeeded(NULL); + if (needs_assign) { + Index range, GRange, tileSize; + Index total_size = ::Eigen::internal::array_prod(evaluator.dimensions()); + total_size = (total_size == 0) ? 1 : total_size; + const int PacketSize = Eigen::PacketType< + typename Eigen::TensorEvaluator<Expression, Eigen::SyclDevice>::CoeffReturnType, + Eigen::SyclDevice>::size; + Index vectorizable_threads = + static_cast<Index>(total_size / PacketSize); + dev.parallel_for_setup(vectorizable_threads, tileSize, range, GRange); + range = total_size; + auto f = [&](cl::sycl::handler &cgh) { + evaluator.bind(cgh); + typedef ExecExprFunctorKernel<Expression, true, + Eigen::TensorEvaluator<Expression, Eigen::SyclDevice>> + conditional_vectorized_kernel; + + typedef ExecExprFunctorKernel<Expression, false, + Eigen::TensorEvaluator<Expression, Eigen::SyclDevice>> + non_vectorized_kernel; +// This is to make sure that an expression with a size less than vectorized size +// will not call the vectorized kernel. +// The reason for having this kernel is that the vectorisable parameter is a +// compile-time parameter, +// however, the size of a tensor is a run-time parameter + (vectorizable_threads) + ? cgh.parallel_for( +#ifdef EIGEN_SYCL_USE_PROGRAM_CLASS + dev.program().template get_kernel<vectorized_kernel>(), +#endif + cl::sycl::nd_range<1>(cl::sycl::range<1>(GRange), + cl::sycl::range<1>(tileSize)), + conditional_vectorized_kernel(range, vectorizable_threads, + evaluator)) + : cgh.parallel_for( +#ifdef EIGEN_SYCL_USE_PROGRAM_CLASS + dev.program().template get_kernel<non_vectorized_kernel>(), +#endif + cl::sycl::nd_range<1>(cl::sycl::range<1>(GRange), + cl::sycl::range<1>(tileSize)), + non_vectorized_kernel(range, vectorizable_threads, + evaluator)); + }; + cl::sycl::event e; + EIGEN_SYCL_TRY_CATCH(e = dev.sycl_queue().submit(f)); + dev.async_synchronize(e); + } + evaluator.cleanup(); } }; |