diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h | 25 |
1 files changed, 20 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h index 3d3142996..4433fec01 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h @@ -31,7 +31,6 @@ template <typename PtrType, size_t N, typename... Params> struct EvalToLHSConstructor { PtrType expr; EvalToLHSConstructor(const utility::tuple::Tuple<Params...> &t) : expr(ConvertToActualTypeSycl(typename Eigen::internal::remove_all<PtrType>::type, utility::tuple::get<N>(t))) {} - //EvalToLHSConstructor(const utility::tuple::Tuple<Params...> &t): expr((&(*(utility::tuple::get<N>(t).get_pointer())))) {} }; /// \struct ExprConstructor is used to reconstruct the expression on the device and @@ -57,8 +56,6 @@ CVQual PlaceHolder<CVQual TensorMap<T, Options_, MakePointer_>, N>, Params...>{\ : expr(Type(ConvertToActualTypeSycl(typename Type::Scalar, utility::tuple::get<N>(t)), fd.dimensions())){}\ }; -//: expr(Type((&(*(utility::tuple::get<N>(t).get_pointer()))), fd.dimensions())) {} - TENSORMAP(const) TENSORMAP() @@ -198,7 +195,6 @@ CVQual PlaceHolder<CVQual TensorForcedEvalOp<DevExpr>, N>, Params...> {\ ExprConstructor(FuncDetector &fd, const utility::tuple::Tuple<Params...> &t)\ : expr(Type(ConvertToActualTypeSycl(typename Type::Scalar, utility::tuple::get<N>(t)), fd.dimensions())) {}\ }; -//: expr(Type((&(*(utility::tuple::get<N>(t).get_pointer()))), fd.dimensions())) {} FORCEDEVAL(const) FORCEDEVAL() @@ -224,7 +220,6 @@ CVQual PlaceHolder<CVQual TensorReductionOp<OP, Dim, DevExpr>, N>, Params...> {\ ExprConstructor(FuncDetector &fd, const utility::tuple::Tuple<Params...> &t)\ :expr(Type(ConvertToActualTypeSycl(typename Type::Scalar, utility::tuple::get<N>(t)), fd.dimensions())) {}\ }; -//: expr(Type((&(*(utility::tuple::get<N>(t).get_pointer()))), fd.dimensions())) {} SYCLREDUCTIONEXPR(const) SYCLREDUCTIONEXPR() @@ -249,6 +244,26 @@ SYCLSLICEOPEXPR() #undef SYCLSLICEOPEXPR +#define SYCLRESHAPEANDSHUFFLEOPEXPRCONST(OPEXPR, CVQual)\ +template<typename Param, typename OrigXprType, typename XprType, typename... Params>\ +struct ExprConstructor<CVQual OPEXPR <Param, OrigXprType> , CVQual OPEXPR <Param, XprType>, Params... >{\ + typedef ExprConstructor<OrigXprType, XprType, Params...> my_xpr_type;\ + typedef CVQual OPEXPR <Param, typename my_xpr_type::Type> Type ;\ + my_xpr_type xprExpr;\ + Type expr;\ + template <typename FuncDetector>\ + ExprConstructor(FuncDetector &funcD, const utility::tuple::Tuple<Params...> &t)\ + : xprExpr(funcD.xprExpr, t), expr(xprExpr.expr, funcD.param()) {}\ +}; + +SYCLRESHAPEANDSHUFFLEOPEXPRCONST(TensorReshapingOp, const) +SYCLRESHAPEANDSHUFFLEOPEXPRCONST(TensorReshapingOp, ) + +SYCLRESHAPEANDSHUFFLEOPEXPRCONST(TensorShufflingOp, const) +SYCLRESHAPEANDSHUFFLEOPEXPRCONST(TensorShufflingOp, ) +#undef SYCLRESHAPEANDSHUFFLEOPEXPRCONST + + /// template deduction for \ref ExprConstructor struct template <typename OrigExpr, typename IndexExpr, typename FuncD, typename... Params> auto createDeviceExpression(FuncD &funcD, const utility::tuple::Tuple<Params...> &t) |