aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor
diff options
context:
space:
mode:
authorGravatar Mehdi Goli <mehdi.goli@codeplay.com>2016-12-14 15:30:37 +0000
committerGravatar Mehdi Goli <mehdi.goli@codeplay.com>2016-12-14 15:30:37 +0000
commit2d4a091beb9e55664c1475137af7166d524cbc1d (patch)
treed9e4baec0be3eb3c8a4bb2451701f7e49730daa1 /unsupported/Eigen/CXX11/src/Tensor
parent3d59a477201d4d4f34b4332fda699c21387cf726 (diff)
Adding tensor contraction operation backend for Sycl; adding test for contractionOp sycl backend; adding temporary solution to prevent memory leak in buffer; cleaning up cxx11_tensor_buildins_sycl.h
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h10
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h48
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h355
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h28
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h3
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorSycl.h2
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorSyclConvertToDeviceExpression.h15
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractAccessor.h20
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h23
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h13
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorSyclPlaceHolderExpr.h15
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorSyclRun.h1
12 files changed, 430 insertions, 103 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index 20b29e5fd..2ac6abf69 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -156,9 +156,9 @@ struct TensorContractionEvaluatorBase
m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
op.rhsExpression(), op.lhsExpression()), device),
m_device(device),
- m_result(NULL) {
+ m_result(NULL), m_expr_indices(op.indices()) {
EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
- static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)),
+ static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)),
YOU_MADE_A_PROGRAMMING_MISTAKE);
@@ -327,7 +327,7 @@ struct TensorContractionEvaluatorBase
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar * data) {
m_leftImpl.evalSubExprsIfNeeded(NULL);
m_rightImpl.evalSubExprsIfNeeded(NULL);
if (data) {
@@ -564,6 +564,9 @@ struct TensorContractionEvaluatorBase
TensorEvaluator<EvalRightArgType, Device> m_rightImpl;
const Device& m_device;
Scalar* m_result;
+ /// required for sycl
+ const Indices m_expr_indices;
+
};
@@ -621,6 +624,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
}
+
};
} // end namespace Eigen
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
index a2d7c7414..6a28024b6 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
@@ -22,8 +22,14 @@ enum {
/*
* Implementation of the Eigen blas_data_mapper class for tensors.
*/
-
-template <typename Tensor, bool HasRawAccess> struct CoeffLoader {
+/// The make pointer class is used by sycl in order to build the mapper class on the device. For other platform the default make pointer is used which
+/// is scalar * for CoeffLoader.
+template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_ = MakePointer> struct CoeffLoader;
+template<typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
+ int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
+ template <class> class MakePointer_ = MakePointer> class BaseTensorContractionMapper;
+
+template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_> struct CoeffLoader {
enum {
DirectOffsets = false
};
@@ -47,7 +53,7 @@ template <typename Tensor, bool HasRawAccess> struct CoeffLoader {
const Tensor m_tensor;
};
-template <typename Tensor> struct CoeffLoader<Tensor, true> {
+template <typename Tensor, template <class> class MakePointer_> struct CoeffLoader<Tensor, true, MakePointer_> {
enum {
DirectOffsets = true
};
@@ -67,13 +73,14 @@ template <typename Tensor> struct CoeffLoader<Tensor, true> {
}
private:
typedef typename Tensor::Scalar Scalar;
- const Scalar* m_data;
+
+ typename MakePointer_<const Scalar>::Type m_data;
};
template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
- int packet_size, bool inner_dim_contiguous, int Alignment>
+ int packet_size, bool inner_dim_contiguous, int Alignment, template <class> class MakePointer_ = MakePointer>
class SimpleTensorContractionMapper {
public:
EIGEN_DEVICE_FUNC
@@ -89,7 +96,7 @@ class SimpleTensorContractionMapper {
m_k_strides(k_strides) { }
enum {
- DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess>::DirectOffsets
+ DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>::DirectOffsets
};
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
@@ -206,23 +213,22 @@ class SimpleTensorContractionMapper {
}
protected:
- CoeffLoader<Tensor, Tensor::RawAccess> m_tensor;
+ CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_> m_tensor;
const nocontract_t m_nocontract_strides;
const nocontract_t m_ij_strides;
const contract_t m_contract_strides;
const contract_t m_k_strides;
};
-
template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
int packet_size, bool inner_dim_contiguous,
- bool inner_dim_reordered, int Alignment>
-class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment>
+ bool inner_dim_reordered, int Alignment, template <class> class MakePointer_>
+class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment, MakePointer_>
{
public:
- typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment> ParentMapper;
+ typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment, MakePointer_> ParentMapper;
EIGEN_DEVICE_FUNC
BaseTensorContractionMapper(const Tensor& tensor,
@@ -307,11 +313,11 @@ template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
bool inner_dim_contiguous,
- bool inner_dim_reordered, int Alignment>
-class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment>
+ bool inner_dim_reordered, int Alignment, template <class> class MakePointer_>
+class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment, MakePointer_>
{
public:
- typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment> ParentMapper;
+ typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment, MakePointer_> ParentMapper;
EIGEN_DEVICE_FUNC
BaseTensorContractionMapper(const Tensor& tensor,
@@ -345,14 +351,14 @@ template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
int packet_size,
- bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, template <class> class MakePointer_=MakePointer>
class TensorContractionSubMapper {
public:
typedef typename Tensor::PacketReturnType Packet;
typedef typename unpacket_traits<Packet>::half HalfPacket;
- typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
- typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
+ typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> ParentMapper;
+ typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> Self;
typedef Self LinearMapper;
enum {
@@ -452,14 +458,14 @@ template<typename Scalar_, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
int packet_size,
- bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, template <class> class MakePointer_=MakePointer>
class TensorContractionInputMapper
- : public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> {
+ : public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> {
public:
typedef Scalar_ Scalar;
- typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base;
- typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
+ typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> Base;
+ typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> SubMapper;
typedef SubMapper VectorMapper;
EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor,
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
new file mode 100644
index 000000000..7e3c73caf
--- /dev/null
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
@@ -0,0 +1,355 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Mehdi Goli Codeplay Software Ltd.
+// Ralph Potter Codeplay Software Ltd.
+// Luke Iwanski Codeplay Software Ltd.
+// Contact: <eigen@codeplay.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+/*****************************************************************
+ * TensorSyclConvertToDeviceExpression.h
+ *
+ * \brief:
+ * TensorContractionsycl
+ *
+*****************************************************************/
+
+#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
+#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
+namespace Eigen {
+
+template <typename LhsScalar, typename RhsScalar,bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> struct LaunchSyclKernels;
+template<typename Indices, typename LeftArgType, typename RightArgType>
+struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, const Eigen::SyclDevice> :
+ public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, const Eigen::SyclDevice> > {
+
+ typedef const Eigen::SyclDevice Device;
+
+ typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
+ typedef TensorContractionEvaluatorBase<Self> Base;
+ typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
+ typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
+ typedef typename XprType::Index Index;
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+ typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
+
+ enum {
+ Layout = TensorEvaluator<LeftArgType, Device>::Layout,
+ };
+
+ // Most of the code is assuming that both input tensors are ColMajor. If the
+ // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
+ // If we want to compute A * B = C, where A is LHS and B is RHS, the code
+ // will pretend B is LHS and A is RHS.
+ typedef typename internal::conditional<
+ static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
+ typedef typename internal::conditional<
+ static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
+
+ static const int LDims =
+ internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
+ static const int RDims =
+ internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
+ static const int ContractDims = internal::array_size<Indices>::value;
+
+ typedef array<Index, LDims> left_dim_mapper_t;
+ typedef array<Index, RDims> right_dim_mapper_t;
+
+ typedef array<Index, ContractDims> contract_t;
+ typedef array<Index, LDims - ContractDims> left_nocontract_t;
+ typedef array<Index, RDims - ContractDims> right_nocontract_t;
+
+ static const int NumDims = LDims + RDims - 2 * ContractDims;
+
+ typedef DSizes<Index, NumDims> Dimensions;
+
+ // typedefs needed in evalTo
+ typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
+ typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
+
+ typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
+ typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
+
+ typedef typename LeftEvaluator::Dimensions LeftDimensions;
+ typedef typename RightEvaluator::Dimensions RightDimensions;
+
+ EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
+ Base(op, device) {}
+
+ // We need to redefine this method to make nvcc happy
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
+ this->m_leftImpl.evalSubExprsIfNeeded(NULL);
+ this->m_rightImpl.evalSubExprsIfNeeded(NULL);
+ if (data) {
+ evalTo(data);
+ return false;
+ } else {
+ this->m_result = static_cast<Scalar*>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar)));
+ evalTo(this->m_result);
+ return true;
+ }
+ }
+ const Eigen::SyclDevice& device() const {return this->m_device;}
+ void evalTo(Scalar* buffer) const {
+ // Here is the result
+ if (this->m_lhs_inner_dim_contiguous) {
+ if (this->m_rhs_inner_dim_contiguous) {
+ if (this->m_rhs_inner_dim_reordered) {
+ evalTyped<true, true, true, Unaligned>(buffer);
+ }
+ else {
+ evalTyped<true, true, false, Unaligned>(buffer);
+ }
+ }
+ else {
+ if (this->m_rhs_inner_dim_reordered) {
+ evalTyped<true, false, true, Unaligned>(buffer);
+ }
+ else {
+ evalTyped<true, false, false, Unaligned>(buffer);
+ }
+ }
+ }
+ else {
+ if (this->m_rhs_inner_dim_contiguous) {
+ if (this->m_rhs_inner_dim_reordered) {
+ evalTyped<false, true, true, Unaligned>(buffer);
+ }
+ else {
+ evalTyped<false, true, false, Unaligned>(buffer);
+ }
+ }
+ else {
+ if (this->m_rhs_inner_dim_reordered) {
+ evalTyped<false, false, true, Unaligned>(buffer);
+ }
+ else {
+ evalTyped<false, false, false, Unaligned>(buffer);
+ }
+ }
+ }
+ }
+
+ template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
+ void evalTyped(Scalar* buffer) const {
+ // columns in left side, rows in right side
+ const Index k = this->m_k_size;
+ EIGEN_UNUSED_VARIABLE(k)
+ // rows in left side
+ const Index m = this->m_i_size;
+ // columns in right side
+ const Index n = this->m_j_size;
+
+ // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
+ this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
+ LaunchSyclKernels<LhsScalar, RhsScalar,lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered>::Run(*this, buffer, m, n, k,
+ this->m_k_strides, this->m_left_contracting_strides, this->m_right_contracting_strides,
+ this->m_i_strides, this->m_j_strides, this->m_left_nocontract_strides, this->m_right_nocontract_strides);
+ }
+ // required by sycl to construct the expr on the device. Returns original left_impl
+ const TensorEvaluator<LeftArgType, Device>& left_impl() const {
+ return choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), this->m_leftImpl, this->m_rightImpl);
+ }
+ // required by sycl to construct the expr on the device. Returns original right_impl
+ const TensorEvaluator<RightArgType, Device>& right_impl() const {
+ return choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), this->m_rightImpl, this->m_leftImpl);
+ }
+ // required by sycl to construct the expr on the device
+ const Indices& indices() const {return this->m_expr_indices;}
+};
+
+/// Dummy container on the device. This is used to avoid calling the constructor of TensorEvaluator for TensorContractionOp. This makes the code much faster.
+template<typename Expr> struct TensorEvaluatorContainer;
+template<typename Indices, typename LeftArgType, typename RightArgType>
+struct TensorEvaluatorContainer<TensorContractionOp<Indices, LeftArgType, RightArgType>>{
+ typedef Eigen::DefaultDevice Device;
+ typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
+ typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
+ typedef typename XprType::Index Index;
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+ typedef typename PacketType<CoeffReturnType, Eigen::DefaultDevice>::type PacketReturnType;
+ enum {
+ Layout = TensorEvaluator<LeftArgType, Device>::Layout,
+ };
+
+ typedef typename internal::conditional<static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
+ typedef typename internal::conditional<static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
+ typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
+ typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
+
+ TensorEvaluatorContainer(const XprType& op, const Eigen::DefaultDevice& device)
+ : m_leftImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
+ op.lhsExpression(), op.rhsExpression()), device),
+ m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
+ op.rhsExpression(), op.lhsExpression()), device){}
+LeftEvaluator m_leftImpl;
+RightEvaluator m_rightImpl;
+};
+
+#define TileSizeDimM 32 // Tile size for dimension M
+#define TileSizeDimN 32 // Tile size for dimension N
+#define TileSizeDimK 16 // Tile size for dimension K
+#define WorkLoadPerThreadM 4 // Work load per thread in dimension M
+#define WorkLoadPerThreadN 4 // work load per thread in dimension N
+#define LocalThreadSizeM (TileSizeDimM/WorkLoadPerThreadM) // Local thread size for the first dimension (M here)
+#define LocalThreadSizeN (TileSizeDimN/WorkLoadPerThreadN) // Local thread size for the second dimension (N here)
+#define LoadPerThreadLhs ((TileSizeDimK*WorkLoadPerThreadM*WorkLoadPerThreadN)/(TileSizeDimN)) // workload per thread for Lhs expression
+#define LoadPerThreadRhs ((TileSizeDimK*WorkLoadPerThreadM*WorkLoadPerThreadN)/(TileSizeDimM)) // workload per thread for Rhs expression
+#define RoundUp(x,y) ((((x) + (y) - 1) / (y))*(y)) // RoundUp function to make sure that the global threadId is dividabe by local threadId
+
+template <typename PLEXPR, bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> struct KernelNameConstructor;
+template <typename LhsScalar, typename RhsScalar, bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> struct LaunchSyclKernels {
+template< typename Self, typename Output, typename Index, typename ContractT, typename LeftNocontractT, typename RightNocontractT>
+ static void Run(const Self& self, Output* buffer, Index M, Index N, Index K,
+ ContractT m_k_strides, ContractT m_left_contracting_strides, ContractT m_right_contracting_strides,
+ LeftNocontractT m_i_strides, RightNocontractT m_j_strides, LeftNocontractT m_left_nocontract_strides, RightNocontractT m_right_nocontract_strides){
+ // create a tuple of accessors from Evaluator
+ typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression<typename Self::XprType>::Type PlaceHolderExpr;
+ typedef KernelNameConstructor<PlaceHolderExpr, lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered> KernelName;
+ auto functors = Eigen::TensorSycl::internal::extractFunctors(self);
+ Index roundUpK = RoundUp(K, TileSizeDimK);
+ Index roundUpM = RoundUp(M, TileSizeDimM);
+ Index roundUpN = RoundUp(N, TileSizeDimN);
+ self.device().sycl_queue().submit([&](cl::sycl::handler &cgh) {
+ auto tuple_of_accessors = Eigen::TensorSycl::internal::createTupleOfAccessors<Self>(cgh, self);
+ // Local memory for elements of Lhs
+ cl::sycl::accessor<LhsScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> localLhs(cl::sycl::range<1>(2* TileSizeDimM * TileSizeDimK), cgh);
+ // Local memory for elements of Rhs
+ cl::sycl::accessor<RhsScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> localRhs(cl::sycl::range<1>(2* TileSizeDimK * TileSizeDimN), cgh);
+ //Output memory
+ auto out_privateRes= self.device(). template get_sycl_accessor<cl::sycl::access::mode::write>(cgh, buffer);
+ // sycl parallel for
+ cgh.parallel_for<KernelName>( cl::sycl::nd_range<2>(cl::sycl::range<2>(roundUpM/WorkLoadPerThreadM, roundUpN/WorkLoadPerThreadN), cl::sycl::range<2>(LocalThreadSizeM, LocalThreadSizeN)), [=](cl::sycl::nd_item<2> itemID) {
+ typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression<typename Self::XprType>::Type DevExpr;
+ auto device_expr =Eigen::TensorSycl::internal::createDeviceExpression<DevExpr, PlaceHolderExpr>(functors, tuple_of_accessors);
+ auto device_evaluator = TensorEvaluatorContainer<DevExpr>(device_expr.expr, Eigen::DefaultDevice());
+ typedef TensorEvaluatorContainer<DevExpr> DevEvaluator;
+ typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
+ typename DevEvaluator::LeftEvaluator, LeftNocontractT,
+ ContractT, 1,
+ lhs_inner_dim_contiguous,
+ false, Unaligned, MakeGlobalPointer> LhsMapper;
+
+ typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
+ typename DevEvaluator::RightEvaluator, RightNocontractT,
+ ContractT, 1,
+ rhs_inner_dim_contiguous,
+ rhs_inner_dim_reordered, Unaligned, MakeGlobalPointer> RhsMapper;
+ // initialize data mappers must happen inside the kernel for device eval
+ LhsMapper lhs(device_evaluator.m_leftImpl, m_left_nocontract_strides, m_i_strides, m_left_contracting_strides, m_k_strides);
+ RhsMapper rhs(device_evaluator.m_rightImpl, m_right_nocontract_strides, m_j_strides, m_right_contracting_strides, m_k_strides);
+ auto out_ptr = ConvertToActualTypeSycl(Output, out_privateRes);
+ // Matmul Kernel
+ // Thread identifiers
+ const int mLocalThreadId = itemID.get_local(0); // Local ID row
+ const int nLocalThreadId = itemID.get_local(1); // Local ID col
+ const int mGroupId = itemID.get_group(0); // Work-group ID row
+ const int nGroupId = itemID.get_group(1); // Work-group ID localCol
+ const int linearLocalThreadId = nLocalThreadId*LocalThreadSizeM + mLocalThreadId; // linear local thread ID
+ // Allocate register space
+ float privateLhs;
+ float privateRhs[WorkLoadPerThreadN];
+ float privateRes[WorkLoadPerThreadM][WorkLoadPerThreadN];
+ // Initialise the privateResumulation registers
+ for (int wLPTM=0; wLPTM<WorkLoadPerThreadM; wLPTM++) {
+ for (int wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
+ privateRes[wLPTM][wLPTN] = 0.0f;
+ }
+ }
+
+ // Tile Lhs
+ for (int lPTL=0; lPTL<LoadPerThreadLhs; lPTL++) {
+ int
+ localLhsLinearId = lPTL*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
+ int localLhsRow = localLhsLinearId% TileSizeDimM;
+ int localLhsCol = localLhsLinearId/TileSizeDimM;
+ // Load the value (wide vector load)
+ int GlobalLhsColId = TileSizeDimK*0 + localLhsCol;
+ localLhs[0 + ((localLhsCol*TileSizeDimM + localLhsRow)*2)] =((GlobalLhsColId < K)&& (mGroupId*(TileSizeDimM)+ localLhsRow <M))? lhs(mGroupId*(TileSizeDimM) + localLhsRow, GlobalLhsColId):static_cast<Output>(0);
+ }
+ // Tile Rhs
+ for (int lPTR=0; lPTR<LoadPerThreadRhs; lPTR++) {
+ int localRhsLinearId = lPTR*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
+ int localRhsRow = localRhsLinearId% TileSizeDimN;
+ int localRhsCol = localRhsLinearId/TileSizeDimN;
+ // Load the value (wide vector load)
+ int GlobalRhsRowId = TileSizeDimK*0 + localRhsCol;
+ localRhs[0 + ((localRhsCol*TileSizeDimN + localRhsRow) *2)] = ((GlobalRhsRowId < K)&& ((nGroupId*(TileSizeDimN) + localRhsRow)< N))? rhs(GlobalRhsRowId, nGroupId*(TileSizeDimN) + localRhsRow): static_cast<Output>(0);
+
+ }
+ // Loop over all tiles
+ const int numTiles = roundUpK/TileSizeDimK;
+ int firstHalf=0;
+ do {
+ // Synchronise
+ itemID.barrier(cl::sycl::access::fence_space::local_space);
+ // Load the next tile of Lhs and Rhs into local memory
+ int nextHalf = firstHalf + 1;
+ if (nextHalf < numTiles) {
+ // Tile A
+ for (int lPTL=0; lPTL<LoadPerThreadLhs; lPTL++) {
+ int localLhsLinearId = lPTL*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
+ int localLhsRow = localLhsLinearId% TileSizeDimM;
+ int localLhsCol = localLhsLinearId/TileSizeDimM;
+ // global K id
+ int GlobalLhsColId = TileSizeDimK*nextHalf + localLhsCol;
+ // Store the loaded value into local memory
+ localLhs[(nextHalf%2) + ((localLhsCol*TileSizeDimM + localLhsRow) *2)] = ((GlobalLhsColId < K)&& (mGroupId*(TileSizeDimM)+ localLhsRow <M))? lhs(mGroupId*(TileSizeDimM) + localLhsRow, GlobalLhsColId): static_cast<Output>(0);
+ }
+ // Tile B
+ for (int lPTR=0; lPTR<LoadPerThreadRhs; lPTR++) {
+ int localRhsLinearId = lPTR*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
+ int localRhsRow = localRhsLinearId% TileSizeDimN;
+ int localRhsCol = localRhsLinearId/TileSizeDimN;
+ // Load the value (wide vector load)
+ int GlobalRhsRowId = TileSizeDimK*nextHalf + localRhsCol;
+ // Store the loaded vector into local memory
+ localRhs[(nextHalf%2) +((localRhsCol*TileSizeDimN + localRhsRow)*2)] = ((GlobalRhsRowId < K)&& ((nGroupId*(TileSizeDimN) + localRhsRow)< N))? rhs(GlobalRhsRowId, nGroupId*(TileSizeDimN) + localRhsRow):static_cast<Output>(0);
+ }
+ }
+ // Loop over the values of a single tile
+ for (int k=0; k<TileSizeDimK; k++) {
+ // Cache the values of localRhs in registers
+ for (int wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
+ int localRhsCol = nLocalThreadId + wLPTN*LocalThreadSizeN;
+ privateRhs[wLPTN] = localRhs[(firstHalf%2) +((k*TileSizeDimN + localRhsCol)*2)];
+ }
+ // Perform the computation
+ for (int wLPTM=0; wLPTM<WorkLoadPerThreadM; wLPTM++) {
+ int localLhsRow = mLocalThreadId + wLPTM*LocalThreadSizeM;
+ privateLhs = localLhs[(firstHalf%2)+ ((k*TileSizeDimM + localLhsRow)*2)];
+ for (int wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
+ privateRes[wLPTM][wLPTN] += privateLhs * privateRhs[wLPTN];
+ }
+ }
+ }
+ // Next tile
+ firstHalf++;
+ } while (firstHalf<numTiles);
+
+
+ // Store the final results in C
+ for (int wLPTM=0; wLPTM<WorkLoadPerThreadM; wLPTM++) {
+ int globalRow = mGroupId*TileSizeDimM + mLocalThreadId + wLPTM*LocalThreadSizeM;
+ if (globalRow< M){
+ for (int wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
+ int globalCol = nGroupId*TileSizeDimN + nLocalThreadId + wLPTN*LocalThreadSizeN;
+ if(globalCol<N)
+ out_ptr[globalCol*M + globalRow] = privateRes[wLPTM][wLPTN];
+ }
+ }
+ }
+
+ /// End the kernel
+ });
+ });
+ self.device().synchronize();
+ }
+};
+
+} // end namespace Eigen
+#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h
index 40dd5d81a..f92ea1d7b 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h
@@ -31,7 +31,7 @@ namespace Eigen {
auto dst_ptr = ConvertToActualTypeSycl(Scalar, m_dst_acc);
auto globalid = itemID.get_global_linear_id();
if (globalid < m_rng) {
- dst_ptr[globalid + m_i] = src_ptr[globalid + m_offset];
+ dst_ptr[globalid + m_i] = src_ptr[globalid + m_offset];
}
}
@@ -50,7 +50,7 @@ EIGEN_STRONG_INLINE auto get_sycl_supported_devices()->decltype(cl::sycl::device
/// get_devices returns all the available opencl devices. Either use device_selector or exclude devices that computecpp does not support (AMD OpenCL for CPU )
auto s= (*it).template get_info<cl::sycl::info::device::vendor>();
std::transform(s.begin(), s.end(), s.begin(), ::tolower);
- if((*it).is_cpu() && s.find("amd")!=std::string::npos){
+ if((*it).is_cpu() && s.find("amd")!=std::string::npos){ // remove amd cpu as it is not supported by computecpp
it=devices.erase(it);
}
else{
@@ -72,9 +72,9 @@ struct QueueInterface {
mutable std::map<const uint8_t *, cl::sycl::buffer<uint8_t, 1>> buffer_map;
/// sycl queue
mutable cl::sycl::queue m_queue;
- /// creating device by using cl::sycl::selector or cl::sycl::device both are the same and can be captured throufh dev_Selector typename
+ /// creating device by using cl::sycl::selector or cl::sycl::device both are the same and can be captured through dev_Selector typename
/// SyclStreamDevice is not owned. it is the caller's responsibility to destroy it.
- template<typename dev_Selector> explicit QueueInterface(dev_Selector s):
+ template<typename dev_Selector> explicit QueueInterface(const dev_Selector& s):
#ifdef EIGEN_EXCEPTIONS
m_queue(cl::sycl::queue(s, [&](cl::sycl::exception_list l) {
for (const auto& e : l) {
@@ -103,17 +103,21 @@ struct QueueInterface {
auto ptr =buf.get_access<cl::sycl::access::mode::discard_write, cl::sycl::access::target::host_buffer>().get_pointer();
buf.set_final_data(nullptr);
std::lock_guard<std::mutex> lock(mutex_);
- buffer_map.insert(std::pair<const uint8_t *, cl::sycl::buffer<uint8_t, 1>>(ptr,buf));
+ buffer_map.insert(std::pair<const uint8_t *, cl::sycl::buffer<uint8_t, 1>>(static_cast<const uint8_t*>(ptr),buf));
return static_cast<void*>(ptr);
}
/// This is used to deallocate the device pointer. p is used as a key inside
/// the map to find the device buffer and delete it.
- EIGEN_STRONG_INLINE void deallocate(const void *p) const {
+ EIGEN_STRONG_INLINE void deallocate(void *p) const {
std::lock_guard<std::mutex> lock(mutex_);
auto it = buffer_map.find(static_cast<const uint8_t*>(p));
if (it != buffer_map.end()) {
+ auto num_bytes =it->second.get_size();
buffer_map.erase(it);
+ // Temporary solution for memory leak in computecpp. It will be fixed in the next computecpp version
+ std::allocator<uint8_t> a1; // Default allocator for buffer<uint8_t,1>
+ a1.deallocate(static_cast<uint8_t*>(p), num_bytes);
}
}
@@ -188,7 +192,7 @@ struct SyclDevice {
return m_queue_stream->allocate(num_bytes);
}
/// deallocate device memory
- EIGEN_STRONG_INLINE void deallocate(const void *p) const {
+ EIGEN_STRONG_INLINE void deallocate(void *p) const {
m_queue_stream->deallocate(p);
}
@@ -235,25 +239,25 @@ struct SyclDevice {
size_t rng, GRange, tileSize;
parallel_for_setup(n/sizeof(T), tileSize, rng, GRange);
// Assuming that the dst is the start of the destination pointer
- auto dest_buf = cl::sycl::buffer<uint8_t, 1, cl::sycl::map_allocator<uint8_t> >(static_cast<uint8_t*>(dst), cl::sycl::range<1>(rng*sizeof(T)));
+ auto dest_buf = cl::sycl::buffer<uint8_t, 1, cl::sycl::map_allocator<uint8_t> >(static_cast<uint8_t*>(dst), cl::sycl::range<1>(n));
sycl_queue().submit([&](cl::sycl::handler &cgh) {
auto src_acc= it->second.template get_access<cl::sycl::access::mode::read, cl::sycl::access::target::global_buffer>(cgh);
auto dst_acc =dest_buf.template get_access<cl::sycl::access::mode::discard_write, cl::sycl::access::target::global_buffer>(cgh);
- cgh.parallel_for( cl::sycl::nd_range<1>(cl::sycl::range<1>(GRange), cl::sycl::range<1>(tileSize)), MemCopyFunctor<T>(src_acc, dst_acc, rng, 0, offset));
+ cgh.parallel_for( cl::sycl::nd_range<1>(cl::sycl::range<1>(GRange), cl::sycl::range<1>(tileSize)), MemCopyFunctor<T>(src_acc, dst_acc, rng, 0, 0));
});
synchronize();
}
/// returning the sycl queue
EIGEN_STRONG_INLINE cl::sycl::queue& sycl_queue() const { return m_queue_stream->m_queue;}
/// Here is the implementation of memset function on sycl.
- template<typename T> EIGEN_STRONG_INLINE void memset(T *buff, int c, size_t n) const {
+ template<typename T> EIGEN_STRONG_INLINE void memset(T *data, int c, size_t n) const {
size_t rng, GRange, tileSize;
parallel_for_setup(n/sizeof(T), tileSize, rng, GRange);
sycl_queue().submit([&](cl::sycl::handler &cgh) {
- auto buf_acc =get_sycl_buffer(static_cast<uint8_t*>(static_cast<void*>(buff))). template get_access<cl::sycl::access::mode::discard_write, cl::sycl::access::target::global_buffer>(cgh);
+ auto buf_acc =get_sycl_buffer(static_cast<uint8_t*>(static_cast<void*>(data))). template get_access<cl::sycl::access::mode::discard_write, cl::sycl::access::target::global_buffer>(cgh);
cgh.parallel_for<SyclDevice>( cl::sycl::nd_range<1>(cl::sycl::range<1>(GRange), cl::sycl::range<1>(tileSize)), [=](cl::sycl::nd_item<1> itemID) {
auto globalid=itemID.get_global_linear_id();
- if (globalid< n) {
+ if (globalid< rng) {
for(size_t i=0; i<sizeof(T); i++)
buf_acc[globalid*sizeof(T) + i] = c;
}
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
index 834ce07df..a68010c55 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
@@ -32,6 +32,7 @@ struct TensorEvaluator
typedef typename Derived::Scalar CoeffReturnType;
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
typedef typename Derived::Dimensions Dimensions;
+ typedef Derived XprType;
// NumDimensions is -1 for variable dim tensors
static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ?
@@ -152,6 +153,8 @@ struct TensorEvaluator<const Derived, Device>
typedef typename Derived::Scalar CoeffReturnType;
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
typedef typename Derived::Dimensions Dimensions;
+ typedef const Derived XprType;
+
// NumDimensions is -1 for variable dim tensors
static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ?
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSycl.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSycl.h
index e9cef0eae..d7cbb420f 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorSycl.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSycl.h
@@ -80,5 +80,7 @@ template<typename T> struct GetType<false, T>{
//sycl functors
#include "TensorSyclFunctors.h"
+#include "TensorContractionSycl.h"
+
#endif // end of EIGEN_USE_SYCL
#endif // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_H
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclConvertToDeviceExpression.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclConvertToDeviceExpression.h
index e940c8a9d..113dd2557 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclConvertToDeviceExpression.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclConvertToDeviceExpression.h
@@ -135,21 +135,6 @@ KERNELBROKERCONVERTERSLICESTRIDEOP(const)
KERNELBROKERCONVERTERSLICESTRIDEOP()
#undef KERNELBROKERCONVERTERSLICESTRIDEOP
-#define KERNELBROKERCONVERTPADDINGANDRESHAPEANDSHUFFLEOP(OPEXPR, CVQual)\
-template<typename Param, typename XprType>\
-struct ConvertToDeviceExpression<CVQual OPEXPR <Param, XprType> >{\
- typedef CVQual OPEXPR<Param, typename ConvertToDeviceExpression<XprType>::Type> Type;\
-};
-
-KERNELBROKERCONVERTPADDINGANDRESHAPEANDSHUFFLEOP(TensorPaddingOp, const)
-KERNELBROKERCONVERTPADDINGANDRESHAPEANDSHUFFLEOP(TensorPaddingOp, )
-
-KERNELBROKERCONVERTPADDINGANDRESHAPEANDSHUFFLEOP(TensorReshapingOp, const)
-KERNELBROKERCONVERTPADDINGANDRESHAPEANDSHUFFLEOP(TensorReshapingOp, )
-
-KERNELBROKERCONVERTPADDINGANDRESHAPEANDSHUFFLEOP(TensorShufflingOp, const)
-KERNELBROKERCONVERTPADDINGANDRESHAPEANDSHUFFLEOP(TensorShufflingOp, )
-#undef KERNELBROKERCONVERTPADDINGANDRESHAPEANDSHUFFLEOP
} // namespace internal
} // namespace TensorSycl
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractAccessor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractAccessor.h
index dc8356cf4..876fcd45e 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractAccessor.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractAccessor.h
@@ -223,26 +223,6 @@ SYCLSLICESTRIDEOPEXTACC()
#undef SYCLSLICESTRIDEOPEXTACC
-#define PADDINGRESHAPEANDSHUFFOPEXTRACC(OPEXPR, CVQual)\
-template<typename Param, typename XprType, typename Dev>\
-struct ExtractAccessor<TensorEvaluator<CVQual OPEXPR<Param, XprType>, Dev> > {\
- static inline auto getTuple(cl::sycl::handler& cgh, const TensorEvaluator<CVQual OPEXPR<Param, XprType>, Dev>& eval)\
- -> decltype(AccessorConstructor::getTuple(cgh, eval.impl())){\
- return AccessorConstructor::getTuple(cgh, eval.impl());\
- }\
-};
-
-// tensor padding
-PADDINGRESHAPEANDSHUFFOPEXTRACC(TensorPaddingOp, const)
-PADDINGRESHAPEANDSHUFFOPEXTRACC(TensorPaddingOp, )
-// tensor reshaping
-PADDINGRESHAPEANDSHUFFOPEXTRACC(TensorReshapingOp, const)
-PADDINGRESHAPEANDSHUFFOPEXTRACC(TensorReshapingOp, )
-/// Tensor shuffling
-PADDINGRESHAPEANDSHUFFOPEXTRACC(TensorShufflingOp, const)
-PADDINGRESHAPEANDSHUFFOPEXTRACC(TensorShufflingOp, )
-#undef PADDINGRESHAPEANDSHUFFOPEXTRACC
-
/// template deduction for \ref ExtractAccessor
template <typename Evaluator>
auto createTupleOfAccessors(cl::sycl::handler& cgh, const Evaluator& eval)
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h
index ff8be5444..4376a0e3c 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h
@@ -58,7 +58,7 @@ SYCLEXTRTENSORMAPFIXEDSIZE()
template <template <class, class> class UnaryCategory, typename OP, typename RHSExpr, typename Dev>\
struct FunctorExtractor<TensorEvaluator<CVQual UnaryCategory<OP, RHSExpr>, Dev> > {\
FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;\
- OP func;\
+ const OP func;\
FunctorExtractor(const TensorEvaluator<CVQual UnaryCategory<OP, RHSExpr>, Dev>& expr)\
: rhsExpr(expr.impl()), func(expr.functor()) {}\
};
@@ -74,7 +74,7 @@ template <template<class, class, class> class BinaryCategory, typename OP, typen
struct FunctorExtractor<TensorEvaluator<CVQual BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> > {\
FunctorExtractor<TensorEvaluator<LHSExpr, Dev> > lhsExpr;\
FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;\
- OP func;\
+ const OP func;\
FunctorExtractor(const TensorEvaluator<CVQual BinaryCategory<OP, LHSExpr, RHSExpr>, Dev>& expr)\
: lhsExpr(expr.left_impl()),rhsExpr(expr.right_impl()),func(expr.functor()) {}\
};
@@ -90,7 +90,7 @@ struct FunctorExtractor<TensorEvaluator<CVQual TernaryCategory<OP, Arg1Expr, Arg
FunctorExtractor<TensorEvaluator<Arg1Expr, Dev> > arg1Expr;\
FunctorExtractor<TensorEvaluator<Arg2Expr, Dev> > arg2Expr;\
FunctorExtractor<TensorEvaluator<Arg3Expr, Dev> > arg3Expr;\
- OP func;\
+ const OP func;\
FunctorExtractor(const TensorEvaluator<CVQual TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev>& expr)\
: arg1Expr(expr.arg1Impl()), arg2Expr(expr.arg2Impl()), arg3Expr(expr.arg3Impl()), func(expr.functor()) {}\
};
@@ -241,6 +241,23 @@ PADDINGOPFUNCEXT(TensorPaddingOp, padding(), padding_value(), const)
PADDINGOPFUNCEXT(TensorPaddingOp, padding(), padding_value(), )
#undef PADDINGOPFUNCEXT
+/// specialisation of the \ref FunctorExtractor struct when the node type is
+/// TensorContractionOp The LHS and RHS here are the original one no need to apply condition on their type.
+#define SYCLEXTRFUNCCONTRACT(CVQual)\
+template <typename Indices, typename LHSExpr, typename RHSExpr, typename Dev>\
+struct FunctorExtractor<TensorEvaluator<CVQual TensorContractionOp<Indices, LHSExpr, RHSExpr>, Dev> > {\
+ FunctorExtractor<TensorEvaluator<LHSExpr, Dev> > lhsExpr;\
+ FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;\
+ const Indices func;\
+ FunctorExtractor(const TensorEvaluator<CVQual TensorContractionOp<Indices, LHSExpr, RHSExpr>, Dev>& expr)\
+ : lhsExpr(expr.left_impl()),rhsExpr(expr.right_impl()),func(expr.indices()) {}\
+};
+
+SYCLEXTRFUNCCONTRACT(const)
+SYCLEXTRFUNCCONTRACT()
+#undef SYCLEXTRFUNCCONTRACT
+
+
/// template deduction function for FunctorExtractor
template <typename Evaluator>
auto inline extractFunctors(const Evaluator& evaluator)-> FunctorExtractor<Evaluator> {
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h
index 5d392218e..37fe196ea 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h
@@ -132,19 +132,6 @@ SLICESTRIDEOPLEAFCOUNT(const)
SLICESTRIDEOPLEAFCOUNT()
#undef SLICESTRIDEOPLEAFCOUNT
-#define PADDINGRESHAPEANDSHUFFLELEAFCOUNT(OPEXPR, CVQual)\
-template<typename Param, typename XprType>\
-struct LeafCount<CVQual OPEXPR<Param, XprType> >:CategoryCount<XprType>{};
-
-PADDINGRESHAPEANDSHUFFLELEAFCOUNT(TensorPaddingOp, const)
-PADDINGRESHAPEANDSHUFFLELEAFCOUNT(TensorPaddingOp, )
-
-PADDINGRESHAPEANDSHUFFLELEAFCOUNT(TensorReshapingOp, const)
-PADDINGRESHAPEANDSHUFFLELEAFCOUNT(TensorReshapingOp, )
-
-PADDINGRESHAPEANDSHUFFLELEAFCOUNT(TensorShufflingOp, const)
-PADDINGRESHAPEANDSHUFFLELEAFCOUNT(TensorShufflingOp, )
-#undef PADDINGRESHAPEANDSHUFFLELEAFCOUNT
} /// namespace TensorSycl
} /// namespace internal
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclPlaceHolderExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclPlaceHolderExpr.h
index e1dbd0c6c..4419a1780 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclPlaceHolderExpr.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclPlaceHolderExpr.h
@@ -191,21 +191,6 @@ SYCLSLICESTRIDEOPPLH(const)
SYCLSLICESTRIDEOPPLH()
#undef SYCLSLICESTRIDEOPPLH
-#define PADDINGRESHAPEANDSHUFFLEOPPLH(OPEXP , CVQual)\
-template<typename Param, typename XprType, size_t N>\
-struct PlaceHolderExpression<CVQual OPEXP<Param, XprType>, N > {\
- typedef CVQual OPEXP<Param, typename CalculateIndex<N, XprType>::ArgType> Type;\
-};
-
-PADDINGRESHAPEANDSHUFFLEOPPLH(TensorPaddingOp, const)
-PADDINGRESHAPEANDSHUFFLEOPPLH(TensorPaddingOp,)
-
-PADDINGRESHAPEANDSHUFFLEOPPLH(TensorReshapingOp, const)
-PADDINGRESHAPEANDSHUFFLEOPPLH(TensorReshapingOp, )
-
-PADDINGRESHAPEANDSHUFFLEOPPLH(TensorShufflingOp, const)
-PADDINGRESHAPEANDSHUFFLEOPPLH(TensorShufflingOp,)
-#undef PADDINGRESHAPEANDSHUFFLEOPPLH
/// template deduction for \ref PlaceHolderExpression struct
template <typename Expr>
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclRun.h b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclRun.h
index f259f03c4..69f7211cf 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorSyclRun.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorSyclRun.h
@@ -56,7 +56,6 @@ void run(Expr &expr, Dev &dev) {
});
dev.synchronize();
}
-
evaluator.cleanup();
}
} // namespace TensorSycl