diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-12-21 16:43:27 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-12-21 16:43:27 -0800 |
commit | 660da83e18adf77d2507410de7a9b20f3e7dcb85 (patch) | |
tree | 980683db0091222a421a8ba5b970462f542ae810 /unsupported/Eigen/CXX11/src | |
parent | 4236aebe103b0fa54f3b9e7e3c0c12094fa6e200 (diff) | |
parent | 3cfa16f41d0eddb83d15d99ea64af24ee5bdbb0c (diff) |
Pulled latest update from trunk
Diffstat (limited to 'unsupported/Eigen/CXX11/src')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorSyclRun.h | 42 |
1 files changed, 21 insertions, 21 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclRun.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclRun.h index 11e4ddc56..32930be26 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclRun.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclRun.h @@ -26,25 +26,25 @@ namespace Eigen { namespace TensorSycl { +template<typename Expr, typename FunctorExpr, typename TupleType > struct ExecExprFunctorKernel{ + typedef typename internal::createPlaceHolderExpression<Expr>::Type PlaceHolderExpr; - template<typename Expr, typename FunctorExpr, typename TupleType > struct ExecExprFunctorKernel{ - typedef typename internal::createPlaceHolderExpression<Expr>::Type PlaceHolderExpr; + typedef typename Expr::Index Index; + FunctorExpr functors; + TupleType tuple_of_accessors; + Index range; + ExecExprFunctorKernel(Index range_, FunctorExpr functors_, TupleType tuple_of_accessors_) + : functors(functors_), tuple_of_accessors(tuple_of_accessors_), range(range_){} + void operator()(cl::sycl::nd_item<1> itemID) { + typedef typename internal::ConvertToDeviceExpression<Expr>::Type DevExpr; + auto device_expr =internal::createDeviceExpression<DevExpr, PlaceHolderExpr>(functors, tuple_of_accessors); + auto device_evaluator = Eigen::TensorEvaluator<decltype(device_expr.expr), Eigen::DefaultDevice>(device_expr.expr, Eigen::DefaultDevice()); + typename DevExpr::Index gId = static_cast<typename DevExpr::Index>(itemID.get_global_linear_id()); + if (gId < range) + device_evaluator.evalScalar(gId); + } +}; - typedef typename Expr::Index Index; - FunctorExpr functors; - TupleType tuple_of_accessors; - Index range; - ExecExprFunctorKernel(Index range_, FunctorExpr functors_, TupleType tuple_of_accessors_) - : functors(functors_), tuple_of_accessors(tuple_of_accessors_), range(range_){} - void operator()(cl::sycl::nd_item<1> itemID) { - typedef typename internal::ConvertToDeviceExpression<Expr>::Type DevExpr; - auto device_expr =internal::createDeviceExpression<DevExpr, PlaceHolderExpr>(functors, tuple_of_accessors); - auto device_evaluator = Eigen::TensorEvaluator<decltype(device_expr.expr), Eigen::DefaultDevice>(device_expr.expr, Eigen::DefaultDevice()); - typename DevExpr::Index gId = static_cast<typename DevExpr::Index>(itemID.get_global_linear_id()); - if (gId < range) - device_evaluator.evalScalar(gId); - } - }; /// The run function in tensor sycl convert the expression tree to a buffer /// based expression tree; /// creates the expression tree for the device with accessor to buffers; @@ -54,12 +54,12 @@ void run(Expr &expr, Dev &dev) { Eigen::TensorEvaluator<Expr, Dev> evaluator(expr, dev); const bool needs_assign = evaluator.evalSubExprsIfNeeded(NULL); if (needs_assign) { - auto functors = internal::extractFunctors(evaluator); - typedef decltype(functors) FunctorExpr; + typedef decltype(internal::extractFunctors(evaluator)) FunctorExpr; + FunctorExpr functors = internal::extractFunctors(evaluator); dev.sycl_queue().submit([&](cl::sycl::handler &cgh) { // create a tuple of accessors from Evaluator - auto tuple_of_accessors = internal::createTupleOfAccessors<decltype(evaluator)>(cgh, evaluator); - typedef decltype(tuple_of_accessors) TupleType; + typedef decltype(internal::createTupleOfAccessors<decltype(evaluator)>(cgh, evaluator)) TupleType; + TupleType tuple_of_accessors = internal::createTupleOfAccessors<decltype(evaluator)>(cgh, evaluator); typename Expr::Index range, GRange, tileSize; dev.parallel_for_setup(static_cast<typename Expr::Index>(evaluator.dimensions().TotalSize()), tileSize, range, GRange); |