diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h index abc7ba551..fcd7d4d00 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h @@ -84,7 +84,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) { this->m_leftImpl.evalSubExprsIfNeeded(NULL); this->m_rightImpl.evalSubExprsIfNeeded(NULL); - if (data) { + if (data) { evalTo(data); return false; } else { @@ -173,6 +173,7 @@ typename HostExpr::Index LocalThreadSizeM, typename HostExpr::Index LocalThreadS LhsLocalAcc localLhs; RhsLocalAcc localRhs; OutAccessor out_res; + size_t out_offset; Index roundUpK, M, N, K; ContractT m_k_strides, m_left_contracting_strides, m_right_contracting_strides; LeftNocontractT m_i_strides, m_left_nocontract_strides; @@ -182,11 +183,12 @@ typename HostExpr::Index LocalThreadSizeM, typename HostExpr::Index LocalThreadS Device dev; - KernelConstructor(LHSFunctorExpr lhs_functors_, RHSFunctorExpr rhs_functors_, LhsLocalAcc localLhs_, RhsLocalAcc localRhs_, OutAccessor out_res_, + KernelConstructor(LHSFunctorExpr lhs_functors_, RHSFunctorExpr rhs_functors_, LhsLocalAcc localLhs_, RhsLocalAcc localRhs_, OutAccessor out_res_, size_t out_offset_, Index roundUpK_, Index M_, Index N_, Index K_, ContractT m_k_strides_, ContractT m_left_contracting_strides_, ContractT m_right_contracting_strides_, LeftNocontractT m_i_strides_, RightNocontractT m_j_strides_, LeftNocontractT m_left_nocontract_strides_, RightNocontractT m_right_nocontract_strides_, LHSTupleType left_tuple_of_accessors_, RHSTupleType right_tuple_of_accessors_, Device dev_) - :lhs_functors(lhs_functors_), rhs_functors(rhs_functors_), localLhs(localLhs_), localRhs(localRhs_), out_res(out_res_), roundUpK(roundUpK_), M(M_), N(N_), K(K_), + :lhs_functors(lhs_functors_), rhs_functors(rhs_functors_), localLhs(localLhs_), localRhs(localRhs_), out_res(out_res_), + out_offset(out_offset_), roundUpK(roundUpK_), M(M_), N(N_), K(K_), m_k_strides(m_k_strides_), m_left_contracting_strides(m_left_contracting_strides_), m_right_contracting_strides(m_right_contracting_strides_), m_i_strides(m_i_strides_), m_left_nocontract_strides(m_left_nocontract_strides_), @@ -316,7 +318,7 @@ typename HostExpr::Index LocalThreadSizeM, typename HostExpr::Index LocalThreadS for (Index wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) { Index globalCol = nGroupId*TileSizeDimN + nLocalThreadId + wLPTN*LocalThreadSizeN; if(globalCol<N) - out_ptr[globalCol*M + globalRow] = privateRes[wLPTM][wLPTN]; + out_ptr[globalCol*M + globalRow +ConvertToActualSyclOffset(OutScalar, out_offset)] = privateRes[wLPTM][wLPTN]; } } } @@ -356,12 +358,12 @@ template< typename Self, typename OutScalar, typename ContractT, typename LeftNo // extract lhs functor list LHSFunctorExpr lhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.left_impl()); // extract rhs functor list - RHSFunctorExpr rhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.left_impl()); + RHSFunctorExpr rhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.right_impl()); Index roundUpK = RoundUp(K, TileSizeDimK); Index roundUpM = RoundUp(M, TileSizeDimM); Index roundUpN = RoundUp(N, TileSizeDimN); - + ptrdiff_t out_offset = self.device().get_offset(buffer); self.device().sycl_queue().submit([&](cl::sycl::handler &cgh) { /// work-around for gcc bug typedef decltype(Eigen::TensorSycl::internal::createTupleOfAccessors<OrigLHSExpr>(cgh, self.left_impl())) LHSTupleType; @@ -379,17 +381,16 @@ template< typename Self, typename OutScalar, typename ContractT, typename LeftNo typedef cl::sycl::accessor<RhsScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> RhsLocalAcc; RhsLocalAcc localRhs(cl::sycl::range<1>(2* TileSizeDimK * TileSizeDimN), cgh); - typedef cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::write, cl::sycl::access::target::global_buffer> OutAccessor; + typedef cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::global_buffer> OutAccessor; //OutScalar memory - OutAccessor out_res= self.device(). template get_sycl_accessor<cl::sycl::access::mode::write>(cgh, buffer); - + OutAccessor out_res= self.device(). template get_sycl_accessor<cl::sycl::access::mode::read_write>(cgh, buffer); // sycl parallel for cgh.parallel_for(cl::sycl::nd_range<2>(cl::sycl::range<2>(roundUpM/WorkLoadPerThreadM, roundUpN/WorkLoadPerThreadN), cl::sycl::range<2>(LocalThreadSizeM, LocalThreadSizeN)), KernelConstructor<HostExpr, OutScalar, LhsScalar, RhsScalar, LHSFunctorExpr, RHSFunctorExpr, LhsLocalAcc, RhsLocalAcc, OutAccessor, Index, ContractT, LeftNocontractT, RightNocontractT, lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, TileSizeDimM, TileSizeDimN, TileSizeDimK, WorkLoadPerThreadM, WorkLoadPerThreadN, LocalThreadSizeM, LocalThreadSizeN, LoadPerThreadLhs, LoadPerThreadRhs, LHSTupleType, RHSTupleType, Eigen::DefaultDevice>(lhs_functors, rhs_functors, - localLhs, localRhs, out_res, roundUpK, M, N, K, m_k_strides, m_left_contracting_strides, m_right_contracting_strides,m_i_strides, m_j_strides, + localLhs, localRhs, out_res, out_offset, roundUpK, M, N, K, m_k_strides, m_left_contracting_strides, m_right_contracting_strides,m_i_strides, m_j_strides, m_left_nocontract_strides,m_right_nocontract_strides, left_tuple_of_accessors, right_tuple_of_accessors, Eigen::DefaultDevice())); }); self.device().asynchronousExec(); |