diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorSyclRun.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorSyclRun.h | 43 |
1 files changed, 14 insertions, 29 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclRun.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclRun.h index 3758d46a0..57f2dda26 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclRun.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclRun.h @@ -20,8 +20,8 @@ * *****************************************************************/ -#ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSORSYCL_SYCLRUN_HPP -#define UNSUPPORTED_EIGEN_CXX11_SRC_TENSORSYCL_SYCLRUN_HPP +#ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_SYCLRUN_HPP +#define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_SYCLRUN_HPP namespace Eigen { namespace TensorSycl { @@ -34,17 +34,14 @@ void run(Expr &expr, Dev &dev) { Eigen::TensorEvaluator<Expr, Dev> evaluator(expr, dev); const bool needs_assign = evaluator.evalSubExprsIfNeeded(NULL); if (needs_assign) { - using PlaceHolderExpr = - typename internal::createPlaceHolderExpression<Expr>::Type; + typedef typename internal::createPlaceHolderExpression<Expr>::Type PlaceHolderExpr; auto functors = internal::extractFunctors(evaluator); dev.m_queue.submit([&](cl::sycl::handler &cgh) { // create a tuple of accessors from Evaluator - auto tuple_of_accessors = - internal::createTupleOfAccessors<decltype(evaluator)>(cgh, evaluator); - const auto range = - utility::tuple::get<0>(tuple_of_accessors).get_range()[0]; + auto tuple_of_accessors = internal::createTupleOfAccessors<decltype(evaluator)>(cgh, evaluator); + const auto range = utility::tuple::get<0>(tuple_of_accessors).get_range()[0]; size_t outTileSize = range; if (range > 64) outTileSize = 64; @@ -53,26 +50,14 @@ void run(Expr &expr, Dev &dev) { if (yMode != 0) yRange += (outTileSize - yMode); // run the kernel - cgh.parallel_for<PlaceHolderExpr>( - cl::sycl::nd_range<1>(cl::sycl::range<1>(yRange), - cl::sycl::range<1>(outTileSize)), - [=](cl::sycl::nd_item<1> itemID) { - using DevExpr = - typename internal::ConvertToDeviceExpression<Expr>::Type; - - auto device_expr = - internal::createDeviceExpression<DevExpr, PlaceHolderExpr>( - functors, tuple_of_accessors); - auto device_evaluator = - Eigen::TensorEvaluator<decltype(device_expr.expr), - Eigen::DefaultDevice>( - device_expr.expr, Eigen::DefaultDevice()); - - if (itemID.get_global_linear_id() < range) { - device_evaluator.evalScalar( - static_cast<int>(itemID.get_global_linear_id())); - } - }); + cgh.parallel_for<PlaceHolderExpr>( cl::sycl::nd_range<1>(cl::sycl::range<1>(yRange), cl::sycl::range<1>(outTileSize)), [=](cl::sycl::nd_item<1> itemID) { + typedef typename internal::ConvertToDeviceExpression<Expr>::Type DevExpr; + auto device_expr =internal::createDeviceExpression<DevExpr, PlaceHolderExpr>(functors, tuple_of_accessors); + auto device_evaluator = Eigen::TensorEvaluator<decltype(device_expr.expr), Eigen::DefaultDevice>(device_expr.expr, Eigen::DefaultDevice()); + if (itemID.get_global_linear_id() < range) { + device_evaluator.evalScalar(static_cast<int>(itemID.get_global_linear_id())); + } + }); }); dev.m_queue.throw_asynchronous(); } @@ -81,4 +66,4 @@ void run(Expr &expr, Dev &dev) { } // namespace TensorSycl } // namespace Eigen -#endif // UNSUPPORTED_EIGEN_CXX11_SRC_TENSORSYCL_SYCLRUN_HPP +#endif // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_SYCLRUN_HPP |