diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h | 75 |
1 files changed, 61 insertions, 14 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h index a1c112f4d..234580c7c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h @@ -93,26 +93,58 @@ SYCLFORCEDEVALLEAFCOUNT(const) SYCLFORCEDEVALLEAFCOUNT() #undef SYCLFORCEDEVALLEAFCOUNT +#define SYCLCUSTOMUNARYOPLEAFCOUNT(CVQual)\ +template <typename CustomUnaryFunc, typename XprType>\ +struct LeafCount<CVQual TensorCustomUnaryOp<CustomUnaryFunc, XprType> > {\ +static const size_t Count =1;\ +}; + +SYCLCUSTOMUNARYOPLEAFCOUNT(const) +SYCLCUSTOMUNARYOPLEAFCOUNT() +#undef SYCLCUSTOMUNARYOPLEAFCOUNT + + +#define SYCLCUSTOMBINARYOPLEAFCOUNT(CVQual)\ +template <typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>\ +struct LeafCount<CVQual TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> > {\ +static const size_t Count =1;\ +}; +SYCLCUSTOMBINARYOPLEAFCOUNT( const) +SYCLCUSTOMBINARYOPLEAFCOUNT() +#undef SYCLCUSTOMBINARYOPLEAFCOUNT + /// specialisation of the \ref LeafCount struct when the node type is TensorEvalToOp -#define EVALTOLEAFCOUNT(CVQual)\ +#define EVALTOLAYOUTSWAPINDEXTUPLELEAFCOUNT(CVQual , ExprNode, Num)\ template <typename Expr>\ -struct LeafCount<CVQual TensorEvalToOp<Expr> > {\ - static const size_t Count = 1 + CategoryCount<Expr>::Count;\ +struct LeafCount<CVQual ExprNode<Expr> > {\ + static const size_t Count = Num + CategoryCount<Expr>::Count;\ }; -EVALTOLEAFCOUNT(const) -EVALTOLEAFCOUNT() -#undef EVALTOLEAFCOUNT +EVALTOLAYOUTSWAPINDEXTUPLELEAFCOUNT(const, TensorEvalToOp, 1) +EVALTOLAYOUTSWAPINDEXTUPLELEAFCOUNT(, TensorEvalToOp, 1) +EVALTOLAYOUTSWAPINDEXTUPLELEAFCOUNT(const, TensorLayoutSwapOp, 0) +EVALTOLAYOUTSWAPINDEXTUPLELEAFCOUNT(, TensorLayoutSwapOp, 0) + +EVALTOLAYOUTSWAPINDEXTUPLELEAFCOUNT(const, TensorIndexTupleOp, 0) +EVALTOLAYOUTSWAPINDEXTUPLELEAFCOUNT(, TensorIndexTupleOp, 0) + +#undef EVALTOLAYOUTSWAPINDEXTUPLELEAFCOUNT /// specialisation of the \ref LeafCount struct when the node type is const TensorReductionOp -#define REDUCTIONLEAFCOUNT(CVQual)\ +#define REDUCTIONLEAFCOUNT(CVQual, ExprNode)\ template <typename OP, typename Dim, typename Expr>\ -struct LeafCount<CVQual TensorReductionOp<OP, Dim, Expr> > {\ +struct LeafCount<CVQual ExprNode<OP, Dim, Expr> > {\ static const size_t Count =1;\ }; -REDUCTIONLEAFCOUNT(const) -REDUCTIONLEAFCOUNT() +// TensorReductionOp +REDUCTIONLEAFCOUNT(const,TensorReductionOp) +REDUCTIONLEAFCOUNT(,TensorReductionOp) + +// tensor Argmax -TensorTupleReducerOp +REDUCTIONLEAFCOUNT(const, TensorTupleReducerOp) +REDUCTIONLEAFCOUNT(, TensorTupleReducerOp) + #undef REDUCTIONLEAFCOUNT /// specialisation of the \ref LeafCount struct when the node type is const TensorContractionOp @@ -128,8 +160,6 @@ CONTRACTIONCONVOLUTIONLEAFCOUNT(const,TensorConvolutionOp) CONTRACTIONCONVOLUTIONLEAFCOUNT(,TensorConvolutionOp) #undef CONTRACTIONCONVOLUTIONLEAFCOUNT - - /// specialisation of the \ref LeafCount struct when the node type is TensorSlicingOp #define SLICEOPLEAFCOUNT(CVQual)\ template <typename StartIndices, typename Sizes, typename XprType>\ @@ -139,7 +169,6 @@ SLICEOPLEAFCOUNT(const) SLICEOPLEAFCOUNT() #undef SLICEOPLEAFCOUNT - /// specialisation of the \ref LeafCount struct when the node type is TensorChippingOp #define CHIPPINGOPLEAFCOUNT(CVQual)\ template <DenseIndex DimId, typename XprType>\ @@ -149,7 +178,7 @@ CHIPPINGOPLEAFCOUNT(const) CHIPPINGOPLEAFCOUNT() #undef CHIPPINGOPLEAFCOUNT - +///TensorStridingSlicingOp #define SLICESTRIDEOPLEAFCOUNT(CVQual)\ template<typename StartIndices, typename StopIndices, typename Strides, typename XprType>\ struct LeafCount<CVQual TensorStridingSlicingOp<StartIndices, StopIndices, Strides, XprType> >:CategoryCount<XprType>{}; @@ -158,6 +187,24 @@ SLICESTRIDEOPLEAFCOUNT(const) SLICESTRIDEOPLEAFCOUNT() #undef SLICESTRIDEOPLEAFCOUNT +//TensorImagePatchOp +#define TENSORIMAGEPATCHOPLEAFCOUNT(CVQual)\ +template<DenseIndex Rows, DenseIndex Cols, typename XprType>\ +struct LeafCount<CVQual TensorImagePatchOp<Rows, Cols, XprType> >:CategoryCount<XprType>{}; + + +TENSORIMAGEPATCHOPLEAFCOUNT(const) +TENSORIMAGEPATCHOPLEAFCOUNT() +#undef TENSORIMAGEPATCHOPLEAFCOUNT + +// TensorVolumePatchOp +#define TENSORVOLUMEPATCHOPLEAFCOUNT(CVQual)\ +template<DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, typename XprType>\ +struct LeafCount<CVQual TensorVolumePatchOp<Planes, Rows, Cols, XprType> >:CategoryCount<XprType>{}; + +TENSORVOLUMEPATCHOPLEAFCOUNT(const) +TENSORVOLUMEPATCHOPLEAFCOUNT() +#undef TENSORVOLUMEPATCHOPLEAFCOUNT } /// namespace TensorSycl } /// namespace internal |