aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
diff options
context:
space:
mode:
authorGravatar Mehdi Goli <mehdi.goli@codeplay.com>2017-02-01 15:29:53 +0000
committerGravatar Mehdi Goli <mehdi.goli@codeplay.com>2017-02-01 15:29:53 +0000
commitbab29936a1cf0a68ffe4ccb1fd9b4807a3ec87ae (patch)
treec750b36227a31ddb2a1e0d5fd11f0036fda775db /unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
parent48a20b7d956433713a39e04d39cba443b7a763de (diff)
Reducing warnings in Sycl backend.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h115
1 files changed, 57 insertions, 58 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
index dc16f89e0..e87de0c57 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
@@ -22,7 +22,7 @@
#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
namespace Eigen {
-template <typename LhsScalar, typename RhsScalar,bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> struct LaunchSyclKernels;
+template <typename Index, typename LhsScalar, typename RhsScalar,bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> struct LaunchSyclKernels;
template<typename Indices, typename LeftArgType, typename RightArgType>
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, const Eigen::SyclDevice> :
public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, const Eigen::SyclDevice> > {
@@ -146,7 +146,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
- LaunchSyclKernels<LhsScalar, RhsScalar,lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered>::Run(*this, buffer, m, n, k,
+ LaunchSyclKernels<Index, LhsScalar, RhsScalar,lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered>::Run(*this, buffer, m, n, k,
this->m_k_strides, this->m_left_contracting_strides, this->m_right_contracting_strides,
this->m_i_strides, this->m_j_strides, this->m_left_nocontract_strides, this->m_right_nocontract_strides);
}
@@ -162,8 +162,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
template <typename HostExpr, typename OutScalar, typename LhsScalar, typename RhsScalar, typename LHSFunctorExpr, typename RHSFunctorExpr, typename LhsLocalAcc, typename RhsLocalAcc, typename OutAccessor, typename Index, typename ContractT, typename LeftNocontractT,
typename RightNocontractT, bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered,
-int TileSizeDimM, int TileSizeDimN,int TileSizeDimK, int WorkLoadPerThreadM,int WorkLoadPerThreadN,
-int LocalThreadSizeM, int LocalThreadSizeN, int LoadPerThreadLhs, int LoadPerThreadRhs, typename LHSTupleType, typename RHSTupleType, typename Device> struct KernelConstructor{
+typename HostExpr::Index TileSizeDimM, typename HostExpr::Index TileSizeDimN,typename HostExpr::Index TileSizeDimK, typename HostExpr::Index WorkLoadPerThreadM,typename HostExpr::Index WorkLoadPerThreadN,
+typename HostExpr::Index LocalThreadSizeM, typename HostExpr::Index LocalThreadSizeN, typename HostExpr::Index LoadPerThreadLhs, typename HostExpr::Index LoadPerThreadRhs, typename LHSTupleType, typename RHSTupleType, typename Device> struct KernelConstructor{
typedef typename Eigen::internal::traits<HostExpr>::_LhsNested LHSHostExpr;
typedef typename Eigen::internal::traits<HostExpr>::_RhsNested RHSHostExpr;
typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression<LHSHostExpr>::Type LHSPlaceHolderExpr;
@@ -224,84 +224,83 @@ int LocalThreadSizeM, int LocalThreadSizeN, int LoadPerThreadLhs, int LoadPerThr
auto out_ptr = ConvertToActualTypeSycl(OutScalar, out_res);
// Matmul Kernel
// Thread identifiers
- const int mLocalThreadId = itemID.get_local(0); // Local ID row
- const int nLocalThreadId = itemID.get_local(1); // Local ID col
- const int mGroupId = itemID.get_group(0); // Work-group ID row
- const int nGroupId = itemID.get_group(1); // Work-group ID localCol
- const int linearLocalThreadId = nLocalThreadId*LocalThreadSizeM + mLocalThreadId; // linear local thread ID
+ const Index mLocalThreadId = itemID.get_local(0); // Local ID row
+ const Index nLocalThreadId = itemID.get_local(1); // Local ID col
+ const Index mGroupId = itemID.get_group(0); // Work-group ID row
+ const Index nGroupId = itemID.get_group(1); // Work-group ID localCol
+ const Index linearLocalThreadId = nLocalThreadId*LocalThreadSizeM + mLocalThreadId; // linear local thread ID
// Allocate register space
float privateLhs;
float privateRhs[WorkLoadPerThreadN];
float privateRes[WorkLoadPerThreadM][WorkLoadPerThreadN];
// Initialise the privateResumulation registers
- for (int wLPTM=0; wLPTM<WorkLoadPerThreadM; wLPTM++) {
- for (int wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
+ for (Index wLPTM=0; wLPTM<WorkLoadPerThreadM; wLPTM++) {
+ for (Index wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
privateRes[wLPTM][wLPTN] = 0.0f;
}
}
// Tile Lhs
- for (int lPTL=0; lPTL<LoadPerThreadLhs; lPTL++) {
- int
- localLhsLinearId = lPTL*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
- int localLhsRow = localLhsLinearId% TileSizeDimM;
- int localLhsCol = localLhsLinearId/TileSizeDimM;
+ for (Index lPTL=0; lPTL<LoadPerThreadLhs; lPTL++) {
+ Index localLhsLinearId = lPTL*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
+ Index localLhsRow = localLhsLinearId% TileSizeDimM;
+ Index localLhsCol = localLhsLinearId/TileSizeDimM;
// Load the value (wide vector load)
- int GlobalLhsColId = TileSizeDimK*0 + localLhsCol;
+ Index GlobalLhsColId = TileSizeDimK*0 + localLhsCol;
localLhs[0 + ((localLhsCol*TileSizeDimM + localLhsRow)*2)] =((GlobalLhsColId < K)&& (mGroupId*(TileSizeDimM)+ localLhsRow <M))? lhs(mGroupId*(TileSizeDimM) + localLhsRow, GlobalLhsColId):static_cast<OutScalar>(0);
}
// Tile Rhs
- for (int lPTR=0; lPTR<LoadPerThreadRhs; lPTR++) {
- int localRhsLinearId = lPTR*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
- int localRhsRow = localRhsLinearId% TileSizeDimN;
- int localRhsCol = localRhsLinearId/TileSizeDimN;
+ for (Index lPTR=0; lPTR<LoadPerThreadRhs; lPTR++) {
+ Index localRhsLinearId = lPTR*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
+ Index localRhsRow = localRhsLinearId% TileSizeDimN;
+ Index localRhsCol = localRhsLinearId/TileSizeDimN;
// Load the value (wide vector load)
- int GlobalRhsRowId = TileSizeDimK*0 + localRhsCol;
+ Index GlobalRhsRowId = TileSizeDimK*0 + localRhsCol;
localRhs[0 + ((localRhsCol*TileSizeDimN + localRhsRow) *2)] = ((GlobalRhsRowId < K)&& ((nGroupId*(TileSizeDimN) + localRhsRow)< N))? rhs(GlobalRhsRowId, nGroupId*(TileSizeDimN) + localRhsRow): static_cast<OutScalar>(0);
}
// Loop over all tiles
- const int numTiles = roundUpK/TileSizeDimK;
- int firstHalf=0;
+ const Index numTiles = roundUpK/TileSizeDimK;
+ Index firstHalf=0;
do {
// Synchronise
itemID.barrier(cl::sycl::access::fence_space::local_space);
// Load the next tile of Lhs and Rhs into local memory
- int nextHalf = firstHalf + 1;
+ Index nextHalf = firstHalf + 1;
if (nextHalf < numTiles) {
// Tile A
- for (int lPTL=0; lPTL<LoadPerThreadLhs; lPTL++) {
- int localLhsLinearId = lPTL*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
- int localLhsRow = localLhsLinearId% TileSizeDimM;
- int localLhsCol = localLhsLinearId/TileSizeDimM;
+ for (Index lPTL=0; lPTL<LoadPerThreadLhs; lPTL++) {
+ Index localLhsLinearId = lPTL*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
+ Index localLhsRow = localLhsLinearId% TileSizeDimM;
+ Index localLhsCol = localLhsLinearId/TileSizeDimM;
// global K id
- int GlobalLhsColId = TileSizeDimK*nextHalf + localLhsCol;
+ Index GlobalLhsColId = TileSizeDimK*nextHalf + localLhsCol;
// Store the loaded value into local memory
localLhs[(nextHalf%2) + ((localLhsCol*TileSizeDimM + localLhsRow) *2)] = ((GlobalLhsColId < K)&& (mGroupId*(TileSizeDimM)+ localLhsRow <M))? lhs(mGroupId*(TileSizeDimM) + localLhsRow, GlobalLhsColId): static_cast<OutScalar>(0);
}
// Tile B
- for (int lPTR=0; lPTR<LoadPerThreadRhs; lPTR++) {
- int localRhsLinearId = lPTR*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
- int localRhsRow = localRhsLinearId% TileSizeDimN;
- int localRhsCol = localRhsLinearId/TileSizeDimN;
+ for (Index lPTR=0; lPTR<LoadPerThreadRhs; lPTR++) {
+ Index localRhsLinearId = lPTR*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
+ Index localRhsRow = localRhsLinearId% TileSizeDimN;
+ Index localRhsCol = localRhsLinearId/TileSizeDimN;
// Load the value (wide vector load)
- int GlobalRhsRowId = TileSizeDimK*nextHalf + localRhsCol;
+ Index GlobalRhsRowId = TileSizeDimK*nextHalf + localRhsCol;
// Store the loaded vector into local memory
localRhs[(nextHalf%2) +((localRhsCol*TileSizeDimN + localRhsRow)*2)] = ((GlobalRhsRowId < K)&& ((nGroupId*(TileSizeDimN) + localRhsRow)< N))? rhs(GlobalRhsRowId, nGroupId*(TileSizeDimN) + localRhsRow):static_cast<OutScalar>(0);
}
}
// Loop over the values of a single tile
- for (int k=0; k<TileSizeDimK; k++) {
+ for (Index k=0; k<TileSizeDimK; k++) {
// Cache the values of localRhs in registers
- for (int wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
- int localRhsCol = nLocalThreadId + wLPTN*LocalThreadSizeN;
+ for (Index wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
+ Index localRhsCol = nLocalThreadId + wLPTN*LocalThreadSizeN;
privateRhs[wLPTN] = localRhs[(firstHalf%2) +((k*TileSizeDimN + localRhsCol)*2)];
}
// Perform the computation
- for (int wLPTM=0; wLPTM<WorkLoadPerThreadM; wLPTM++) {
- int localLhsRow = mLocalThreadId + wLPTM*LocalThreadSizeM;
+ for (Index wLPTM=0; wLPTM<WorkLoadPerThreadM; wLPTM++) {
+ Index localLhsRow = mLocalThreadId + wLPTM*LocalThreadSizeM;
privateLhs = localLhs[(firstHalf%2)+ ((k*TileSizeDimM + localLhsRow)*2)];
- for (int wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
+ for (Index wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
privateRes[wLPTM][wLPTN] += privateLhs * privateRhs[wLPTN];
}
}
@@ -311,11 +310,11 @@ int LocalThreadSizeM, int LocalThreadSizeN, int LoadPerThreadLhs, int LoadPerThr
} while (firstHalf<numTiles);
// Store the final results in C
- for (int wLPTM=0; wLPTM<WorkLoadPerThreadM; wLPTM++) {
- int globalRow = mGroupId*TileSizeDimM + mLocalThreadId + wLPTM*LocalThreadSizeM;
+ for (Index wLPTM=0; wLPTM<WorkLoadPerThreadM; wLPTM++) {
+ Index globalRow = mGroupId*TileSizeDimM + mLocalThreadId + wLPTM*LocalThreadSizeM;
if (globalRow< M){
- for (int wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
- int globalCol = nGroupId*TileSizeDimN + nLocalThreadId + wLPTN*LocalThreadSizeN;
+ for (Index wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
+ Index globalCol = nGroupId*TileSizeDimN + nLocalThreadId + wLPTN*LocalThreadSizeN;
if(globalCol<N)
out_ptr[globalCol*M + globalRow] = privateRes[wLPTM][wLPTN];
}
@@ -325,24 +324,24 @@ int LocalThreadSizeM, int LocalThreadSizeN, int LoadPerThreadLhs, int LoadPerThr
}
};
-template <typename LhsScalar, typename RhsScalar, bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> struct LaunchSyclKernels {
-
-static const int TileSizeDimM = 32; // Tile size for dimension M
-static const int TileSizeDimN = 32; // Tile size for dimension N
-static const int TileSizeDimK = 16; // Tile size for dimension K
-static const int WorkLoadPerThreadM = 4; // Work load per thread in dimension M
-static const int WorkLoadPerThreadN = 4; // work load per thread in dimension N
-static const int LocalThreadSizeM = (TileSizeDimM/WorkLoadPerThreadM); // Local thread size for the first dimension (M here)
-static const int LocalThreadSizeN = (TileSizeDimN/WorkLoadPerThreadN); // Local thread size for the second dimension (N here)
-static const int LoadPerThreadLhs = ((TileSizeDimK*WorkLoadPerThreadM*WorkLoadPerThreadN)/(TileSizeDimN)); // workload per thread for Lhs expression
-static const int LoadPerThreadRhs = ((TileSizeDimK*WorkLoadPerThreadM*WorkLoadPerThreadN)/(TileSizeDimM)); // workload per thread for Rhs expression
+template <typename Index, typename LhsScalar, typename RhsScalar, bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> struct LaunchSyclKernels {
+
+static const Index TileSizeDimM = 32ul; // Tile size for dimension M
+static const Index TileSizeDimN = 32ul; // Tile size for dimension N
+static const Index TileSizeDimK = 16ul; // Tile size for dimension K
+static const Index WorkLoadPerThreadM = 4ul; // Work load per thread in dimension M
+static const Index WorkLoadPerThreadN = 4ul; // work load per thread in dimension N
+static const Index LocalThreadSizeM = (TileSizeDimM/WorkLoadPerThreadM); // Local thread size for the first dimension (M here)
+static const Index LocalThreadSizeN = (TileSizeDimN/WorkLoadPerThreadN); // Local thread size for the second dimension (N here)
+static const Index LoadPerThreadLhs = ((TileSizeDimK*WorkLoadPerThreadM*WorkLoadPerThreadN)/(TileSizeDimN)); // workload per thread for Lhs expression
+static const Index LoadPerThreadRhs = ((TileSizeDimK*WorkLoadPerThreadM*WorkLoadPerThreadN)/(TileSizeDimM)); // workload per thread for Rhs expression
// RoundUp function to make sure that the global threadId is divisable by local threadId
-static int RoundUp(int x, int y) {
+static Index RoundUp(Index x, Index y) {
return ((((x) + (y) - 1) / (y))*(y));
}
-template< typename Self, typename OutScalar, typename Index, typename ContractT, typename LeftNocontractT, typename RightNocontractT>
+template< typename Self, typename OutScalar, typename ContractT, typename LeftNocontractT, typename RightNocontractT>
static void Run(const Self& self, OutScalar* buffer, Index M, Index N, Index K,
ContractT m_k_strides, ContractT m_left_contracting_strides, ContractT m_right_contracting_strides,
LeftNocontractT m_i_strides, RightNocontractT m_j_strides, LeftNocontractT m_left_nocontract_strides, RightNocontractT m_right_nocontract_strides){