// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Mehdi Goli Codeplay Software Ltd. // Ralph Potter Codeplay Software Ltd. // Luke Iwanski Codeplay Software Ltd. // Contact: // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. /***************************************************************** * TensorTensorContractionsycl.h * * \brief: * TensorContractionsycl * *****************************************************************/ #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H namespace Eigen { template struct LaunchSyclKernels; template struct TensorEvaluator, const Eigen::SyclDevice> : public TensorContractionEvaluatorBase, const Eigen::SyclDevice> > { static_assert(std::is_same::value, "SYCL tensor contraction does not support output kernels."); typedef const Eigen::SyclDevice Device; typedef TensorEvaluator, Device> Self; typedef TensorContractionEvaluatorBase Base; typedef TensorContractionOp XprType; typedef typename internal::remove_const::type Scalar; typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename PacketType::type PacketReturnType; enum { Layout = TensorEvaluator::Layout, }; // Most of the code is assuming that both input tensors are ColMajor. If the // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: // If we want to compute A * B = C, where A is LHS and B is RHS, the code // will pretend B is LHS and A is RHS. typedef typename internal::conditional< static_cast(Layout) == static_cast(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType; typedef typename internal::conditional< static_cast(Layout) == static_cast(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType; static const int LDims = internal::array_size::Dimensions>::value; static const int RDims = internal::array_size::Dimensions>::value; static const int ContractDims = internal::array_size::value; typedef array left_dim_mapper_t; typedef array right_dim_mapper_t; typedef array contract_t; typedef array left_nocontract_t; typedef array right_nocontract_t; static const int NumDims = LDims + RDims - 2 * ContractDims; typedef DSizes Dimensions; // typedefs needed in evalTo typedef typename internal::remove_const::type LhsScalar; typedef typename internal::remove_const::type RhsScalar; typedef TensorEvaluator LeftEvaluator; typedef TensorEvaluator RightEvaluator; typedef typename LeftEvaluator::Dimensions LeftDimensions; typedef typename RightEvaluator::Dimensions RightDimensions; EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : Base(op, device) {} // We need to redefine this method to make nvcc happy EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) { this->m_leftImpl.evalSubExprsIfNeeded(NULL); this->m_rightImpl.evalSubExprsIfNeeded(NULL); if (data) { evalTo(data); return false; } else { this->m_result = static_cast(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar))); evalTo(this->m_result); return true; } } const Eigen::SyclDevice& device() const {return this->m_device;} void evalTo(Scalar* buffer) const { // Here is the result if (this->m_lhs_inner_dim_contiguous) { if (this->m_rhs_inner_dim_contiguous) { if (this->m_rhs_inner_dim_reordered) { evalTyped(buffer); } else { evalTyped(buffer); } } else { if (this->m_rhs_inner_dim_reordered) { evalTyped(buffer); } else { evalTyped(buffer); } } } else { if (this->m_rhs_inner_dim_contiguous) { if (this->m_rhs_inner_dim_reordered) { evalTyped(buffer); } else { evalTyped(buffer); } } else { if (this->m_rhs_inner_dim_reordered) { evalTyped(buffer); } else { evalTyped(buffer); } } } } template void evalTyped(Scalar* buffer) const { // columns in left side, rows in right side const Index k = this->m_k_size; EIGEN_UNUSED_VARIABLE(k) // rows in left side const Index m = this->m_i_size; // columns in right side const Index n = this->m_j_size; // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar) this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); LaunchSyclKernels::Run(*this, buffer, m, n, k, this->m_k_strides, this->m_left_contracting_strides, this->m_right_contracting_strides, this->m_i_strides, this->m_j_strides, this->m_left_nocontract_strides, this->m_right_nocontract_strides); } // required by sycl to construct the expr on the device. Returns original left_impl const TensorEvaluator& left_impl() const { return choose(Cond(Layout) == static_cast(ColMajor)>(), this->m_leftImpl, this->m_rightImpl); } // required by sycl to construct the expr on the device. Returns original right_impl const TensorEvaluator& right_impl() const { return choose(Cond(Layout) == static_cast(ColMajor)>(), this->m_rightImpl, this->m_leftImpl); } }; template struct KernelConstructor{ typedef typename Eigen::internal::traits::_LhsNested LHSHostExpr; typedef typename Eigen::internal::traits::_RhsNested RHSHostExpr; typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression::Type LHSPlaceHolderExpr; typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression::Type RHSPlaceHolderExpr; LHSFunctorExpr lhs_functors; RHSFunctorExpr rhs_functors; LhsLocalAcc localLhs; RhsLocalAcc localRhs; OutAccessor out_res; size_t out_offset; Index roundUpK, M, N, K; ContractT m_k_strides, m_left_contracting_strides, m_right_contracting_strides; LeftNocontractT m_i_strides, m_left_nocontract_strides; RightNocontractT m_j_strides, m_right_nocontract_strides; LHSTupleType left_tuple_of_accessors; RHSTupleType right_tuple_of_accessors; Device dev; KernelConstructor(LHSFunctorExpr lhs_functors_, RHSFunctorExpr rhs_functors_, LhsLocalAcc localLhs_, RhsLocalAcc localRhs_, OutAccessor out_res_, size_t out_offset_, Index roundUpK_, Index M_, Index N_, Index K_, ContractT m_k_strides_, ContractT m_left_contracting_strides_, ContractT m_right_contracting_strides_, LeftNocontractT m_i_strides_, RightNocontractT m_j_strides_, LeftNocontractT m_left_nocontract_strides_, RightNocontractT m_right_nocontract_strides_, LHSTupleType left_tuple_of_accessors_, RHSTupleType right_tuple_of_accessors_, Device dev_) :lhs_functors(lhs_functors_), rhs_functors(rhs_functors_), localLhs(localLhs_), localRhs(localRhs_), out_res(out_res_), out_offset(out_offset_), roundUpK(roundUpK_), M(M_), N(N_), K(K_), m_k_strides(m_k_strides_), m_left_contracting_strides(m_left_contracting_strides_), m_right_contracting_strides(m_right_contracting_strides_), m_i_strides(m_i_strides_), m_left_nocontract_strides(m_left_nocontract_strides_), m_j_strides(m_j_strides_), m_right_nocontract_strides(m_right_nocontract_strides_), left_tuple_of_accessors(left_tuple_of_accessors_), right_tuple_of_accessors(right_tuple_of_accessors_), dev(dev_){} void operator()(cl::sycl::nd_item<2> itemID) { typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression::Type DevExpr; typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression::Type LHSDevExpr; typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression::Type RHSDevExpr; auto lhs_dev_expr = Eigen::TensorSycl::internal::createDeviceExpression(lhs_functors, left_tuple_of_accessors); auto rhs_dev_expr = Eigen::TensorSycl::internal::createDeviceExpression(rhs_functors, right_tuple_of_accessors); typedef decltype(lhs_dev_expr.expr) LeftArgType; typedef decltype(rhs_dev_expr.expr) RightArgType; typedef typename internal::conditional(Eigen::internal::traits::Layout) == static_cast(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType; typedef typename internal::conditional(Eigen::internal::traits::Layout) == static_cast(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType; typedef TensorEvaluator LeftEvaluator; typedef TensorEvaluator RightEvaluator; typedef internal::TensorContractionInputMapper LhsMapper; typedef internal::TensorContractionInputMapper RhsMapper; // initialize data mappers must happen inside the kernel for device eval LhsMapper lhs(LeftEvaluator(choose(Cond(Eigen::internal::traits::Layout) == static_cast(ColMajor)>(), lhs_dev_expr.expr, rhs_dev_expr.expr), dev), m_left_nocontract_strides, m_i_strides, m_left_contracting_strides, m_k_strides); RhsMapper rhs(RightEvaluator(choose(Cond(Eigen::internal::traits::Layout) == static_cast(ColMajor)>(), rhs_dev_expr.expr, lhs_dev_expr.expr),dev), m_right_nocontract_strides, m_j_strides, m_right_contracting_strides, m_k_strides); auto out_ptr = ConvertToActualTypeSycl(OutScalar, out_res); // Matmul Kernel // Thread identifiers const Index mLocalThreadId = itemID.get_local(0); // Local ID row const Index nLocalThreadId = itemID.get_local(1); // Local ID col const Index mGroupId = itemID.get_group(0); // Work-group ID row const Index nGroupId = itemID.get_group(1); // Work-group ID localCol const Index linearLocalThreadId = nLocalThreadId*LocalThreadSizeM + mLocalThreadId; // linear local thread ID // Allocate register space LhsScalar privateLhs; RhsScalar privateRhs[WorkLoadPerThreadN]; OutScalar privateRes[WorkLoadPerThreadM][WorkLoadPerThreadN]; // Initialise the privateResumulation registers for (Index wLPTM=0; wLPTM(0); } } // Tile Lhs for (Index lPTL=0; lPTL(0); } // Tile Rhs for (Index lPTR=0; lPTR(0); } // Loop over all tiles const Index numTiles = roundUpK/TileSizeDimK; Index firstHalf=0; do { // Synchronise itemID.barrier(cl::sycl::access::fence_space::local_space); // Load the next tile of Lhs and Rhs into local memory Index nextHalf = firstHalf + 1; if (nextHalf < numTiles) { // Tile A for (Index lPTL=0; lPTL(0); } // Tile B for (Index lPTR=0; lPTR(0); } } // Loop over the values of a single tile for (Index k=0; k struct LaunchSyclKernels { static const Index TileSizeDimM = 32ul; // Tile size for dimension M static const Index TileSizeDimN = 32ul; // Tile size for dimension N static const Index TileSizeDimK = 16ul; // Tile size for dimension K static const Index WorkLoadPerThreadM = 4ul; // Work load per thread in dimension M static const Index WorkLoadPerThreadN = 4ul; // work load per thread in dimension N static const Index LocalThreadSizeM = (TileSizeDimM/WorkLoadPerThreadM); // Local thread size for the first dimension (M here) static const Index LocalThreadSizeN = (TileSizeDimN/WorkLoadPerThreadN); // Local thread size for the second dimension (N here) static const Index LoadPerThreadLhs = ((TileSizeDimK*WorkLoadPerThreadM*WorkLoadPerThreadN)/(TileSizeDimN)); // workload per thread for Lhs expression static const Index LoadPerThreadRhs = ((TileSizeDimK*WorkLoadPerThreadM*WorkLoadPerThreadN)/(TileSizeDimM)); // workload per thread for Rhs expression // RoundUp function to make sure that the global threadId is divisable by local threadId static Index RoundUp(Index x, Index y) { return ((((x) + (y) - 1) / (y))*(y)); } template< typename Self, typename OutScalar, typename ContractT, typename LeftNocontractT, typename RightNocontractT> static void Run(const Self& self, OutScalar* buffer, Index M, Index N, Index K, ContractT m_k_strides, ContractT m_left_contracting_strides, ContractT m_right_contracting_strides, LeftNocontractT m_i_strides, RightNocontractT m_j_strides, LeftNocontractT m_left_nocontract_strides, RightNocontractT m_right_nocontract_strides){ typedef typename Self::XprType HostExpr; typedef typename Eigen::internal::traits::_LhsNested LHSHostExpr; typedef typename Eigen::internal::traits::_RhsNested RHSHostExpr; typedef TensorEvaluator OrigLHSExpr; typedef TensorEvaluator OrigRHSExpr; typedef Eigen::TensorSycl::internal::FunctorExtractor LHSFunctorExpr; typedef Eigen::TensorSycl::internal::FunctorExtractor RHSFunctorExpr; // extract lhs functor list LHSFunctorExpr lhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.left_impl()); // extract rhs functor list RHSFunctorExpr rhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.right_impl()); Index roundUpK = RoundUp(K, TileSizeDimK); Index roundUpM = RoundUp(M, TileSizeDimM); Index roundUpN = RoundUp(N, TileSizeDimN); ptrdiff_t out_offset = self.device().get_offset(buffer); self.device().sycl_queue().submit([&](cl::sycl::handler &cgh) { /// work-around for gcc bug typedef decltype(Eigen::TensorSycl::internal::createTupleOfAccessors(cgh, self.left_impl())) LHSTupleType; /// work-around for gcc bug typedef decltype(Eigen::TensorSycl::internal::createTupleOfAccessors(cgh, self.right_impl())) RHSTupleType; // create lhs tuple of accessors LHSTupleType left_tuple_of_accessors = Eigen::TensorSycl::internal::createTupleOfAccessors(cgh, self.left_impl()); // create rhs tuple of accessors RHSTupleType right_tuple_of_accessors = Eigen::TensorSycl::internal::createTupleOfAccessors(cgh, self.right_impl()); // Local memory for elements of Lhs typedef cl::sycl::accessor LhsLocalAcc; LhsLocalAcc localLhs(cl::sycl::range<1>(2* TileSizeDimM * TileSizeDimK), cgh); // Local memory for elements of Rhs typedef cl::sycl::accessor RhsLocalAcc; RhsLocalAcc localRhs(cl::sycl::range<1>(2* TileSizeDimK * TileSizeDimN), cgh); typedef cl::sycl::accessor OutAccessor; //OutScalar memory OutAccessor out_res= self.device(). template get_sycl_accessor(cgh, buffer); // sycl parallel for cgh.parallel_for(cl::sycl::nd_range<2>(cl::sycl::range<2>(roundUpM/WorkLoadPerThreadM, roundUpN/WorkLoadPerThreadN), cl::sycl::range<2>(LocalThreadSizeM, LocalThreadSizeN)), KernelConstructor(lhs_functors, rhs_functors, localLhs, localRhs, out_res, out_offset, roundUpK, M, N, K, m_k_strides, m_left_contracting_strides, m_right_contracting_strides,m_i_strides, m_j_strides, m_left_nocontract_strides,m_right_nocontract_strides, left_tuple_of_accessors, right_tuple_of_accessors, Eigen::SyclKernelDevice())); }); self.device().asynchronousExec(); } }; } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H