aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h25
1 files changed, 20 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h
index 3d3142996..4433fec01 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h
@@ -31,7 +31,6 @@ template <typename PtrType, size_t N, typename... Params>
struct EvalToLHSConstructor {
PtrType expr;
EvalToLHSConstructor(const utility::tuple::Tuple<Params...> &t) : expr(ConvertToActualTypeSycl(typename Eigen::internal::remove_all<PtrType>::type, utility::tuple::get<N>(t))) {}
- //EvalToLHSConstructor(const utility::tuple::Tuple<Params...> &t): expr((&(*(utility::tuple::get<N>(t).get_pointer())))) {}
};
/// \struct ExprConstructor is used to reconstruct the expression on the device and
@@ -57,8 +56,6 @@ CVQual PlaceHolder<CVQual TensorMap<T, Options_, MakePointer_>, N>, Params...>{\
: expr(Type(ConvertToActualTypeSycl(typename Type::Scalar, utility::tuple::get<N>(t)), fd.dimensions())){}\
};
-//: expr(Type((&(*(utility::tuple::get<N>(t).get_pointer()))), fd.dimensions())) {}
-
TENSORMAP(const)
TENSORMAP()
@@ -198,7 +195,6 @@ CVQual PlaceHolder<CVQual TensorForcedEvalOp<DevExpr>, N>, Params...> {\
ExprConstructor(FuncDetector &fd, const utility::tuple::Tuple<Params...> &t)\
: expr(Type(ConvertToActualTypeSycl(typename Type::Scalar, utility::tuple::get<N>(t)), fd.dimensions())) {}\
};
-//: expr(Type((&(*(utility::tuple::get<N>(t).get_pointer()))), fd.dimensions())) {}
FORCEDEVAL(const)
FORCEDEVAL()
@@ -224,7 +220,6 @@ CVQual PlaceHolder<CVQual TensorReductionOp<OP, Dim, DevExpr>, N>, Params...> {\
ExprConstructor(FuncDetector &fd, const utility::tuple::Tuple<Params...> &t)\
:expr(Type(ConvertToActualTypeSycl(typename Type::Scalar, utility::tuple::get<N>(t)), fd.dimensions())) {}\
};
-//: expr(Type((&(*(utility::tuple::get<N>(t).get_pointer()))), fd.dimensions())) {}
SYCLREDUCTIONEXPR(const)
SYCLREDUCTIONEXPR()
@@ -249,6 +244,26 @@ SYCLSLICEOPEXPR()
#undef SYCLSLICEOPEXPR
+#define SYCLRESHAPEANDSHUFFLEOPEXPRCONST(OPEXPR, CVQual)\
+template<typename Param, typename OrigXprType, typename XprType, typename... Params>\
+struct ExprConstructor<CVQual OPEXPR <Param, OrigXprType> , CVQual OPEXPR <Param, XprType>, Params... >{\
+ typedef ExprConstructor<OrigXprType, XprType, Params...> my_xpr_type;\
+ typedef CVQual OPEXPR <Param, typename my_xpr_type::Type> Type ;\
+ my_xpr_type xprExpr;\
+ Type expr;\
+ template <typename FuncDetector>\
+ ExprConstructor(FuncDetector &funcD, const utility::tuple::Tuple<Params...> &t)\
+ : xprExpr(funcD.xprExpr, t), expr(xprExpr.expr, funcD.param()) {}\
+};
+
+SYCLRESHAPEANDSHUFFLEOPEXPRCONST(TensorReshapingOp, const)
+SYCLRESHAPEANDSHUFFLEOPEXPRCONST(TensorReshapingOp, )
+
+SYCLRESHAPEANDSHUFFLEOPEXPRCONST(TensorShufflingOp, const)
+SYCLRESHAPEANDSHUFFLEOPEXPRCONST(TensorShufflingOp, )
+#undef SYCLRESHAPEANDSHUFFLEOPEXPRCONST
+
+
/// template deduction for \ref ExprConstructor struct
template <typename OrigExpr, typename IndexExpr, typename FuncD, typename... Params>
auto createDeviceExpression(FuncD &funcD, const utility::tuple::Tuple<Params...> &t)