aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorSyclPlaceHolderExpr.h
blob: f456c35aad22d58bf6cc6ecd01c55a9c8f84bef6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Mehdi Goli    Codeplay Software Ltd.
// Ralph Potter  Codeplay Software Ltd.
// Luke Iwanski  Codeplay Software Ltd.
// Contact: <eigen@codeplay.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

/*****************************************************************
 * TensorSyclPlaceHolderExpr.h
 *
 * \brief:
 *  This is the specialisation of the placeholder expression based on the
 * operation type
 *
*****************************************************************/

#ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_PLACEHOLDER_EXPR_HPP
#define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_PLACEHOLDER_EXPR_HPP

namespace Eigen {
namespace TensorSycl {
namespace internal {
/// \sttruct PlaceHolderExpression
/// \brief it is used to create the PlaceHolder expression. The PlaceHolder
/// expression is a copy of expression type in which the TensorMap of the has
/// been replaced with PlaceHolder.
template <typename Expr, size_t N>
struct PlaceHolderExpression;

template<size_t N, typename... Args>
struct CalculateIndex;

template<size_t N, typename Arg>
struct CalculateIndex<N, Arg>{
  typedef typename PlaceHolderExpression<Arg, N>::Type ArgType;
  typedef utility::tuple::Tuple<ArgType> ArgsTuple;
};

template<size_t N, typename Arg1, typename Arg2>
struct CalculateIndex<N, Arg1, Arg2>{
  static const size_t Arg2LeafCount = LeafCount<Arg2>::Count;
  typedef typename PlaceHolderExpression<Arg1, N - Arg2LeafCount>::Type Arg1Type;
  typedef typename PlaceHolderExpression<Arg2, N>::Type Arg2Type;
  typedef utility::tuple::Tuple<Arg1Type, Arg2Type> ArgsTuple;
};

template<size_t N, typename Arg1, typename Arg2, typename Arg3>
struct CalculateIndex<N, Arg1, Arg2, Arg3> {
  static const size_t Arg3LeafCount = LeafCount<Arg3>::Count;
  static const size_t Arg2LeafCount = LeafCount<Arg2>::Count;
  typedef typename PlaceHolderExpression<Arg1, N - Arg3LeafCount - Arg2LeafCount>::Type Arg1Type;
  typedef typename PlaceHolderExpression<Arg2, N - Arg3LeafCount>::Type Arg2Type;
  typedef typename PlaceHolderExpression<Arg3, N>::Type Arg3Type;
  typedef utility::tuple::Tuple<Arg1Type, Arg2Type, Arg3Type> ArgsTuple;
};

template<template<class...> class Category , class OP, class TPL>
struct CategoryHelper;

template<template<class...> class Category , class OP, class ...T >
struct CategoryHelper<Category, OP, utility::tuple::Tuple<T...> > {
  typedef Category<OP, T... > Type;
};

template<template<class...> class Category , class ...T >
struct CategoryHelper<Category, NoOP, utility::tuple::Tuple<T...> > {
  typedef Category<T... > Type;
};

/// specialisation of the \ref PlaceHolderExpression when the node is
/// TensorCwiseNullaryOp, TensorCwiseUnaryOp, TensorBroadcastingOp, TensorCwiseBinaryOp,  TensorCwiseTernaryOp
#define OPEXPRCATEGORY(CVQual)\
template <template <class, class... > class Category, typename OP, typename... SubExpr, size_t N>\
struct PlaceHolderExpression<CVQual Category<OP, SubExpr...>, N>{\
  typedef CVQual typename CategoryHelper<Category, OP, typename CalculateIndex<N, SubExpr...>::ArgsTuple>::Type Type;\
};

OPEXPRCATEGORY(const)
OPEXPRCATEGORY()
#undef OPEXPRCATEGORY

/// specialisation of the \ref PlaceHolderExpression when the node is
/// TensorCwiseSelectOp
#define SELECTEXPR(CVQual)\
template <typename IfExpr, typename ThenExpr, typename ElseExpr, size_t N>\
struct PlaceHolderExpression<CVQual TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, N> {\
  typedef CVQual typename CategoryHelper<TensorSelectOp, NoOP, typename CalculateIndex<N, IfExpr, ThenExpr, ElseExpr>::ArgsTuple>::Type Type;\
};

SELECTEXPR(const)
SELECTEXPR()
#undef SELECTEXPR

/// specialisation of the \ref PlaceHolderExpression when the node is
/// TensorAssignOp
#define ASSIGNEXPR(CVQual)\
template <typename LHSExpr, typename RHSExpr, size_t N>\
struct PlaceHolderExpression<CVQual TensorAssignOp<LHSExpr, RHSExpr>, N> {\
  typedef CVQual typename CategoryHelper<TensorAssignOp, NoOP, typename CalculateIndex<N, LHSExpr, RHSExpr>::ArgsTuple>::Type Type;\
};

ASSIGNEXPR(const)
ASSIGNEXPR()
#undef ASSIGNEXPR

/// specialisation of the \ref PlaceHolderExpression when the node is
/// TensorMap
#define TENSORMAPEXPR(CVQual)\
template <typename Scalar_, int Options_, int Options2_, int NumIndices_, typename IndexType_, template <class> class MakePointer_, size_t N>\
struct PlaceHolderExpression< CVQual TensorMap< Tensor<Scalar_, NumIndices_, Options_, IndexType_>, Options2_, MakePointer_>, N> {\
  typedef CVQual Eigen::internal::PlaceHolder<CVQual TensorMap<Tensor<Scalar_, NumIndices_, Options_, IndexType_>, Options2_, MakePointer_>, N> Type;\
};

TENSORMAPEXPR(const)
TENSORMAPEXPR()
#undef TENSORMAPEXPR

/// specialisation of the \ref PlaceHolderExpression when the node is
/// TensorForcedEvalOp
#define FORCEDEVAL(CVQual)\
template <typename Expr, size_t N>\
struct PlaceHolderExpression<CVQual TensorForcedEvalOp<Expr>, N> {\
  typedef CVQual Eigen::internal::PlaceHolder<CVQual TensorForcedEvalOp<Expr>, N> Type;\
};

FORCEDEVAL(const)
FORCEDEVAL()
#undef FORCEDEVAL

/// specialisation of the \ref PlaceHolderExpression when the node is
/// TensorEvalToOp
#define EVALTO(CVQual)\
template <typename Expr, size_t N>\
struct PlaceHolderExpression<CVQual TensorEvalToOp<Expr>, N> {\
  typedef CVQual TensorEvalToOp<typename CalculateIndex <N, Expr>::ArgType> Type;\
};

EVALTO(const)
EVALTO()
#undef EVALTO

/// template deduction for \ref PlaceHolderExpression struct
template <typename Expr>
struct createPlaceHolderExpression {
  static const size_t TotalLeaves = LeafCount<Expr>::Count;
  typedef typename PlaceHolderExpression<Expr, TotalLeaves - 1>::Type Type;
};

}
}
}  // namespace Eigen

#endif  // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_PLACEHOLDER_EXPR_HPP