aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h297
1 files changed, 69 insertions, 228 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h
index f69c5afcb..801b4f5d7 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h
@@ -19,8 +19,8 @@
*
*****************************************************************/
-#ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSORYSYCL_EXTRACT_FUNCTORS_HPP
-#define UNSUPPORTED_EIGEN_CXX11_SRC_TENSORYSYCL_EXTRACT_FUNCTORS_HPP
+#ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP
+#define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP
namespace Eigen {
namespace TensorSycl {
@@ -31,283 +31,124 @@ namespace internal {
/// expression on the device.
/// We have to do that as in Eigen the functors are not stateless so we cannot
/// re-instantiate them on the device.
-/// We have to pass whatever instantiated to the device.
-template <typename Evaluator>
-struct FunctorExtractor;
-
-/// specialisation of the \ref FunctorExtractor struct when the node type is
-/// TensorMap:
-template <typename PlainObjectType, int Options_, typename Dev>
-struct FunctorExtractor<
- TensorEvaluator<TensorMap<PlainObjectType, Options_>, Dev>> {
- using Dimensions = typename PlainObjectType::Dimensions;
- const Dimensions m_dimensions;
- const Dimensions& dimensions() const { return m_dimensions; }
- FunctorExtractor(
- const TensorEvaluator<TensorMap<PlainObjectType, Options_>, Dev>& expr)
- : m_dimensions(expr.dimensions()) {}
-};
-
-/// specialisation of the \ref FunctorExtractor struct when the node type is
-/// const TensorMap
-template <typename PlainObjectType, int Options_, typename Dev>
-struct FunctorExtractor<
- TensorEvaluator<const TensorMap<PlainObjectType, Options_>, Dev>> {
- using Dimensions = typename PlainObjectType::Dimensions;
+/// We have to pass instantiated functors to the device.
+// This struct is used for leafNode (TensorMap) and nodes behaving like leafNode (TensorForcedEval).
+template <typename Evaluator> struct FunctorExtractor{
+ typedef typename Evaluator::Dimensions Dimensions;
const Dimensions m_dimensions;
const Dimensions& dimensions() const { return m_dimensions; }
- FunctorExtractor(
- const TensorEvaluator<const TensorMap<PlainObjectType, Options_>, Dev>&
- expr)
- : m_dimensions(expr.dimensions()) {}
-};
-
-/// specialisation of the \ref FunctorExtractor struct when the node type is
-/// TensorForcedEvalOp
-template <typename Expr, typename Dev>
-struct FunctorExtractor<TensorEvaluator<TensorForcedEvalOp<Expr>, Dev>> {
- using Dimensions = typename Expr::Dimensions;
- const Dimensions m_dimensions;
- const Dimensions& dimensions() const { return m_dimensions; }
- FunctorExtractor(const TensorEvaluator<TensorForcedEvalOp<Expr>, Dev>& expr)
- : m_dimensions(expr.dimensions()) {}
-};
-
-/// specialisation of the \ref FunctorExtractor struct when the node type is
-/// const TensorForcedEvalOp
-template <typename Expr, typename Dev>
-struct FunctorExtractor<TensorEvaluator<const TensorForcedEvalOp<Expr>, Dev>> {
- using Dimensions =
- typename TensorEvaluator<const TensorForcedEvalOp<Expr>, Dev>::Dimensions;
- const Dimensions m_dimensions;
- const Dimensions& dimensions() const { return m_dimensions; }
- FunctorExtractor(
- const TensorEvaluator<const TensorForcedEvalOp<Expr>, Dev>& expr)
- : m_dimensions(expr.dimensions()) {}
-};
-
-/// specialisation of the \ref FunctorExtractor struct when the node type is
-/// TensorCwiseNullaryOp
-template <typename OP, typename RHSExpr, typename Dev>
-struct FunctorExtractor<
- TensorEvaluator<TensorCwiseNullaryOp<OP, RHSExpr>, Dev>> {
- FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
- OP func;
- FunctorExtractor(
- TensorEvaluator<TensorCwiseNullaryOp<OP, RHSExpr>, Dev>& expr)
- : rhsExpr(expr.impl()), func(expr.functor()) {}
-};
+ FunctorExtractor(const Evaluator& expr)
+ : m_dimensions(expr.dimensions()) {}
-/// specialisation of the \ref FunctorExtractor struct when the node type is
-/// const TensorCwiseNullaryOp
-template <typename OP, typename RHSExpr, typename Dev>
-struct FunctorExtractor<
- TensorEvaluator<const TensorCwiseNullaryOp<OP, RHSExpr>, Dev>> {
- FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
- OP func;
- FunctorExtractor(
- const TensorEvaluator<const TensorCwiseNullaryOp<OP, RHSExpr>, Dev>& expr)
- : rhsExpr(expr.impl()), func(expr.functor()) {}
-};
-
-/// specialisation of the \ref FunctorExtractor struct when the node type is
-/// TensorBroadcastingOp
-template <typename OP, typename RHSExpr, typename Dev>
-struct FunctorExtractor<
- TensorEvaluator<TensorBroadcastingOp<OP, RHSExpr>, Dev>> {
- FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
- OP func;
- FunctorExtractor(
- const TensorEvaluator<TensorBroadcastingOp<OP, RHSExpr>, Dev>& expr)
- : rhsExpr(expr.impl()), func(expr.functor()) {}
};
/// specialisation of the \ref FunctorExtractor struct when the node type is
-/// const TensorBroadcastingOp
-template <typename OP, typename RHSExpr, typename Dev>
-struct FunctorExtractor<
- TensorEvaluator<const TensorBroadcastingOp<OP, RHSExpr>, Dev>> {
- FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
+/// const TensorCwiseNullaryOp, const TensorCwiseUnaryOp, and const TensorBroadcastingOp
+template <template <class, class> class UnaryCategory, typename OP, typename RHSExpr, typename Dev>
+struct FunctorExtractor<TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev> > {
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
OP func;
- FunctorExtractor(
- const TensorEvaluator<const TensorBroadcastingOp<OP, RHSExpr>, Dev>& expr)
- : rhsExpr(expr.impl()), func(expr.functor()) {}
+ FunctorExtractor(const TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev>& expr)
+ : rhsExpr(expr.impl()), func(expr.functor()) {}
};
-
/// specialisation of the \ref FunctorExtractor struct when the node type is
-/// TensorCwiseUnaryOp
-template <typename OP, typename RHSExpr, typename Dev>
-struct FunctorExtractor<TensorEvaluator<TensorCwiseUnaryOp<OP, RHSExpr>, Dev>> {
- FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
- OP func;
- FunctorExtractor(
- const TensorEvaluator<TensorCwiseUnaryOp<OP, RHSExpr>, Dev>& expr)
- : rhsExpr(expr.impl()), func(expr.functor()) {}
-};
+/// TensorCwiseNullaryOp, TensorCwiseUnaryOp, and TensorBroadcastingOp
+template <template <class, class> class UnaryCategory, typename OP, typename RHSExpr, typename Dev>
+struct FunctorExtractor<TensorEvaluator<UnaryCategory<OP, RHSExpr>, Dev> >
+: FunctorExtractor<TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev> >{};
/// specialisation of the \ref FunctorExtractor struct when the node type is
-/// const TensorCwiseUnaryOp
-template <typename OP, typename RHSExpr, typename Dev>
-struct FunctorExtractor<
- TensorEvaluator<const TensorCwiseUnaryOp<OP, RHSExpr>, Dev>> {
- FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
- OP func;
- FunctorExtractor(
- const TensorEvaluator<const TensorCwiseUnaryOp<OP, RHSExpr>, Dev>& expr)
- : rhsExpr(expr.impl()), func(expr.functor()) {}
-};
-
-/// specialisation of the \ref FunctorExtractor struct when the node type is
-/// TensorCwiseBinaryOp
-template <typename OP, typename LHSExpr, typename RHSExpr, typename Dev>
-struct FunctorExtractor<
- TensorEvaluator<TensorCwiseBinaryOp<OP, LHSExpr, RHSExpr>, Dev>> {
- FunctorExtractor<TensorEvaluator<LHSExpr, Dev>> lhsExpr;
- FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
+/// const TensorCwiseBinaryOp
+template <template<class, class, class> class BinaryCategory, typename OP, typename LHSExpr, typename RHSExpr, typename Dev>
+struct FunctorExtractor<TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> > {
+ FunctorExtractor<TensorEvaluator<LHSExpr, Dev> > lhsExpr;
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
OP func;
- FunctorExtractor(
- const TensorEvaluator<TensorCwiseBinaryOp<OP, LHSExpr, RHSExpr>, Dev>&
- expr)
- : lhsExpr(expr.left_impl()),
- rhsExpr(expr.right_impl()),
- func(expr.functor()) {}
+ FunctorExtractor(const TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev>& expr)
+ : lhsExpr(expr.left_impl()),rhsExpr(expr.right_impl()),func(expr.functor()) {}
};
/// specialisation of the \ref FunctorExtractor struct when the node type is
/// const TensorCwiseBinaryOp
-template <typename OP, typename LHSExpr, typename RHSExpr, typename Dev>
-struct FunctorExtractor<
- TensorEvaluator<const TensorCwiseBinaryOp<OP, LHSExpr, RHSExpr>, Dev>> {
- FunctorExtractor<TensorEvaluator<LHSExpr, Dev>> lhsExpr;
- FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
- OP func;
- FunctorExtractor(const TensorEvaluator<
- const TensorCwiseBinaryOp<OP, LHSExpr, RHSExpr>, Dev>& expr)
- : lhsExpr(expr.left_impl()),
- rhsExpr(expr.right_impl()),
- func(expr.functor()) {}
-};
+template <template <class, class, class> class BinaryCategory, typename OP, typename LHSExpr, typename RHSExpr, typename Dev>
+struct FunctorExtractor<TensorEvaluator<BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> >
+: FunctorExtractor<TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> >{};
/// specialisation of the \ref FunctorExtractor struct when the node type is
/// const TensorCwiseTernaryOp
-template <typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr,
- typename Dev>
-struct FunctorExtractor<TensorEvaluator<
- const TensorCwiseTernaryOp<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev>> {
- FunctorExtractor<TensorEvaluator<Arg1Expr, Dev>> arg1Expr;
- FunctorExtractor<TensorEvaluator<Arg2Expr, Dev>> arg2Expr;
- FunctorExtractor<TensorEvaluator<Arg3Expr, Dev>> arg3Expr;
+template <template <class, class, class, class> class TernaryCategory, typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr,typename Dev>
+struct FunctorExtractor<TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> > {
+ FunctorExtractor<TensorEvaluator<Arg1Expr, Dev> > arg1Expr;
+ FunctorExtractor<TensorEvaluator<Arg2Expr, Dev> > arg2Expr;
+ FunctorExtractor<TensorEvaluator<Arg3Expr, Dev> > arg3Expr;
OP func;
- FunctorExtractor(const TensorEvaluator<
- const TensorCwiseTernaryOp<OP, Arg1Expr, Arg2Expr, Arg3Expr>,
- Dev>& expr)
- : arg1Expr(expr.arg1Impl()),
- arg2Expr(expr.arg2Impl()),
- arg3Expr(expr.arg3Impl()),
- func(expr.functor()) {}
+ FunctorExtractor(const TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev>& expr)
+ : arg1Expr(expr.arg1Impl()), arg2Expr(expr.arg2Impl()), arg3Expr(expr.arg3Impl()), func(expr.functor()) {}
};
/// specialisation of the \ref FunctorExtractor struct when the node type is
/// TensorCwiseTernaryOp
-template <typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr,
- typename Dev>
-struct FunctorExtractor<TensorEvaluator<
- TensorCwiseTernaryOp<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev>> {
- FunctorExtractor<TensorEvaluator<Arg1Expr, Dev>> arg1Expr;
- FunctorExtractor<TensorEvaluator<Arg2Expr, Dev>> arg2Expr;
- FunctorExtractor<TensorEvaluator<Arg3Expr, Dev>> arg3Expr;
- OP func;
- FunctorExtractor(
- const TensorEvaluator<
- TensorCwiseTernaryOp<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev>& expr)
- : arg1Expr(expr.arg1Impl()),
- arg2Expr(expr.arg2Impl()),
- arg3Expr(expr.arg3Impl()),
- func(expr.functor()) {}
-};
+template <template <class, class, class, class> class TernaryCategory, typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr, typename Dev>
+struct FunctorExtractor<TensorEvaluator< TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> >
+:FunctorExtractor<TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> >{};
/// specialisation of the \ref FunctorExtractor struct when the node type is
-/// const TensorCwiseSelectOp
+/// const TensorCwiseSelectOp. This is an specialisation without OP so it has to be separated.
template <typename IfExpr, typename ThenExpr, typename ElseExpr, typename Dev>
-struct FunctorExtractor<
- TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>> {
- FunctorExtractor<TensorEvaluator<IfExpr, Dev>> ifExpr;
- FunctorExtractor<TensorEvaluator<ThenExpr, Dev>> thenExpr;
- FunctorExtractor<TensorEvaluator<ElseExpr, Dev>> elseExpr;
- FunctorExtractor(const TensorEvaluator<
- const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>& expr)
- : ifExpr(expr.cond_impl()),
- thenExpr(expr.then_impl()),
- elseExpr(expr.else_impl()) {}
+struct FunctorExtractor< TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > {
+ FunctorExtractor<TensorEvaluator<IfExpr, Dev> > ifExpr;
+ FunctorExtractor<TensorEvaluator<ThenExpr, Dev> > thenExpr;
+ FunctorExtractor<TensorEvaluator<ElseExpr, Dev> > elseExpr;
+ FunctorExtractor(const TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>& expr)
+ : ifExpr(expr.cond_impl()), thenExpr(expr.then_impl()), elseExpr(expr.else_impl()) {}
};
/// specialisation of the \ref FunctorExtractor struct when the node type is
-/// TensorCwiseSelectOp
+/// TensorCwiseSelectOp. This is an specialisation without OP so it has to be separated
template <typename IfExpr, typename ThenExpr, typename ElseExpr, typename Dev>
-struct FunctorExtractor<
- TensorEvaluator<TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>> {
- FunctorExtractor<IfExpr> ifExpr;
- FunctorExtractor<ThenExpr> thenExpr;
- FunctorExtractor<ElseExpr> elseExpr;
- FunctorExtractor(
- const TensorEvaluator<TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>&
- expr)
- : ifExpr(expr.cond_impl()),
- thenExpr(expr.then_impl()),
- elseExpr(expr.else_impl()) {}
-};
+struct FunctorExtractor<TensorEvaluator<TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> >
+:FunctorExtractor< TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > {};
/// specialisation of the \ref FunctorExtractor struct when the node type is
-/// TensorAssignOp
+/// const TensorAssignOp. This is an specialisation without OP so it has to be separated.
template <typename LHSExpr, typename RHSExpr, typename Dev>
-struct FunctorExtractor<
- TensorEvaluator<TensorAssignOp<LHSExpr, RHSExpr>, Dev>> {
- FunctorExtractor<TensorEvaluator<LHSExpr, Dev>> lhsExpr;
- FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
- FunctorExtractor(
- const TensorEvaluator<TensorAssignOp<LHSExpr, RHSExpr>, Dev>& expr)
- : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()) {}
+struct FunctorExtractor<TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev> > {
+ FunctorExtractor<TensorEvaluator<LHSExpr, Dev> > lhsExpr;
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
+ FunctorExtractor(const TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev>& expr)
+ : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()) {}
};
/// specialisation of the \ref FunctorExtractor struct when the node type is
-/// const TensorAssignOp
+/// TensorAssignOp. This is an specialisation without OP so it has to be separated.
template <typename LHSExpr, typename RHSExpr, typename Dev>
-struct FunctorExtractor<
- TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev>> {
- FunctorExtractor<TensorEvaluator<LHSExpr, Dev>> lhsExpr;
- FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
- FunctorExtractor(
- const TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev>& expr)
- : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()) {}
-};
+struct FunctorExtractor<TensorEvaluator<TensorAssignOp<LHSExpr, RHSExpr>, Dev> >
+:FunctorExtractor<TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev> >{};
+
/// specialisation of the \ref FunctorExtractor struct when the node type is
-/// TensorEvalToOp
+/// const TensorEvalToOp, This is an specialisation without OP so it has to be separated.
template <typename RHSExpr, typename Dev>
-struct FunctorExtractor<TensorEvaluator<TensorEvalToOp<RHSExpr>, Dev>> {
- FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
- FunctorExtractor(const TensorEvaluator<TensorEvalToOp<RHSExpr>, Dev>& expr)
- : rhsExpr(expr.impl()) {}
+struct FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev> > {
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
+ FunctorExtractor(const TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev>& expr)
+ : rhsExpr(expr.impl()) {}
};
/// specialisation of the \ref FunctorExtractor struct when the node type is
-/// const TensorEvalToOp
+/// TensorEvalToOp. This is a specialisation without OP so it has to be separated.
template <typename RHSExpr, typename Dev>
-struct FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev>> {
- FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
- FunctorExtractor(
- const TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev>& expr)
- : rhsExpr(expr.impl()) {}
-};
+struct FunctorExtractor<TensorEvaluator<TensorEvalToOp<RHSExpr>, Dev> >
+: FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev> > {};
+
/// template deduction function for FunctorExtractor
template <typename Evaluator>
-auto extractFunctors(const Evaluator& evaluator)
- -> FunctorExtractor<Evaluator> {
+auto inline extractFunctors(const Evaluator& evaluator)-> FunctorExtractor<Evaluator> {
return FunctorExtractor<Evaluator>(evaluator);
}
} // namespace internal
} // namespace TensorSycl
} // namespace Eigen
-#endif // UNSUPPORTED_EIGEN_CXX11_SRC_TENSORYSYCL_EXTRACT_FUNCTORS_HPP
+#endif // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP