// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Mehdi Goli Codeplay Software Ltd. // Ralph Potter Codeplay Software Ltd. // Luke Iwanski Codeplay Software Ltd. // Contact: // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. /***************************************************************** * TensorSyclextractFunctors.h * * \brief: * Used to extract all the functors allocated to each node of the expression *tree. * *****************************************************************/ #ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSORYSYCL_EXTRACT_FUNCTORS_HPP #define UNSUPPORTED_EIGEN_CXX11_SRC_TENSORYSYCL_EXTRACT_FUNCTORS_HPP namespace Eigen { namespace TensorSycl { namespace internal { /// \struct FunctorExtractor: This struct is used to extract the functors /// constructed on /// the host-side, to pack them and reuse them in reconstruction of the /// expression on the device. /// We have to do that as in Eigen the functors are not stateless so we cannot /// re-instantiate them on the device. /// We have to pass whatever instantiated to the device. template struct FunctorExtractor; /// specialisation of the \ref FunctorExtractor struct when the node type is /// TensorMap: template struct FunctorExtractor< TensorEvaluator, Dev>> { using Dimensions = typename PlainObjectType::Dimensions; const Dimensions m_dimensions; const Dimensions& dimensions() const { return m_dimensions; } FunctorExtractor( const TensorEvaluator, Dev>& expr) : m_dimensions(expr.dimensions()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// const TensorMap template struct FunctorExtractor< TensorEvaluator, Dev>> { using Dimensions = typename PlainObjectType::Dimensions; const Dimensions m_dimensions; const Dimensions& dimensions() const { return m_dimensions; } FunctorExtractor( const TensorEvaluator, Dev>& expr) : m_dimensions(expr.dimensions()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// TensorForcedEvalOp template struct FunctorExtractor, Dev>> { using Dimensions = typename Expr::Dimensions; const Dimensions m_dimensions; const Dimensions& dimensions() const { return m_dimensions; } FunctorExtractor(const TensorEvaluator, Dev>& expr) : m_dimensions(expr.dimensions()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// const TensorForcedEvalOp template struct FunctorExtractor, Dev>> { using Dimensions = typename TensorEvaluator, Dev>::Dimensions; const Dimensions m_dimensions; const Dimensions& dimensions() const { return m_dimensions; } FunctorExtractor( const TensorEvaluator, Dev>& expr) : m_dimensions(expr.dimensions()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// TensorCwiseNullaryOp template struct FunctorExtractor< TensorEvaluator, Dev>> { FunctorExtractor> rhsExpr; OP func; FunctorExtractor( TensorEvaluator, Dev>& expr) : rhsExpr(expr.impl()), func(expr.functor()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// const TensorCwiseNullaryOp template struct FunctorExtractor< TensorEvaluator, Dev>> { FunctorExtractor> rhsExpr; OP func; FunctorExtractor( const TensorEvaluator, Dev>& expr) : rhsExpr(expr.impl()), func(expr.functor()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// TensorBroadcastingOp template struct FunctorExtractor< TensorEvaluator, Dev>> { FunctorExtractor> rhsExpr; OP func; FunctorExtractor( const TensorEvaluator, Dev>& expr) : rhsExpr(expr.impl()), func(expr.functor()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// const TensorBroadcastingOp template struct FunctorExtractor< TensorEvaluator, Dev>> { FunctorExtractor> rhsExpr; OP func; FunctorExtractor( const TensorEvaluator, Dev>& expr) : rhsExpr(expr.impl()), func(expr.functor()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// TensorCwiseUnaryOp template struct FunctorExtractor, Dev>> { FunctorExtractor> rhsExpr; OP func; FunctorExtractor( const TensorEvaluator, Dev>& expr) : rhsExpr(expr.impl()), func(expr.functor()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// const TensorCwiseUnaryOp template struct FunctorExtractor< TensorEvaluator, Dev>> { FunctorExtractor> rhsExpr; OP func; FunctorExtractor( const TensorEvaluator, Dev>& expr) : rhsExpr(expr.impl()), func(expr.functor()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// TensorCwiseBinaryOp template struct FunctorExtractor< TensorEvaluator, Dev>> { FunctorExtractor> lhsExpr; FunctorExtractor> rhsExpr; OP func; FunctorExtractor( const TensorEvaluator, Dev>& expr) : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()), func(expr.functor()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// const TensorCwiseBinaryOp template struct FunctorExtractor< TensorEvaluator, Dev>> { FunctorExtractor> lhsExpr; FunctorExtractor> rhsExpr; OP func; FunctorExtractor(const TensorEvaluator< const TensorCwiseBinaryOp, Dev>& expr) : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()), func(expr.functor()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// const TensorCwiseTernaryOp template struct FunctorExtractor, Dev>> { FunctorExtractor> arg1Expr; FunctorExtractor> arg2Expr; FunctorExtractor> arg3Expr; OP func; FunctorExtractor(const TensorEvaluator< const TensorCwiseTernaryOp, Dev>& expr) : arg1Expr(expr.arg1Impl()), arg2Expr(expr.arg2Impl()), arg3Expr(expr.arg3Impl()), func(expr.functor()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// TensorCwiseTernaryOp template struct FunctorExtractor, Dev>> { FunctorExtractor> arg1Expr; FunctorExtractor> arg2Expr; FunctorExtractor> arg3Expr; OP func; FunctorExtractor( const TensorEvaluator< TensorCwiseTernaryOp, Dev>& expr) : arg1Expr(expr.arg1Impl()), arg2Expr(expr.arg2Impl()), arg3Expr(expr.arg3Impl()), func(expr.functor()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// const TensorCwiseSelectOp template struct FunctorExtractor< TensorEvaluator, Dev>> { FunctorExtractor> ifExpr; FunctorExtractor> thenExpr; FunctorExtractor> elseExpr; FunctorExtractor(const TensorEvaluator< const TensorSelectOp, Dev>& expr) : ifExpr(expr.cond_impl()), thenExpr(expr.then_impl()), elseExpr(expr.else_impl()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// TensorCwiseSelectOp template struct FunctorExtractor< TensorEvaluator, Dev>> { FunctorExtractor ifExpr; FunctorExtractor thenExpr; FunctorExtractor elseExpr; FunctorExtractor( const TensorEvaluator, Dev>& expr) : ifExpr(expr.cond_impl()), thenExpr(expr.then_impl()), elseExpr(expr.else_impl()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// TensorAssignOp template struct FunctorExtractor< TensorEvaluator, Dev>> { FunctorExtractor> lhsExpr; FunctorExtractor> rhsExpr; FunctorExtractor( const TensorEvaluator, Dev>& expr) : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// const TensorAssignOp template struct FunctorExtractor< TensorEvaluator, Dev>> { FunctorExtractor> lhsExpr; FunctorExtractor> rhsExpr; FunctorExtractor( const TensorEvaluator, Dev>& expr) : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// TensorEvalToOp template struct FunctorExtractor, Dev>> { FunctorExtractor> rhsExpr; FunctorExtractor(const TensorEvaluator, Dev>& expr) : rhsExpr(expr.impl()) {} }; /// specialisation of the \ref FunctorExtractor struct when the node type is /// const TensorEvalToOp template struct FunctorExtractor, Dev>> { FunctorExtractor> rhsExpr; FunctorExtractor( const TensorEvaluator, Dev>& expr) : rhsExpr(expr.impl()) {} }; /// template deduction function for FunctorExtractor template auto extractFunctors(const Evaluator& evaluator) -> FunctorExtractor { return FunctorExtractor(evaluator); } } // namespace internal } // namespace TensorSycl } // namespace Eigen #endif // UNSUPPORTED_EIGEN_CXX11_SRC_TENSORYSYCL_EXTRACT_FUNCTORS_HPP