diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h | 35 |
1 files changed, 32 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h index 6f9ab57af..e26cbdf6d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h @@ -42,6 +42,20 @@ template <typename Evaluator> struct FunctorExtractor{ }; +/// specialisation of the \ref FunctorExtractor struct when the node type does not require anything +///TensorConversionOp +#define SYCLEXTRFUNCCONVERSION(ExprNode, CVQual)\ +template <typename ArgType1, typename ArgType2, typename Dev>\ +struct FunctorExtractor<TensorEvaluator<CVQual ExprNode<ArgType1, ArgType2>, Dev> > {\ + FunctorExtractor<TensorEvaluator<ArgType2, Dev> > subExpr;\ + FunctorExtractor(const TensorEvaluator<CVQual ExprNode<ArgType1, ArgType2>, Dev>& expr)\ + : subExpr(expr.impl()) {}\ +}; + +SYCLEXTRFUNCCONVERSION(TensorConversionOp, const) +SYCLEXTRFUNCCONVERSION(TensorConversionOp, ) +#undef SYCLEXTRFUNCCONVERSION + #define SYCLEXTRTENSORMAPFIXEDSIZE(CVQual)\ template <typename Scalar_, typename Dimensions_, int Options_2, typename IndexType, int Options_, template <class> class MakePointer_, typename Dev>\ struct FunctorExtractor< TensorEvaluator <CVQual TensorMap<TensorFixedSize<Scalar_, Dimensions_, Options_2, IndexType>, Options_, MakePointer_> , Dev> >{\ @@ -169,6 +183,24 @@ SYCLEXTRFUNCREDUCTIONOP(const) SYCLEXTRFUNCREDUCTIONOP() #undef SYCLEXTRFUNCREDUCTIONOP +#define SYCLEXTRFUNCCONTRACTCONVOLUTIONOP(CVQual, ExprNode)\ +template<typename Indices, typename LhsXprType, typename RhsXprType, typename Device>\ +struct FunctorExtractor<TensorEvaluator<CVQual ExprNode<Indices, LhsXprType, RhsXprType>, Device>>{\ + typedef TensorEvaluator<CVQual ExprNode<Indices, LhsXprType, RhsXprType>, Device> Evaluator;\ + typedef typename Evaluator::Dimensions Dimensions;\ + const Dimensions m_dimensions;\ + EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }\ + FunctorExtractor(const TensorEvaluator<CVQual ExprNode<Indices, LhsXprType, RhsXprType>, Device>& expr)\ + : m_dimensions(expr.dimensions()) {}\ +}; + + +SYCLEXTRFUNCCONTRACTCONVOLUTIONOP(const,TensorContractionOp) +SYCLEXTRFUNCCONTRACTCONVOLUTIONOP(,TensorContractionOp) +SYCLEXTRFUNCCONTRACTCONVOLUTIONOP(const,TensorConvolutionOp) +SYCLEXTRFUNCCONTRACTCONVOLUTIONOP(,TensorConvolutionOp) +#undef SYCLEXTRFUNCCONTRACTCONVOLUTIONOP + /// specialisation of the \ref FunctorExtractor struct when the node type is /// const TensorSlicingOp. This is an specialisation without OP so it has to be separated. #define SYCLEXTRFUNCTSLICEOP(CVQual)\ @@ -253,9 +285,6 @@ struct FunctorExtractor<TensorEvaluator<CVQual OPEXPR<Param, LHSExpr, RHSExpr>, : lhsExpr(expr.left_impl()),rhsExpr(expr.right_impl()),func(expr.FUNCCALL) {}\ }; -// TensorContractionOp -SYCLEXTRFUNCCONTRACTCONCAT(TensorContractionOp, indices(), const) -SYCLEXTRFUNCCONTRACTCONCAT(TensorContractionOp, indices(),) // TensorConcatenationOp SYCLEXTRFUNCCONTRACTCONCAT(TensorConcatenationOp, axis(), const) SYCLEXTRFUNCCONTRACTCONCAT(TensorConcatenationOp, axis(),) |