diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h | 297 |
1 files changed, 69 insertions, 228 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h index f69c5afcb..801b4f5d7 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h @@ -19,8 +19,8 @@ * *****************************************************************/ -#ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSORYSYCL_EXTRACT_FUNCTORS_HPP -#define UNSUPPORTED_EIGEN_CXX11_SRC_TENSORYSYCL_EXTRACT_FUNCTORS_HPP +#ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP +#define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP namespace Eigen { namespace TensorSycl { @@ -31,283 +31,124 @@ namespace internal { /// expression on the device. /// We have to do that as in Eigen the functors are not stateless so we cannot /// re-instantiate them on the device. -/// We have to pass whatever instantiated to the device. -template <typename Evaluator> -struct FunctorExtractor; - -/// specialisation of the \ref FunctorExtractor struct when the node type is -/// TensorMap: -template <typename PlainObjectType, int Options_, typename Dev> -struct FunctorExtractor< - TensorEvaluator<TensorMap<PlainObjectType, Options_>, Dev>> { - using Dimensions = typename PlainObjectType::Dimensions; - const Dimensions m_dimensions; - const Dimensions& dimensions() const { return m_dimensions; } - FunctorExtractor( - const TensorEvaluator<TensorMap<PlainObjectType, Options_>, Dev>& expr) - : m_dimensions(expr.dimensions()) {} -}; - -/// specialisation of the \ref FunctorExtractor struct when the node type is -/// const TensorMap -template <typename PlainObjectType, int Options_, typename Dev> -struct FunctorExtractor< - TensorEvaluator<const TensorMap<PlainObjectType, Options_>, Dev>> { - using Dimensions = typename PlainObjectType::Dimensions; +/// We have to pass instantiated functors to the device. +// This struct is used for leafNode (TensorMap) and nodes behaving like leafNode (TensorForcedEval). +template <typename Evaluator> struct FunctorExtractor{ + typedef typename Evaluator::Dimensions Dimensions; const Dimensions m_dimensions; const Dimensions& dimensions() const { return m_dimensions; } - FunctorExtractor( - const TensorEvaluator<const TensorMap<PlainObjectType, Options_>, Dev>& - expr) - : m_dimensions(expr.dimensions()) {} -}; - -/// specialisation of the \ref FunctorExtractor struct when the node type is -/// TensorForcedEvalOp -template <typename Expr, typename Dev> -struct FunctorExtractor<TensorEvaluator<TensorForcedEvalOp<Expr>, Dev>> { - using Dimensions = typename Expr::Dimensions; - const Dimensions m_dimensions; - const Dimensions& dimensions() const { return m_dimensions; } - FunctorExtractor(const TensorEvaluator<TensorForcedEvalOp<Expr>, Dev>& expr) - : m_dimensions(expr.dimensions()) {} -}; - -/// specialisation of the \ref FunctorExtractor struct when the node type is -/// const TensorForcedEvalOp -template <typename Expr, typename Dev> -struct FunctorExtractor<TensorEvaluator<const TensorForcedEvalOp<Expr>, Dev>> { - using Dimensions = - typename TensorEvaluator<const TensorForcedEvalOp<Expr>, Dev>::Dimensions; - const Dimensions m_dimensions; - const Dimensions& dimensions() const { return m_dimensions; } - FunctorExtractor( - const TensorEvaluator<const TensorForcedEvalOp<Expr>, Dev>& expr) - : m_dimensions(expr.dimensions()) {} -}; - -/// specialisation of the \ref FunctorExtractor struct when the node type is -/// TensorCwiseNullaryOp -template <typename OP, typename RHSExpr, typename Dev> -struct FunctorExtractor< - TensorEvaluator<TensorCwiseNullaryOp<OP, RHSExpr>, Dev>> { - FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr; - OP func; - FunctorExtractor( - TensorEvaluator<TensorCwiseNullaryOp<OP, RHSExpr>, Dev>& expr) - : rhsExpr(expr.impl()), func(expr.functor()) {} -}; + FunctorExtractor(const Evaluator& expr) + : m_dimensions(expr.dimensions()) {} -/// specialisation of the \ref FunctorExtractor struct when the node type is -/// const TensorCwiseNullaryOp -template <typename OP, typename RHSExpr, typename Dev> -struct FunctorExtractor< - TensorEvaluator<const TensorCwiseNullaryOp<OP, RHSExpr>, Dev>> { - FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr; - OP func; - FunctorExtractor( - const TensorEvaluator<const TensorCwiseNullaryOp<OP, RHSExpr>, Dev>& expr) - : rhsExpr(expr.impl()), func(expr.functor()) {} -}; - -/// specialisation of the \ref FunctorExtractor struct when the node type is -/// TensorBroadcastingOp -template <typename OP, typename RHSExpr, typename Dev> -struct FunctorExtractor< - TensorEvaluator<TensorBroadcastingOp<OP, RHSExpr>, Dev>> { - FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr; - OP func; - FunctorExtractor( - const TensorEvaluator<TensorBroadcastingOp<OP, RHSExpr>, Dev>& expr) - : rhsExpr(expr.impl()), func(expr.functor()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is -/// const TensorBroadcastingOp -template <typename OP, typename RHSExpr, typename Dev> -struct FunctorExtractor< - TensorEvaluator<const TensorBroadcastingOp<OP, RHSExpr>, Dev>> { - FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr; +/// const TensorCwiseNullaryOp, const TensorCwiseUnaryOp, and const TensorBroadcastingOp +template <template <class, class> class UnaryCategory, typename OP, typename RHSExpr, typename Dev> +struct FunctorExtractor<TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev> > { + FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr; OP func; - FunctorExtractor( - const TensorEvaluator<const TensorBroadcastingOp<OP, RHSExpr>, Dev>& expr) - : rhsExpr(expr.impl()), func(expr.functor()) {} + FunctorExtractor(const TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev>& expr) + : rhsExpr(expr.impl()), func(expr.functor()) {} }; - /// specialisation of the \ref FunctorExtractor struct when the node type is -/// TensorCwiseUnaryOp -template <typename OP, typename RHSExpr, typename Dev> -struct FunctorExtractor<TensorEvaluator<TensorCwiseUnaryOp<OP, RHSExpr>, Dev>> { - FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr; - OP func; - FunctorExtractor( - const TensorEvaluator<TensorCwiseUnaryOp<OP, RHSExpr>, Dev>& expr) - : rhsExpr(expr.impl()), func(expr.functor()) {} -}; +/// TensorCwiseNullaryOp, TensorCwiseUnaryOp, and TensorBroadcastingOp +template <template <class, class> class UnaryCategory, typename OP, typename RHSExpr, typename Dev> +struct FunctorExtractor<TensorEvaluator<UnaryCategory<OP, RHSExpr>, Dev> > +: FunctorExtractor<TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev> >{}; /// specialisation of the \ref FunctorExtractor struct when the node type is -/// const TensorCwiseUnaryOp -template <typename OP, typename RHSExpr, typename Dev> -struct FunctorExtractor< - TensorEvaluator<const TensorCwiseUnaryOp<OP, RHSExpr>, Dev>> { - FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr; - OP func; - FunctorExtractor( - const TensorEvaluator<const TensorCwiseUnaryOp<OP, RHSExpr>, Dev>& expr) - : rhsExpr(expr.impl()), func(expr.functor()) {} -}; - -/// specialisation of the \ref FunctorExtractor struct when the node type is -/// TensorCwiseBinaryOp -template <typename OP, typename LHSExpr, typename RHSExpr, typename Dev> -struct FunctorExtractor< - TensorEvaluator<TensorCwiseBinaryOp<OP, LHSExpr, RHSExpr>, Dev>> { - FunctorExtractor<TensorEvaluator<LHSExpr, Dev>> lhsExpr; - FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr; +/// const TensorCwiseBinaryOp +template <template<class, class, class> class BinaryCategory, typename OP, typename LHSExpr, typename RHSExpr, typename Dev> +struct FunctorExtractor<TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> > { + FunctorExtractor<TensorEvaluator<LHSExpr, Dev> > lhsExpr; + FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr; OP func; - FunctorExtractor( - const TensorEvaluator<TensorCwiseBinaryOp<OP, LHSExpr, RHSExpr>, Dev>& - expr) - : lhsExpr(expr.left_impl()), - rhsExpr(expr.right_impl()), - func(expr.functor()) {} + FunctorExtractor(const TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev>& expr) + : lhsExpr(expr.left_impl()),rhsExpr(expr.right_impl()),func(expr.functor()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// const TensorCwiseBinaryOp -template <typename OP, typename LHSExpr, typename RHSExpr, typename Dev> -struct FunctorExtractor< - TensorEvaluator<const TensorCwiseBinaryOp<OP, LHSExpr, RHSExpr>, Dev>> { - FunctorExtractor<TensorEvaluator<LHSExpr, Dev>> lhsExpr; - FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr; - OP func; - FunctorExtractor(const TensorEvaluator< - const TensorCwiseBinaryOp<OP, LHSExpr, RHSExpr>, Dev>& expr) - : lhsExpr(expr.left_impl()), - rhsExpr(expr.right_impl()), - func(expr.functor()) {} -}; +template <template <class, class, class> class BinaryCategory, typename OP, typename LHSExpr, typename RHSExpr, typename Dev> +struct FunctorExtractor<TensorEvaluator<BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> > +: FunctorExtractor<TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> >{}; /// specialisation of the \ref FunctorExtractor struct when the node type is /// const TensorCwiseTernaryOp -template <typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr, - typename Dev> -struct FunctorExtractor<TensorEvaluator< - const TensorCwiseTernaryOp<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev>> { - FunctorExtractor<TensorEvaluator<Arg1Expr, Dev>> arg1Expr; - FunctorExtractor<TensorEvaluator<Arg2Expr, Dev>> arg2Expr; - FunctorExtractor<TensorEvaluator<Arg3Expr, Dev>> arg3Expr; +template <template <class, class, class, class> class TernaryCategory, typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr,typename Dev> +struct FunctorExtractor<TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> > { + FunctorExtractor<TensorEvaluator<Arg1Expr, Dev> > arg1Expr; + FunctorExtractor<TensorEvaluator<Arg2Expr, Dev> > arg2Expr; + FunctorExtractor<TensorEvaluator<Arg3Expr, Dev> > arg3Expr; OP func; - FunctorExtractor(const TensorEvaluator< - const TensorCwiseTernaryOp<OP, Arg1Expr, Arg2Expr, Arg3Expr>, - Dev>& expr) - : arg1Expr(expr.arg1Impl()), - arg2Expr(expr.arg2Impl()), - arg3Expr(expr.arg3Impl()), - func(expr.functor()) {} + FunctorExtractor(const TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev>& expr) + : arg1Expr(expr.arg1Impl()), arg2Expr(expr.arg2Impl()), arg3Expr(expr.arg3Impl()), func(expr.functor()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// TensorCwiseTernaryOp -template <typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr, - typename Dev> -struct FunctorExtractor<TensorEvaluator< - TensorCwiseTernaryOp<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev>> { - FunctorExtractor<TensorEvaluator<Arg1Expr, Dev>> arg1Expr; - FunctorExtractor<TensorEvaluator<Arg2Expr, Dev>> arg2Expr; - FunctorExtractor<TensorEvaluator<Arg3Expr, Dev>> arg3Expr; - OP func; - FunctorExtractor( - const TensorEvaluator< - TensorCwiseTernaryOp<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev>& expr) - : arg1Expr(expr.arg1Impl()), - arg2Expr(expr.arg2Impl()), - arg3Expr(expr.arg3Impl()), - func(expr.functor()) {} -}; +template <template <class, class, class, class> class TernaryCategory, typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr, typename Dev> +struct FunctorExtractor<TensorEvaluator< TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> > +:FunctorExtractor<TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> >{}; /// specialisation of the \ref FunctorExtractor struct when the node type is -/// const TensorCwiseSelectOp +/// const TensorCwiseSelectOp. This is an specialisation without OP so it has to be separated. template <typename IfExpr, typename ThenExpr, typename ElseExpr, typename Dev> -struct FunctorExtractor< - TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>> { - FunctorExtractor<TensorEvaluator<IfExpr, Dev>> ifExpr; - FunctorExtractor<TensorEvaluator<ThenExpr, Dev>> thenExpr; - FunctorExtractor<TensorEvaluator<ElseExpr, Dev>> elseExpr; - FunctorExtractor(const TensorEvaluator< - const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>& expr) - : ifExpr(expr.cond_impl()), - thenExpr(expr.then_impl()), - elseExpr(expr.else_impl()) {} +struct FunctorExtractor< TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > { + FunctorExtractor<TensorEvaluator<IfExpr, Dev> > ifExpr; + FunctorExtractor<TensorEvaluator<ThenExpr, Dev> > thenExpr; + FunctorExtractor<TensorEvaluator<ElseExpr, Dev> > elseExpr; + FunctorExtractor(const TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>& expr) + : ifExpr(expr.cond_impl()), thenExpr(expr.then_impl()), elseExpr(expr.else_impl()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is -/// TensorCwiseSelectOp +/// TensorCwiseSelectOp. This is an specialisation without OP so it has to be separated template <typename IfExpr, typename ThenExpr, typename ElseExpr, typename Dev> -struct FunctorExtractor< - TensorEvaluator<TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>> { - FunctorExtractor<IfExpr> ifExpr; - FunctorExtractor<ThenExpr> thenExpr; - FunctorExtractor<ElseExpr> elseExpr; - FunctorExtractor( - const TensorEvaluator<TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>& - expr) - : ifExpr(expr.cond_impl()), - thenExpr(expr.then_impl()), - elseExpr(expr.else_impl()) {} -}; +struct FunctorExtractor<TensorEvaluator<TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > +:FunctorExtractor< TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > {}; /// specialisation of the \ref FunctorExtractor struct when the node type is -/// TensorAssignOp +/// const TensorAssignOp. This is an specialisation without OP so it has to be separated. template <typename LHSExpr, typename RHSExpr, typename Dev> -struct FunctorExtractor< - TensorEvaluator<TensorAssignOp<LHSExpr, RHSExpr>, Dev>> { - FunctorExtractor<TensorEvaluator<LHSExpr, Dev>> lhsExpr; - FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr; - FunctorExtractor( - const TensorEvaluator<TensorAssignOp<LHSExpr, RHSExpr>, Dev>& expr) - : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()) {} +struct FunctorExtractor<TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev> > { + FunctorExtractor<TensorEvaluator<LHSExpr, Dev> > lhsExpr; + FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr; + FunctorExtractor(const TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev>& expr) + : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is -/// const TensorAssignOp +/// TensorAssignOp. This is an specialisation without OP so it has to be separated. template <typename LHSExpr, typename RHSExpr, typename Dev> -struct FunctorExtractor< - TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev>> { - FunctorExtractor<TensorEvaluator<LHSExpr, Dev>> lhsExpr; - FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr; - FunctorExtractor( - const TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev>& expr) - : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()) {} -}; +struct FunctorExtractor<TensorEvaluator<TensorAssignOp<LHSExpr, RHSExpr>, Dev> > +:FunctorExtractor<TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev> >{}; + /// specialisation of the \ref FunctorExtractor struct when the node type is -/// TensorEvalToOp +/// const TensorEvalToOp, This is an specialisation without OP so it has to be separated. template <typename RHSExpr, typename Dev> -struct FunctorExtractor<TensorEvaluator<TensorEvalToOp<RHSExpr>, Dev>> { - FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr; - FunctorExtractor(const TensorEvaluator<TensorEvalToOp<RHSExpr>, Dev>& expr) - : rhsExpr(expr.impl()) {} +struct FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev> > { + FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr; + FunctorExtractor(const TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev>& expr) + : rhsExpr(expr.impl()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is -/// const TensorEvalToOp +/// TensorEvalToOp. This is a specialisation without OP so it has to be separated. template <typename RHSExpr, typename Dev> -struct FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev>> { - FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr; - FunctorExtractor( - const TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev>& expr) - : rhsExpr(expr.impl()) {} -}; +struct FunctorExtractor<TensorEvaluator<TensorEvalToOp<RHSExpr>, Dev> > +: FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev> > {}; + /// template deduction function for FunctorExtractor template <typename Evaluator> -auto extractFunctors(const Evaluator& evaluator) - -> FunctorExtractor<Evaluator> { +auto inline extractFunctors(const Evaluator& evaluator)-> FunctorExtractor<Evaluator> { return FunctorExtractor<Evaluator>(evaluator); } } // namespace internal } // namespace TensorSycl } // namespace Eigen -#endif // UNSUPPORTED_EIGEN_CXX11_SRC_TENSORYSYCL_EXTRACT_FUNCTORS_HPP +#endif // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP |