aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h75
1 files changed, 61 insertions, 14 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h
index a1c112f4d..234580c7c 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h
@@ -93,26 +93,58 @@ SYCLFORCEDEVALLEAFCOUNT(const)
SYCLFORCEDEVALLEAFCOUNT()
#undef SYCLFORCEDEVALLEAFCOUNT
+#define SYCLCUSTOMUNARYOPLEAFCOUNT(CVQual)\
+template <typename CustomUnaryFunc, typename XprType>\
+struct LeafCount<CVQual TensorCustomUnaryOp<CustomUnaryFunc, XprType> > {\
+static const size_t Count =1;\
+};
+
+SYCLCUSTOMUNARYOPLEAFCOUNT(const)
+SYCLCUSTOMUNARYOPLEAFCOUNT()
+#undef SYCLCUSTOMUNARYOPLEAFCOUNT
+
+
+#define SYCLCUSTOMBINARYOPLEAFCOUNT(CVQual)\
+template <typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>\
+struct LeafCount<CVQual TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> > {\
+static const size_t Count =1;\
+};
+SYCLCUSTOMBINARYOPLEAFCOUNT( const)
+SYCLCUSTOMBINARYOPLEAFCOUNT()
+#undef SYCLCUSTOMBINARYOPLEAFCOUNT
+
/// specialisation of the \ref LeafCount struct when the node type is TensorEvalToOp
-#define EVALTOLEAFCOUNT(CVQual)\
+#define EVALTOLAYOUTSWAPINDEXTUPLELEAFCOUNT(CVQual , ExprNode, Num)\
template <typename Expr>\
-struct LeafCount<CVQual TensorEvalToOp<Expr> > {\
- static const size_t Count = 1 + CategoryCount<Expr>::Count;\
+struct LeafCount<CVQual ExprNode<Expr> > {\
+ static const size_t Count = Num + CategoryCount<Expr>::Count;\
};
-EVALTOLEAFCOUNT(const)
-EVALTOLEAFCOUNT()
-#undef EVALTOLEAFCOUNT
+EVALTOLAYOUTSWAPINDEXTUPLELEAFCOUNT(const, TensorEvalToOp, 1)
+EVALTOLAYOUTSWAPINDEXTUPLELEAFCOUNT(, TensorEvalToOp, 1)
+EVALTOLAYOUTSWAPINDEXTUPLELEAFCOUNT(const, TensorLayoutSwapOp, 0)
+EVALTOLAYOUTSWAPINDEXTUPLELEAFCOUNT(, TensorLayoutSwapOp, 0)
+
+EVALTOLAYOUTSWAPINDEXTUPLELEAFCOUNT(const, TensorIndexTupleOp, 0)
+EVALTOLAYOUTSWAPINDEXTUPLELEAFCOUNT(, TensorIndexTupleOp, 0)
+
+#undef EVALTOLAYOUTSWAPINDEXTUPLELEAFCOUNT
/// specialisation of the \ref LeafCount struct when the node type is const TensorReductionOp
-#define REDUCTIONLEAFCOUNT(CVQual)\
+#define REDUCTIONLEAFCOUNT(CVQual, ExprNode)\
template <typename OP, typename Dim, typename Expr>\
-struct LeafCount<CVQual TensorReductionOp<OP, Dim, Expr> > {\
+struct LeafCount<CVQual ExprNode<OP, Dim, Expr> > {\
static const size_t Count =1;\
};
-REDUCTIONLEAFCOUNT(const)
-REDUCTIONLEAFCOUNT()
+// TensorReductionOp
+REDUCTIONLEAFCOUNT(const,TensorReductionOp)
+REDUCTIONLEAFCOUNT(,TensorReductionOp)
+
+// tensor Argmax -TensorTupleReducerOp
+REDUCTIONLEAFCOUNT(const, TensorTupleReducerOp)
+REDUCTIONLEAFCOUNT(, TensorTupleReducerOp)
+
#undef REDUCTIONLEAFCOUNT
/// specialisation of the \ref LeafCount struct when the node type is const TensorContractionOp
@@ -128,8 +160,6 @@ CONTRACTIONCONVOLUTIONLEAFCOUNT(const,TensorConvolutionOp)
CONTRACTIONCONVOLUTIONLEAFCOUNT(,TensorConvolutionOp)
#undef CONTRACTIONCONVOLUTIONLEAFCOUNT
-
-
/// specialisation of the \ref LeafCount struct when the node type is TensorSlicingOp
#define SLICEOPLEAFCOUNT(CVQual)\
template <typename StartIndices, typename Sizes, typename XprType>\
@@ -139,7 +169,6 @@ SLICEOPLEAFCOUNT(const)
SLICEOPLEAFCOUNT()
#undef SLICEOPLEAFCOUNT
-
/// specialisation of the \ref LeafCount struct when the node type is TensorChippingOp
#define CHIPPINGOPLEAFCOUNT(CVQual)\
template <DenseIndex DimId, typename XprType>\
@@ -149,7 +178,7 @@ CHIPPINGOPLEAFCOUNT(const)
CHIPPINGOPLEAFCOUNT()
#undef CHIPPINGOPLEAFCOUNT
-
+///TensorStridingSlicingOp
#define SLICESTRIDEOPLEAFCOUNT(CVQual)\
template<typename StartIndices, typename StopIndices, typename Strides, typename XprType>\
struct LeafCount<CVQual TensorStridingSlicingOp<StartIndices, StopIndices, Strides, XprType> >:CategoryCount<XprType>{};
@@ -158,6 +187,24 @@ SLICESTRIDEOPLEAFCOUNT(const)
SLICESTRIDEOPLEAFCOUNT()
#undef SLICESTRIDEOPLEAFCOUNT
+//TensorImagePatchOp
+#define TENSORIMAGEPATCHOPLEAFCOUNT(CVQual)\
+template<DenseIndex Rows, DenseIndex Cols, typename XprType>\
+struct LeafCount<CVQual TensorImagePatchOp<Rows, Cols, XprType> >:CategoryCount<XprType>{};
+
+
+TENSORIMAGEPATCHOPLEAFCOUNT(const)
+TENSORIMAGEPATCHOPLEAFCOUNT()
+#undef TENSORIMAGEPATCHOPLEAFCOUNT
+
+// TensorVolumePatchOp
+#define TENSORVOLUMEPATCHOPLEAFCOUNT(CVQual)\
+template<DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, typename XprType>\
+struct LeafCount<CVQual TensorVolumePatchOp<Planes, Rows, Cols, XprType> >:CategoryCount<XprType>{};
+
+TENSORVOLUMEPATCHOPLEAFCOUNT(const)
+TENSORVOLUMEPATCHOPLEAFCOUNT()
+#undef TENSORVOLUMEPATCHOPLEAFCOUNT
} /// namespace TensorSycl
} /// namespace internal