aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorSyclPlaceHolderExpr.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorSyclPlaceHolderExpr.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorSyclPlaceHolderExpr.h96
1 files changed, 83 insertions, 13 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclPlaceHolderExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclPlaceHolderExpr.h
index 74566dcee..9d5708fc5 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclPlaceHolderExpr.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclPlaceHolderExpr.h
@@ -143,17 +143,52 @@ FORCEDEVAL(const)
FORCEDEVAL()
#undef FORCEDEVAL
+
+/// specialisation of the \ref PlaceHolderExpression when the node is
+/// TensorForcedEvalOp
+#define CUSTOMUNARYOPEVAL(CVQual)\
+template <typename CustomUnaryFunc, typename XprType, size_t N>\
+struct PlaceHolderExpression<CVQual TensorCustomUnaryOp<CustomUnaryFunc, XprType>, N> {\
+ typedef CVQual PlaceHolder<CVQual TensorCustomUnaryOp<CustomUnaryFunc, XprType>, N> Type;\
+};
+
+CUSTOMUNARYOPEVAL(const)
+CUSTOMUNARYOPEVAL()
+#undef CUSTOMUNARYOPEVAL
+
+
/// specialisation of the \ref PlaceHolderExpression when the node is
-/// TensorEvalToOp
-#define EVALTO(CVQual)\
+/// TensorForcedEvalOp
+#define CUSTOMBINARYOPEVAL(CVQual)\
+template <typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, size_t N>\
+struct PlaceHolderExpression<CVQual TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, N> {\
+ typedef CVQual PlaceHolder<CVQual TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, N> Type;\
+};
+
+CUSTOMBINARYOPEVAL(const)
+CUSTOMBINARYOPEVAL()
+#undef CUSTOMBINARYOPEVAL
+
+
+/// specialisation of the \ref PlaceHolderExpression when the node is
+/// TensoroOp, TensorLayoutSwapOp, and TensorIndexTupleOp
+#define EVALTOLAYOUTSWAPINDEXTUPLE(CVQual, ExprNode)\
template <typename Expr, size_t N>\
-struct PlaceHolderExpression<CVQual TensorEvalToOp<Expr>, N> {\
- typedef CVQual TensorEvalToOp<typename CalculateIndex <N, Expr>::ArgType> Type;\
+struct PlaceHolderExpression<CVQual ExprNode<Expr>, N> {\
+ typedef CVQual ExprNode<typename CalculateIndex <N, Expr>::ArgType> Type;\
};
-EVALTO(const)
-EVALTO()
-#undef EVALTO
+// TensorEvalToOp
+EVALTOLAYOUTSWAPINDEXTUPLE(const, TensorEvalToOp)
+EVALTOLAYOUTSWAPINDEXTUPLE(, TensorEvalToOp)
+//TensorLayoutSwapOp
+EVALTOLAYOUTSWAPINDEXTUPLE(const, TensorLayoutSwapOp)
+EVALTOLAYOUTSWAPINDEXTUPLE(, TensorLayoutSwapOp)
+//TensorIndexTupleOp
+EVALTOLAYOUTSWAPINDEXTUPLE(const, TensorIndexTupleOp)
+EVALTOLAYOUTSWAPINDEXTUPLE(, TensorIndexTupleOp)
+
+#undef EVALTOLAYOUTSWAPINDEXTUPLE
/// specialisation of the \ref PlaceHolderExpression when the node is
@@ -169,17 +204,24 @@ CHIPPINGOP()
#undef CHIPPINGOP
/// specialisation of the \ref PlaceHolderExpression when the node is
-/// TensorReductionOp
-#define SYCLREDUCTION(CVQual)\
+/// TensorReductionOp and TensorTupleReducerOp (Argmax)
+#define SYCLREDUCTION(CVQual, ExprNode)\
template <typename OP, typename Dims, typename Expr, size_t N>\
-struct PlaceHolderExpression<CVQual TensorReductionOp<OP, Dims, Expr>, N>{\
- typedef CVQual PlaceHolder<CVQual TensorReductionOp<OP, Dims,Expr>, N> Type;\
+struct PlaceHolderExpression<CVQual ExprNode<OP, Dims, Expr>, N>{\
+ typedef CVQual PlaceHolder<CVQual ExprNode<OP, Dims,Expr>, N> Type;\
};
-SYCLREDUCTION(const)
-SYCLREDUCTION()
+
+// tensor reduction
+SYCLREDUCTION(const, TensorReductionOp)
+SYCLREDUCTION(, TensorReductionOp)
+
+// tensor Argmax -TensorTupleReducerOp
+SYCLREDUCTION(const, TensorTupleReducerOp)
+SYCLREDUCTION(, TensorTupleReducerOp)
#undef SYCLREDUCTION
+
/// specialisation of the \ref PlaceHolderExpression when the node is
/// TensorReductionOp
#define SYCLCONTRACTIONCONVOLUTIONPLH(CVQual, ExprNode)\
@@ -218,6 +260,34 @@ SYCLSLICESTRIDEOPPLH()
#undef SYCLSLICESTRIDEOPPLH
+
+/// specialisation of the \ref PlaceHolderExpression when the node is
+/// TensorImagePatchOp
+#define SYCLTENSORIMAGEPATCHOP(CVQual)\
+template<DenseIndex Rows, DenseIndex Cols, typename XprType, size_t N>\
+struct PlaceHolderExpression<CVQual TensorImagePatchOp<Rows, Cols, XprType>, N> {\
+ typedef CVQual TensorImagePatchOp<Rows, Cols, typename CalculateIndex <N, XprType>::ArgType> Type;\
+};
+
+SYCLTENSORIMAGEPATCHOP(const)
+SYCLTENSORIMAGEPATCHOP()
+#undef SYCLTENSORIMAGEPATCHOP
+
+
+
+/// specialisation of the \ref PlaceHolderExpression when the node is
+/// TensorVolumePatchOp
+#define SYCLTENSORVOLUMEPATCHOP(CVQual)\
+template<DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, typename XprType, size_t N>\
+struct PlaceHolderExpression<CVQual TensorVolumePatchOp<Planes,Rows, Cols, XprType>, N> {\
+ typedef CVQual TensorVolumePatchOp<Planes,Rows, Cols, typename CalculateIndex <N, XprType>::ArgType> Type;\
+};
+
+SYCLTENSORVOLUMEPATCHOP(const)
+SYCLTENSORVOLUMEPATCHOP()
+#undef SYCLTENSORVOLUMEPATCHOP
+
+
/// template deduction for \ref PlaceHolderExpression struct
template <typename Expr>
struct createPlaceHolderExpression {