aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h313
1 files changed, 313 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h
new file mode 100644
index 000000000..f69c5afcb
--- /dev/null
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h
@@ -0,0 +1,313 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Mehdi Goli Codeplay Software Ltd.
+// Ralph Potter Codeplay Software Ltd.
+// Luke Iwanski Codeplay Software Ltd.
+// Contact: <eigen@codeplay.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+/*****************************************************************
+ * TensorSyclextractFunctors.h
+ *
+ * \brief:
+ * Used to extract all the functors allocated to each node of the expression
+*tree.
+ *
+*****************************************************************/
+
+#ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSORYSYCL_EXTRACT_FUNCTORS_HPP
+#define UNSUPPORTED_EIGEN_CXX11_SRC_TENSORYSYCL_EXTRACT_FUNCTORS_HPP
+
+namespace Eigen {
+namespace TensorSycl {
+namespace internal {
+/// \struct FunctorExtractor: This struct is used to extract the functors
+/// constructed on
+/// the host-side, to pack them and reuse them in reconstruction of the
+/// expression on the device.
+/// We have to do that as in Eigen the functors are not stateless so we cannot
+/// re-instantiate them on the device.
+/// We have to pass whatever instantiated to the device.
+template <typename Evaluator>
+struct FunctorExtractor;
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// TensorMap:
+template <typename PlainObjectType, int Options_, typename Dev>
+struct FunctorExtractor<
+ TensorEvaluator<TensorMap<PlainObjectType, Options_>, Dev>> {
+ using Dimensions = typename PlainObjectType::Dimensions;
+ const Dimensions m_dimensions;
+ const Dimensions& dimensions() const { return m_dimensions; }
+ FunctorExtractor(
+ const TensorEvaluator<TensorMap<PlainObjectType, Options_>, Dev>& expr)
+ : m_dimensions(expr.dimensions()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// const TensorMap
+template <typename PlainObjectType, int Options_, typename Dev>
+struct FunctorExtractor<
+ TensorEvaluator<const TensorMap<PlainObjectType, Options_>, Dev>> {
+ using Dimensions = typename PlainObjectType::Dimensions;
+ const Dimensions m_dimensions;
+ const Dimensions& dimensions() const { return m_dimensions; }
+ FunctorExtractor(
+ const TensorEvaluator<const TensorMap<PlainObjectType, Options_>, Dev>&
+ expr)
+ : m_dimensions(expr.dimensions()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// TensorForcedEvalOp
+template <typename Expr, typename Dev>
+struct FunctorExtractor<TensorEvaluator<TensorForcedEvalOp<Expr>, Dev>> {
+ using Dimensions = typename Expr::Dimensions;
+ const Dimensions m_dimensions;
+ const Dimensions& dimensions() const { return m_dimensions; }
+ FunctorExtractor(const TensorEvaluator<TensorForcedEvalOp<Expr>, Dev>& expr)
+ : m_dimensions(expr.dimensions()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// const TensorForcedEvalOp
+template <typename Expr, typename Dev>
+struct FunctorExtractor<TensorEvaluator<const TensorForcedEvalOp<Expr>, Dev>> {
+ using Dimensions =
+ typename TensorEvaluator<const TensorForcedEvalOp<Expr>, Dev>::Dimensions;
+ const Dimensions m_dimensions;
+ const Dimensions& dimensions() const { return m_dimensions; }
+ FunctorExtractor(
+ const TensorEvaluator<const TensorForcedEvalOp<Expr>, Dev>& expr)
+ : m_dimensions(expr.dimensions()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// TensorCwiseNullaryOp
+template <typename OP, typename RHSExpr, typename Dev>
+struct FunctorExtractor<
+ TensorEvaluator<TensorCwiseNullaryOp<OP, RHSExpr>, Dev>> {
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
+ OP func;
+ FunctorExtractor(
+ TensorEvaluator<TensorCwiseNullaryOp<OP, RHSExpr>, Dev>& expr)
+ : rhsExpr(expr.impl()), func(expr.functor()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// const TensorCwiseNullaryOp
+template <typename OP, typename RHSExpr, typename Dev>
+struct FunctorExtractor<
+ TensorEvaluator<const TensorCwiseNullaryOp<OP, RHSExpr>, Dev>> {
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
+ OP func;
+ FunctorExtractor(
+ const TensorEvaluator<const TensorCwiseNullaryOp<OP, RHSExpr>, Dev>& expr)
+ : rhsExpr(expr.impl()), func(expr.functor()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// TensorBroadcastingOp
+template <typename OP, typename RHSExpr, typename Dev>
+struct FunctorExtractor<
+ TensorEvaluator<TensorBroadcastingOp<OP, RHSExpr>, Dev>> {
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
+ OP func;
+ FunctorExtractor(
+ const TensorEvaluator<TensorBroadcastingOp<OP, RHSExpr>, Dev>& expr)
+ : rhsExpr(expr.impl()), func(expr.functor()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// const TensorBroadcastingOp
+template <typename OP, typename RHSExpr, typename Dev>
+struct FunctorExtractor<
+ TensorEvaluator<const TensorBroadcastingOp<OP, RHSExpr>, Dev>> {
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
+ OP func;
+ FunctorExtractor(
+ const TensorEvaluator<const TensorBroadcastingOp<OP, RHSExpr>, Dev>& expr)
+ : rhsExpr(expr.impl()), func(expr.functor()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// TensorCwiseUnaryOp
+template <typename OP, typename RHSExpr, typename Dev>
+struct FunctorExtractor<TensorEvaluator<TensorCwiseUnaryOp<OP, RHSExpr>, Dev>> {
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
+ OP func;
+ FunctorExtractor(
+ const TensorEvaluator<TensorCwiseUnaryOp<OP, RHSExpr>, Dev>& expr)
+ : rhsExpr(expr.impl()), func(expr.functor()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// const TensorCwiseUnaryOp
+template <typename OP, typename RHSExpr, typename Dev>
+struct FunctorExtractor<
+ TensorEvaluator<const TensorCwiseUnaryOp<OP, RHSExpr>, Dev>> {
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
+ OP func;
+ FunctorExtractor(
+ const TensorEvaluator<const TensorCwiseUnaryOp<OP, RHSExpr>, Dev>& expr)
+ : rhsExpr(expr.impl()), func(expr.functor()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// TensorCwiseBinaryOp
+template <typename OP, typename LHSExpr, typename RHSExpr, typename Dev>
+struct FunctorExtractor<
+ TensorEvaluator<TensorCwiseBinaryOp<OP, LHSExpr, RHSExpr>, Dev>> {
+ FunctorExtractor<TensorEvaluator<LHSExpr, Dev>> lhsExpr;
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
+ OP func;
+ FunctorExtractor(
+ const TensorEvaluator<TensorCwiseBinaryOp<OP, LHSExpr, RHSExpr>, Dev>&
+ expr)
+ : lhsExpr(expr.left_impl()),
+ rhsExpr(expr.right_impl()),
+ func(expr.functor()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// const TensorCwiseBinaryOp
+template <typename OP, typename LHSExpr, typename RHSExpr, typename Dev>
+struct FunctorExtractor<
+ TensorEvaluator<const TensorCwiseBinaryOp<OP, LHSExpr, RHSExpr>, Dev>> {
+ FunctorExtractor<TensorEvaluator<LHSExpr, Dev>> lhsExpr;
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
+ OP func;
+ FunctorExtractor(const TensorEvaluator<
+ const TensorCwiseBinaryOp<OP, LHSExpr, RHSExpr>, Dev>& expr)
+ : lhsExpr(expr.left_impl()),
+ rhsExpr(expr.right_impl()),
+ func(expr.functor()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// const TensorCwiseTernaryOp
+template <typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr,
+ typename Dev>
+struct FunctorExtractor<TensorEvaluator<
+ const TensorCwiseTernaryOp<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev>> {
+ FunctorExtractor<TensorEvaluator<Arg1Expr, Dev>> arg1Expr;
+ FunctorExtractor<TensorEvaluator<Arg2Expr, Dev>> arg2Expr;
+ FunctorExtractor<TensorEvaluator<Arg3Expr, Dev>> arg3Expr;
+ OP func;
+ FunctorExtractor(const TensorEvaluator<
+ const TensorCwiseTernaryOp<OP, Arg1Expr, Arg2Expr, Arg3Expr>,
+ Dev>& expr)
+ : arg1Expr(expr.arg1Impl()),
+ arg2Expr(expr.arg2Impl()),
+ arg3Expr(expr.arg3Impl()),
+ func(expr.functor()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// TensorCwiseTernaryOp
+template <typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr,
+ typename Dev>
+struct FunctorExtractor<TensorEvaluator<
+ TensorCwiseTernaryOp<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev>> {
+ FunctorExtractor<TensorEvaluator<Arg1Expr, Dev>> arg1Expr;
+ FunctorExtractor<TensorEvaluator<Arg2Expr, Dev>> arg2Expr;
+ FunctorExtractor<TensorEvaluator<Arg3Expr, Dev>> arg3Expr;
+ OP func;
+ FunctorExtractor(
+ const TensorEvaluator<
+ TensorCwiseTernaryOp<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev>& expr)
+ : arg1Expr(expr.arg1Impl()),
+ arg2Expr(expr.arg2Impl()),
+ arg3Expr(expr.arg3Impl()),
+ func(expr.functor()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// const TensorCwiseSelectOp
+template <typename IfExpr, typename ThenExpr, typename ElseExpr, typename Dev>
+struct FunctorExtractor<
+ TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>> {
+ FunctorExtractor<TensorEvaluator<IfExpr, Dev>> ifExpr;
+ FunctorExtractor<TensorEvaluator<ThenExpr, Dev>> thenExpr;
+ FunctorExtractor<TensorEvaluator<ElseExpr, Dev>> elseExpr;
+ FunctorExtractor(const TensorEvaluator<
+ const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>& expr)
+ : ifExpr(expr.cond_impl()),
+ thenExpr(expr.then_impl()),
+ elseExpr(expr.else_impl()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// TensorCwiseSelectOp
+template <typename IfExpr, typename ThenExpr, typename ElseExpr, typename Dev>
+struct FunctorExtractor<
+ TensorEvaluator<TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>> {
+ FunctorExtractor<IfExpr> ifExpr;
+ FunctorExtractor<ThenExpr> thenExpr;
+ FunctorExtractor<ElseExpr> elseExpr;
+ FunctorExtractor(
+ const TensorEvaluator<TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>&
+ expr)
+ : ifExpr(expr.cond_impl()),
+ thenExpr(expr.then_impl()),
+ elseExpr(expr.else_impl()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// TensorAssignOp
+template <typename LHSExpr, typename RHSExpr, typename Dev>
+struct FunctorExtractor<
+ TensorEvaluator<TensorAssignOp<LHSExpr, RHSExpr>, Dev>> {
+ FunctorExtractor<TensorEvaluator<LHSExpr, Dev>> lhsExpr;
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
+ FunctorExtractor(
+ const TensorEvaluator<TensorAssignOp<LHSExpr, RHSExpr>, Dev>& expr)
+ : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// const TensorAssignOp
+template <typename LHSExpr, typename RHSExpr, typename Dev>
+struct FunctorExtractor<
+ TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev>> {
+ FunctorExtractor<TensorEvaluator<LHSExpr, Dev>> lhsExpr;
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
+ FunctorExtractor(
+ const TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev>& expr)
+ : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// TensorEvalToOp
+template <typename RHSExpr, typename Dev>
+struct FunctorExtractor<TensorEvaluator<TensorEvalToOp<RHSExpr>, Dev>> {
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
+ FunctorExtractor(const TensorEvaluator<TensorEvalToOp<RHSExpr>, Dev>& expr)
+ : rhsExpr(expr.impl()) {}
+};
+
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// const TensorEvalToOp
+template <typename RHSExpr, typename Dev>
+struct FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev>> {
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev>> rhsExpr;
+ FunctorExtractor(
+ const TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev>& expr)
+ : rhsExpr(expr.impl()) {}
+};
+
+/// template deduction function for FunctorExtractor
+template <typename Evaluator>
+auto extractFunctors(const Evaluator& evaluator)
+ -> FunctorExtractor<Evaluator> {
+ return FunctorExtractor<Evaluator>(evaluator);
+}
+} // namespace internal
+} // namespace TensorSycl
+} // namespace Eigen
+
+#endif // UNSUPPORTED_EIGEN_CXX11_SRC_TENSORYSYCL_EXTRACT_FUNCTORS_HPP