aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-06-28 11:22:46 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-06-28 11:22:46 -0700
commit6e7c76481adeef47535aff2e15526ffa0d00eee0 (patch)
tree352e6e1c3235f781f47a805e93370a5bf5f27287 /unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
parent878845cb25c1ba9e56883fd0654eafb55a22fc34 (diff)
parent1f61aee5ca3a1372e7cabf6dc8725d4b54ec54ce (diff)
Merge with Eigen head
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h133
1 files changed, 127 insertions, 6 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
index 647c98d4e..f1ae548f7 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
@@ -442,12 +442,133 @@ EIGEN_STRONG_INLINE void TensorExecutor<Expression, GpuDevice, Vectorizable, Til
// SYCL Executor policy
#ifdef EIGEN_USE_SYCL
-template <typename Expression, bool Vectorizable>
-class TensorExecutor<Expression, SyclDevice, Vectorizable> {
-public:
- static EIGEN_STRONG_INLINE void run(const Expression &expr, const SyclDevice &device) {
- // call TensorSYCL module
- TensorSycl::run(expr, device);
+template <bool Vectorizable, typename Evaluator>
+struct ExecExprFunctorKernel_impl {
+ typedef typename Evaluator::Index Index;
+ const Index range;
+ const Index vectorizable_threads;
+ Evaluator evaluator;
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ExecExprFunctorKernel_impl(
+ const Index range_, const Index vectorizable_threads_,
+ Evaluator evaluator_)
+ : range(range_), vectorizable_threads(vectorizable_threads_),
+ evaluator(evaluator_) {}
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void
+ operator()(cl::sycl::nd_item<1> itemID) {
+ Index gId = static_cast<Index>(itemID.get_global_linear_id());
+ Index total_threads = itemID.get_global_range(0);
+ EIGEN_UNROLL_LOOP
+ for (Index i = gId; i < range; i += total_threads) {
+ evaluator.evalScalar(i);
+ }
+ }
+};
+
+template <typename Evaluator>
+struct ExecExprFunctorKernel_impl<true, Evaluator> {
+ typedef typename Evaluator::Index Index;
+ const Index range;
+ const Index vectorizable_threads;
+ Evaluator evaluator;
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ExecExprFunctorKernel_impl(
+ const Index range_, const Index vectorizable_threads_,
+ Evaluator evaluator_)
+ : range(range_), vectorizable_threads(vectorizable_threads_),
+ evaluator(evaluator_) {}
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void
+ operator()(cl::sycl::nd_item<1> itemID) {
+ Index gId = static_cast<Index>(itemID.get_global_linear_id());
+ if (gId < vectorizable_threads) {
+ const Index PacketSize = Eigen::internal::unpacket_traits<
+ typename Evaluator::PacketReturnType>::size;
+ evaluator.evalPacket(gId * PacketSize);
+ gId += (vectorizable_threads * PacketSize);
+ EIGEN_UNROLL_LOOP
+ for (Index i = gId; i < range; i += vectorizable_threads) {
+ evaluator.evalScalar(i);
+ }
+ }
+ }
+};
+
+template <typename Expr, bool NonZeroVectoriseSize, typename Evaluator>
+struct ExecExprFunctorKernel
+ : ExecExprFunctorKernel_impl<
+ ::Eigen::internal::IsVectorizable<Eigen::SyclDevice, Expr>::value,
+ Evaluator> {
+ ExecExprFunctorKernel(const Index range_, const Index vectorizable_threads_,
+ const Evaluator &evaluator)
+ : ExecExprFunctorKernel_impl<
+ ::Eigen::internal::IsVectorizable<Eigen::SyclDevice, Expr>::value,
+ Evaluator>(range_, vectorizable_threads_, evaluator) {}
+};
+
+template <typename Expr, typename Evaluator>
+struct ExecExprFunctorKernel<Expr, false, Evaluator>
+ : ExecExprFunctorKernel_impl<false, Evaluator> {
+ ExecExprFunctorKernel(const Index range_, const Index vectorizable_threads_,
+ const Evaluator &evaluator)
+ : ExecExprFunctorKernel_impl<false, Evaluator>(
+ range_, vectorizable_threads_, evaluator) {}
+};
+
+template <typename Expression, bool Vectorizable, bool Tileable>
+class TensorExecutor<Expression, Eigen::SyclDevice, Vectorizable, Tileable> {
+ public:
+ typedef typename Expression::Index Index;
+ static EIGEN_STRONG_INLINE void run(const Expression &expr, const Eigen::SyclDevice &dev) {
+ Eigen::TensorEvaluator<Expression, Eigen::SyclDevice> evaluator(expr, dev);
+ const bool needs_assign = evaluator.evalSubExprsIfNeeded(NULL);
+ if (needs_assign) {
+ Index range, GRange, tileSize;
+ Index total_size = ::Eigen::internal::array_prod(evaluator.dimensions());
+ total_size = (total_size == 0) ? 1 : total_size;
+ const int PacketSize = Eigen::PacketType<
+ typename Eigen::TensorEvaluator<Expression, Eigen::SyclDevice>::CoeffReturnType,
+ Eigen::SyclDevice>::size;
+ Index vectorizable_threads =
+ static_cast<Index>(total_size / PacketSize);
+ dev.parallel_for_setup(vectorizable_threads, tileSize, range, GRange);
+ range = total_size;
+ auto f = [&](cl::sycl::handler &cgh) {
+ evaluator.bind(cgh);
+ typedef ExecExprFunctorKernel<Expression, true,
+ Eigen::TensorEvaluator<Expression, Eigen::SyclDevice>>
+ conditional_vectorized_kernel;
+
+ typedef ExecExprFunctorKernel<Expression, false,
+ Eigen::TensorEvaluator<Expression, Eigen::SyclDevice>>
+ non_vectorized_kernel;
+// This is to make sure that an expression with a size less than vectorized size
+// will not call the vectorized kernel.
+// The reason for having this kernel is that the vectorisable parameter is a
+// compile-time parameter,
+// however, the size of a tensor is a run-time parameter
+ (vectorizable_threads)
+ ? cgh.parallel_for(
+#ifdef EIGEN_SYCL_USE_PROGRAM_CLASS
+ dev.program().template get_kernel<vectorized_kernel>(),
+#endif
+ cl::sycl::nd_range<1>(cl::sycl::range<1>(GRange),
+ cl::sycl::range<1>(tileSize)),
+ conditional_vectorized_kernel(range, vectorizable_threads,
+ evaluator))
+ : cgh.parallel_for(
+#ifdef EIGEN_SYCL_USE_PROGRAM_CLASS
+ dev.program().template get_kernel<non_vectorized_kernel>(),
+#endif
+ cl::sycl::nd_range<1>(cl::sycl::range<1>(GRange),
+ cl::sycl::range<1>(tileSize)),
+ non_vectorized_kernel(range, vectorizable_threads,
+ evaluator));
+ };
+ cl::sycl::event e;
+ EIGEN_SYCL_TRY_CATCH(e = dev.sycl_queue().submit(f));
+ dev.async_synchronize(e);
+ }
+ evaluator.cleanup();
}
};