aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h35
1 files changed, 32 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h
index 6f9ab57af..e26cbdf6d 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h
@@ -42,6 +42,20 @@ template <typename Evaluator> struct FunctorExtractor{
};
+/// specialisation of the \ref FunctorExtractor struct when the node type does not require anything
+///TensorConversionOp
+#define SYCLEXTRFUNCCONVERSION(ExprNode, CVQual)\
+template <typename ArgType1, typename ArgType2, typename Dev>\
+struct FunctorExtractor<TensorEvaluator<CVQual ExprNode<ArgType1, ArgType2>, Dev> > {\
+ FunctorExtractor<TensorEvaluator<ArgType2, Dev> > subExpr;\
+ FunctorExtractor(const TensorEvaluator<CVQual ExprNode<ArgType1, ArgType2>, Dev>& expr)\
+ : subExpr(expr.impl()) {}\
+};
+
+SYCLEXTRFUNCCONVERSION(TensorConversionOp, const)
+SYCLEXTRFUNCCONVERSION(TensorConversionOp, )
+#undef SYCLEXTRFUNCCONVERSION
+
#define SYCLEXTRTENSORMAPFIXEDSIZE(CVQual)\
template <typename Scalar_, typename Dimensions_, int Options_2, typename IndexType, int Options_, template <class> class MakePointer_, typename Dev>\
struct FunctorExtractor< TensorEvaluator <CVQual TensorMap<TensorFixedSize<Scalar_, Dimensions_, Options_2, IndexType>, Options_, MakePointer_> , Dev> >{\
@@ -169,6 +183,24 @@ SYCLEXTRFUNCREDUCTIONOP(const)
SYCLEXTRFUNCREDUCTIONOP()
#undef SYCLEXTRFUNCREDUCTIONOP
+#define SYCLEXTRFUNCCONTRACTCONVOLUTIONOP(CVQual, ExprNode)\
+template<typename Indices, typename LhsXprType, typename RhsXprType, typename Device>\
+struct FunctorExtractor<TensorEvaluator<CVQual ExprNode<Indices, LhsXprType, RhsXprType>, Device>>{\
+ typedef TensorEvaluator<CVQual ExprNode<Indices, LhsXprType, RhsXprType>, Device> Evaluator;\
+ typedef typename Evaluator::Dimensions Dimensions;\
+ const Dimensions m_dimensions;\
+ EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }\
+ FunctorExtractor(const TensorEvaluator<CVQual ExprNode<Indices, LhsXprType, RhsXprType>, Device>& expr)\
+ : m_dimensions(expr.dimensions()) {}\
+};
+
+
+SYCLEXTRFUNCCONTRACTCONVOLUTIONOP(const,TensorContractionOp)
+SYCLEXTRFUNCCONTRACTCONVOLUTIONOP(,TensorContractionOp)
+SYCLEXTRFUNCCONTRACTCONVOLUTIONOP(const,TensorConvolutionOp)
+SYCLEXTRFUNCCONTRACTCONVOLUTIONOP(,TensorConvolutionOp)
+#undef SYCLEXTRFUNCCONTRACTCONVOLUTIONOP
+
/// specialisation of the \ref FunctorExtractor struct when the node type is
/// const TensorSlicingOp. This is an specialisation without OP so it has to be separated.
#define SYCLEXTRFUNCTSLICEOP(CVQual)\
@@ -253,9 +285,6 @@ struct FunctorExtractor<TensorEvaluator<CVQual OPEXPR<Param, LHSExpr, RHSExpr>,
: lhsExpr(expr.left_impl()),rhsExpr(expr.right_impl()),func(expr.FUNCCALL) {}\
};
-// TensorContractionOp
-SYCLEXTRFUNCCONTRACTCONCAT(TensorContractionOp, indices(), const)
-SYCLEXTRFUNCCONTRACTCONCAT(TensorContractionOp, indices(),)
// TensorConcatenationOp
SYCLEXTRFUNCCONTRACTCONCAT(TensorConcatenationOp, axis(), const)
SYCLEXTRFUNCCONTRACTCONCAT(TensorConcatenationOp, axis(),)