diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h | 35 |
1 files changed, 18 insertions, 17 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h index e87de0c57..5b4c3c5bd 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h @@ -11,7 +11,7 @@ // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. /***************************************************************** - * TensorSyclConvertToDeviceExpression.h + * TensorTensorContractionsycl.h * * \brief: * TensorContractionsycl @@ -84,7 +84,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) { this->m_leftImpl.evalSubExprsIfNeeded(NULL); this->m_rightImpl.evalSubExprsIfNeeded(NULL); - if (data) { + if (data) { evalTo(data); return false; } else { @@ -173,6 +173,7 @@ typename HostExpr::Index LocalThreadSizeM, typename HostExpr::Index LocalThreadS LhsLocalAcc localLhs; RhsLocalAcc localRhs; OutAccessor out_res; + size_t out_offset; Index roundUpK, M, N, K; ContractT m_k_strides, m_left_contracting_strides, m_right_contracting_strides; LeftNocontractT m_i_strides, m_left_nocontract_strides; @@ -182,11 +183,12 @@ typename HostExpr::Index LocalThreadSizeM, typename HostExpr::Index LocalThreadS Device dev; - KernelConstructor(LHSFunctorExpr lhs_functors_, RHSFunctorExpr rhs_functors_, LhsLocalAcc localLhs_, RhsLocalAcc localRhs_, OutAccessor out_res_, + KernelConstructor(LHSFunctorExpr lhs_functors_, RHSFunctorExpr rhs_functors_, LhsLocalAcc localLhs_, RhsLocalAcc localRhs_, OutAccessor out_res_, size_t out_offset_, Index roundUpK_, Index M_, Index N_, Index K_, ContractT m_k_strides_, ContractT m_left_contracting_strides_, ContractT m_right_contracting_strides_, LeftNocontractT m_i_strides_, RightNocontractT m_j_strides_, LeftNocontractT m_left_nocontract_strides_, RightNocontractT m_right_nocontract_strides_, LHSTupleType left_tuple_of_accessors_, RHSTupleType right_tuple_of_accessors_, Device dev_) - :lhs_functors(lhs_functors_), rhs_functors(rhs_functors_), localLhs(localLhs_), localRhs(localRhs_), out_res(out_res_), roundUpK(roundUpK_), M(M_), N(N_), K(K_), + :lhs_functors(lhs_functors_), rhs_functors(rhs_functors_), localLhs(localLhs_), localRhs(localRhs_), out_res(out_res_), + out_offset(out_offset_), roundUpK(roundUpK_), M(M_), N(N_), K(K_), m_k_strides(m_k_strides_), m_left_contracting_strides(m_left_contracting_strides_), m_right_contracting_strides(m_right_contracting_strides_), m_i_strides(m_i_strides_), m_left_nocontract_strides(m_left_nocontract_strides_), @@ -230,13 +232,13 @@ typename HostExpr::Index LocalThreadSizeM, typename HostExpr::Index LocalThreadS const Index nGroupId = itemID.get_group(1); // Work-group ID localCol const Index linearLocalThreadId = nLocalThreadId*LocalThreadSizeM + mLocalThreadId; // linear local thread ID // Allocate register space - float privateLhs; - float privateRhs[WorkLoadPerThreadN]; - float privateRes[WorkLoadPerThreadM][WorkLoadPerThreadN]; + LhsScalar privateLhs; + RhsScalar privateRhs[WorkLoadPerThreadN]; + OutScalar privateRes[WorkLoadPerThreadM][WorkLoadPerThreadN]; // Initialise the privateResumulation registers for (Index wLPTM=0; wLPTM<WorkLoadPerThreadM; wLPTM++) { for (Index wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) { - privateRes[wLPTM][wLPTN] = 0.0f; + privateRes[wLPTM][wLPTN] = static_cast<OutScalar>(0); } } @@ -316,7 +318,7 @@ typename HostExpr::Index LocalThreadSizeM, typename HostExpr::Index LocalThreadS for (Index wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) { Index globalCol = nGroupId*TileSizeDimN + nLocalThreadId + wLPTN*LocalThreadSizeN; if(globalCol<N) - out_ptr[globalCol*M + globalRow] = privateRes[wLPTM][wLPTN]; + out_ptr[globalCol*M + globalRow +ConvertToActualSyclOffset(OutScalar, out_offset)] = privateRes[wLPTM][wLPTN]; } } } @@ -356,12 +358,12 @@ template< typename Self, typename OutScalar, typename ContractT, typename LeftNo // extract lhs functor list LHSFunctorExpr lhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.left_impl()); // extract rhs functor list - RHSFunctorExpr rhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.left_impl()); + RHSFunctorExpr rhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.right_impl()); Index roundUpK = RoundUp(K, TileSizeDimK); Index roundUpM = RoundUp(M, TileSizeDimM); Index roundUpN = RoundUp(N, TileSizeDimN); - + ptrdiff_t out_offset = self.device().get_offset(buffer); self.device().sycl_queue().submit([&](cl::sycl::handler &cgh) { /// work-around for gcc bug typedef decltype(Eigen::TensorSycl::internal::createTupleOfAccessors<OrigLHSExpr>(cgh, self.left_impl())) LHSTupleType; @@ -379,18 +381,17 @@ template< typename Self, typename OutScalar, typename ContractT, typename LeftNo typedef cl::sycl::accessor<RhsScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> RhsLocalAcc; RhsLocalAcc localRhs(cl::sycl::range<1>(2* TileSizeDimK * TileSizeDimN), cgh); - typedef cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::write, cl::sycl::access::target::global_buffer> OutAccessor; + typedef cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::global_buffer> OutAccessor; //OutScalar memory - OutAccessor out_res= self.device(). template get_sycl_accessor<cl::sycl::access::mode::write>(cgh, buffer); - + OutAccessor out_res= self.device(). template get_sycl_accessor<cl::sycl::access::mode::read_write>(cgh, buffer); // sycl parallel for cgh.parallel_for(cl::sycl::nd_range<2>(cl::sycl::range<2>(roundUpM/WorkLoadPerThreadM, roundUpN/WorkLoadPerThreadN), cl::sycl::range<2>(LocalThreadSizeM, LocalThreadSizeN)), KernelConstructor<HostExpr, OutScalar, LhsScalar, RhsScalar, LHSFunctorExpr, RHSFunctorExpr, LhsLocalAcc, RhsLocalAcc, OutAccessor, Index, ContractT, LeftNocontractT, RightNocontractT, lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, TileSizeDimM, TileSizeDimN, TileSizeDimK, - WorkLoadPerThreadM, WorkLoadPerThreadN, LocalThreadSizeM, LocalThreadSizeN, LoadPerThreadLhs, LoadPerThreadRhs, LHSTupleType, RHSTupleType, Eigen::DefaultDevice>(lhs_functors, rhs_functors, - localLhs, localRhs, out_res, roundUpK, M, N, K, m_k_strides, m_left_contracting_strides, m_right_contracting_strides,m_i_strides, m_j_strides, - m_left_nocontract_strides,m_right_nocontract_strides, left_tuple_of_accessors, right_tuple_of_accessors, Eigen::DefaultDevice())); + WorkLoadPerThreadM, WorkLoadPerThreadN, LocalThreadSizeM, LocalThreadSizeN, LoadPerThreadLhs, LoadPerThreadRhs, LHSTupleType, RHSTupleType, Eigen::SyclKernelDevice>(lhs_functors, rhs_functors, + localLhs, localRhs, out_res, out_offset, roundUpK, M, N, K, m_k_strides, m_left_contracting_strides, m_right_contracting_strides,m_i_strides, m_j_strides, + m_left_nocontract_strides,m_right_nocontract_strides, left_tuple_of_accessors, right_tuple_of_accessors, Eigen::SyclKernelDevice())); }); self.device().asynchronousExec(); } |