aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
diff options
context:
space:
mode:
authorGravatar Mehdi Goli <mehdi.goli@codeplay.com>2017-02-28 17:16:14 +0000
committerGravatar Mehdi Goli <mehdi.goli@codeplay.com>2017-02-28 17:16:14 +0000
commit8296b87d7bd98c19c6064241880691f164790ede (patch)
treebbd18de82debbf021c7643017f9588a16374934f /unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
parente0bd6f5738b94e8d7a4b17b61bf9cb6418685f28 (diff)
Adding sycl backend for TensorCustomOp; fixing the partial lhs modification issue on sycl when the rhs is TensorContraction, reduction or convolution; Fixing the partial modification for memset when sycl backend is used.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h21
1 files changed, 11 insertions, 10 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
index abc7ba551..fcd7d4d00 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
@@ -84,7 +84,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
this->m_leftImpl.evalSubExprsIfNeeded(NULL);
this->m_rightImpl.evalSubExprsIfNeeded(NULL);
- if (data) {
+ if (data) {
evalTo(data);
return false;
} else {
@@ -173,6 +173,7 @@ typename HostExpr::Index LocalThreadSizeM, typename HostExpr::Index LocalThreadS
LhsLocalAcc localLhs;
RhsLocalAcc localRhs;
OutAccessor out_res;
+ size_t out_offset;
Index roundUpK, M, N, K;
ContractT m_k_strides, m_left_contracting_strides, m_right_contracting_strides;
LeftNocontractT m_i_strides, m_left_nocontract_strides;
@@ -182,11 +183,12 @@ typename HostExpr::Index LocalThreadSizeM, typename HostExpr::Index LocalThreadS
Device dev;
- KernelConstructor(LHSFunctorExpr lhs_functors_, RHSFunctorExpr rhs_functors_, LhsLocalAcc localLhs_, RhsLocalAcc localRhs_, OutAccessor out_res_,
+ KernelConstructor(LHSFunctorExpr lhs_functors_, RHSFunctorExpr rhs_functors_, LhsLocalAcc localLhs_, RhsLocalAcc localRhs_, OutAccessor out_res_, size_t out_offset_,
Index roundUpK_, Index M_, Index N_, Index K_, ContractT m_k_strides_, ContractT m_left_contracting_strides_,
ContractT m_right_contracting_strides_, LeftNocontractT m_i_strides_, RightNocontractT m_j_strides_,
LeftNocontractT m_left_nocontract_strides_, RightNocontractT m_right_nocontract_strides_, LHSTupleType left_tuple_of_accessors_, RHSTupleType right_tuple_of_accessors_, Device dev_)
- :lhs_functors(lhs_functors_), rhs_functors(rhs_functors_), localLhs(localLhs_), localRhs(localRhs_), out_res(out_res_), roundUpK(roundUpK_), M(M_), N(N_), K(K_),
+ :lhs_functors(lhs_functors_), rhs_functors(rhs_functors_), localLhs(localLhs_), localRhs(localRhs_), out_res(out_res_),
+ out_offset(out_offset_), roundUpK(roundUpK_), M(M_), N(N_), K(K_),
m_k_strides(m_k_strides_), m_left_contracting_strides(m_left_contracting_strides_),
m_right_contracting_strides(m_right_contracting_strides_),
m_i_strides(m_i_strides_), m_left_nocontract_strides(m_left_nocontract_strides_),
@@ -316,7 +318,7 @@ typename HostExpr::Index LocalThreadSizeM, typename HostExpr::Index LocalThreadS
for (Index wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
Index globalCol = nGroupId*TileSizeDimN + nLocalThreadId + wLPTN*LocalThreadSizeN;
if(globalCol<N)
- out_ptr[globalCol*M + globalRow] = privateRes[wLPTM][wLPTN];
+ out_ptr[globalCol*M + globalRow +ConvertToActualSyclOffset(OutScalar, out_offset)] = privateRes[wLPTM][wLPTN];
}
}
}
@@ -356,12 +358,12 @@ template< typename Self, typename OutScalar, typename ContractT, typename LeftNo
// extract lhs functor list
LHSFunctorExpr lhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.left_impl());
// extract rhs functor list
- RHSFunctorExpr rhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.left_impl());
+ RHSFunctorExpr rhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.right_impl());
Index roundUpK = RoundUp(K, TileSizeDimK);
Index roundUpM = RoundUp(M, TileSizeDimM);
Index roundUpN = RoundUp(N, TileSizeDimN);
-
+ ptrdiff_t out_offset = self.device().get_offset(buffer);
self.device().sycl_queue().submit([&](cl::sycl::handler &cgh) {
/// work-around for gcc bug
typedef decltype(Eigen::TensorSycl::internal::createTupleOfAccessors<OrigLHSExpr>(cgh, self.left_impl())) LHSTupleType;
@@ -379,17 +381,16 @@ template< typename Self, typename OutScalar, typename ContractT, typename LeftNo
typedef cl::sycl::accessor<RhsScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> RhsLocalAcc;
RhsLocalAcc localRhs(cl::sycl::range<1>(2* TileSizeDimK * TileSizeDimN), cgh);
- typedef cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::write, cl::sycl::access::target::global_buffer> OutAccessor;
+ typedef cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::global_buffer> OutAccessor;
//OutScalar memory
- OutAccessor out_res= self.device(). template get_sycl_accessor<cl::sycl::access::mode::write>(cgh, buffer);
-
+ OutAccessor out_res= self.device(). template get_sycl_accessor<cl::sycl::access::mode::read_write>(cgh, buffer);
// sycl parallel for
cgh.parallel_for(cl::sycl::nd_range<2>(cl::sycl::range<2>(roundUpM/WorkLoadPerThreadM, roundUpN/WorkLoadPerThreadN),
cl::sycl::range<2>(LocalThreadSizeM, LocalThreadSizeN)),
KernelConstructor<HostExpr, OutScalar, LhsScalar, RhsScalar, LHSFunctorExpr, RHSFunctorExpr, LhsLocalAcc, RhsLocalAcc, OutAccessor, Index, ContractT, LeftNocontractT,
RightNocontractT, lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, TileSizeDimM, TileSizeDimN, TileSizeDimK,
WorkLoadPerThreadM, WorkLoadPerThreadN, LocalThreadSizeM, LocalThreadSizeN, LoadPerThreadLhs, LoadPerThreadRhs, LHSTupleType, RHSTupleType, Eigen::DefaultDevice>(lhs_functors, rhs_functors,
- localLhs, localRhs, out_res, roundUpK, M, N, K, m_k_strides, m_left_contracting_strides, m_right_contracting_strides,m_i_strides, m_j_strides,
+ localLhs, localRhs, out_res, out_offset, roundUpK, M, N, K, m_k_strides, m_left_contracting_strides, m_right_contracting_strides,m_i_strides, m_j_strides,
m_left_nocontract_strides,m_right_nocontract_strides, left_tuple_of_accessors, right_tuple_of_accessors, Eigen::DefaultDevice()));
});
self.device().asynchronousExec();