diff options
author | Mehdi Goli <mehdi.goli@codeplay.com> | 2017-01-16 13:58:49 +0000 |
---|---|---|
committer | Mehdi Goli <mehdi.goli@codeplay.com> | 2017-01-16 13:58:49 +0000 |
commit | e46e7223817cfd982edec6d8e25c77e8e2493d78 (patch) | |
tree | 3b8345ae7bb7ab2434b117932aea51f016acf43d /unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h | |
parent | 23778a15d8570b4287820f540b719203e07cfb44 (diff) |
Adding Tensor ReverseOp; TensorStriding; TensorConversionOp; Modifying Tensor Contractsycl to be located in any place in the expression tree.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h | 131 |
1 files changed, 65 insertions, 66 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h index b170a1a5c..dc16f89e0 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h @@ -146,9 +146,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar) this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); - LaunchSyclKernels<LhsScalar, RhsScalar,lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered>::Run(*this, buffer, m, n, k, - this->m_k_strides, this->m_left_contracting_strides, this->m_right_contracting_strides, - this->m_i_strides, this->m_j_strides, this->m_left_nocontract_strides, this->m_right_nocontract_strides); + LaunchSyclKernels<LhsScalar, RhsScalar,lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered>::Run(*this, buffer, m, n, k, + this->m_k_strides, this->m_left_contracting_strides, this->m_right_contracting_strides, + this->m_i_strides, this->m_j_strides, this->m_left_nocontract_strides, this->m_right_nocontract_strides); } // required by sycl to construct the expr on the device. Returns original left_impl const TensorEvaluator<LeftArgType, Device>& left_impl() const { @@ -158,47 +158,18 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT const TensorEvaluator<RightArgType, Device>& right_impl() const { return choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), this->m_rightImpl, this->m_leftImpl); } - // required by sycl to construct the expr on the device - const Indices& indices() const {return this->m_expr_indices;} }; -/// Dummy container on the device. This is used to avoid calling the constructor of TensorEvaluator for TensorContractionOp. This makes the code much faster. -template<typename Expr> struct TensorEvaluatorContainer; -template<typename Indices, typename LeftArgType, typename RightArgType> -struct TensorEvaluatorContainer<TensorContractionOp<Indices, LeftArgType, RightArgType>>{ - typedef Eigen::DefaultDevice Device; - typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; - typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; - typedef typename XprType::Index Index; - typedef typename XprType::CoeffReturnType CoeffReturnType; - typedef typename PacketType<CoeffReturnType, Eigen::DefaultDevice>::type PacketReturnType; - enum { - Layout = TensorEvaluator<LeftArgType, Device>::Layout, - }; - - typedef typename internal::conditional<static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType; - typedef typename internal::conditional<static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType; - typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator; - typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator; - - TensorEvaluatorContainer(const XprType& op, const Eigen::DefaultDevice& device) - : m_leftImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), - op.lhsExpression(), op.rhsExpression()), device), - m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), - op.rhsExpression(), op.lhsExpression()), device){} -LeftEvaluator m_leftImpl; -RightEvaluator m_rightImpl; -}; - - -template <typename HostExpr, typename OutScalar, typename LhsScalar, typename RhsScalar, typename FunctorExpr, typename LhsLocalAcc, typename RhsLocalAcc, typename OutAccessor, typename Index, typename ContractT, typename LeftNocontractT, +template <typename HostExpr, typename OutScalar, typename LhsScalar, typename RhsScalar, typename LHSFunctorExpr, typename RHSFunctorExpr, typename LhsLocalAcc, typename RhsLocalAcc, typename OutAccessor, typename Index, typename ContractT, typename LeftNocontractT, typename RightNocontractT, bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int TileSizeDimM, int TileSizeDimN,int TileSizeDimK, int WorkLoadPerThreadM,int WorkLoadPerThreadN, -int LocalThreadSizeM, int LocalThreadSizeN, int LoadPerThreadLhs, int LoadPerThreadRhs, typename TupleType> struct KernelConstructor{ - - typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression<HostExpr>::Type PlaceHolderExpr; - - FunctorExpr functors; +int LocalThreadSizeM, int LocalThreadSizeN, int LoadPerThreadLhs, int LoadPerThreadRhs, typename LHSTupleType, typename RHSTupleType, typename Device> struct KernelConstructor{ + typedef typename Eigen::internal::traits<HostExpr>::_LhsNested LHSHostExpr; + typedef typename Eigen::internal::traits<HostExpr>::_RhsNested RHSHostExpr; + typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression<LHSHostExpr>::Type LHSPlaceHolderExpr; + typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression<RHSHostExpr>::Type RHSPlaceHolderExpr; + LHSFunctorExpr lhs_functors; + RHSFunctorExpr rhs_functors; LhsLocalAcc localLhs; RhsLocalAcc localRhs; OutAccessor out_res; @@ -206,38 +177,50 @@ int LocalThreadSizeM, int LocalThreadSizeN, int LoadPerThreadLhs, int LoadPerThr ContractT m_k_strides, m_left_contracting_strides, m_right_contracting_strides; LeftNocontractT m_i_strides, m_left_nocontract_strides; RightNocontractT m_j_strides, m_right_nocontract_strides; - TupleType tuple_of_accessors; + LHSTupleType left_tuple_of_accessors; + RHSTupleType right_tuple_of_accessors; + Device dev; + - KernelConstructor(FunctorExpr functors_, LhsLocalAcc localLhs_, RhsLocalAcc localRhs_, OutAccessor out_res_, + KernelConstructor(LHSFunctorExpr lhs_functors_, RHSFunctorExpr rhs_functors_, LhsLocalAcc localLhs_, RhsLocalAcc localRhs_, OutAccessor out_res_, Index roundUpK_, Index M_, Index N_, Index K_, ContractT m_k_strides_, ContractT m_left_contracting_strides_, ContractT m_right_contracting_strides_, LeftNocontractT m_i_strides_, RightNocontractT m_j_strides_, - LeftNocontractT m_left_nocontract_strides_, RightNocontractT m_right_nocontract_strides_, TupleType tuple_of_accessors_) - :functors(functors_), localLhs(localLhs_), localRhs(localRhs_), out_res(out_res_), roundUpK(roundUpK_), M(M_), N(N_), K(K_), + LeftNocontractT m_left_nocontract_strides_, RightNocontractT m_right_nocontract_strides_, LHSTupleType left_tuple_of_accessors_, RHSTupleType right_tuple_of_accessors_, Device dev_) + :lhs_functors(lhs_functors_), rhs_functors(rhs_functors_), localLhs(localLhs_), localRhs(localRhs_), out_res(out_res_), roundUpK(roundUpK_), M(M_), N(N_), K(K_), m_k_strides(m_k_strides_), m_left_contracting_strides(m_left_contracting_strides_), m_right_contracting_strides(m_right_contracting_strides_), m_i_strides(m_i_strides_), m_left_nocontract_strides(m_left_nocontract_strides_), m_j_strides(m_j_strides_), m_right_nocontract_strides(m_right_nocontract_strides_), - tuple_of_accessors(tuple_of_accessors_){} + left_tuple_of_accessors(left_tuple_of_accessors_), right_tuple_of_accessors(right_tuple_of_accessors_), dev(dev_){} void operator()(cl::sycl::nd_item<1> itemID) { - typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression<HostExpr>::Type DevExpr; - auto device_expr =Eigen::TensorSycl::internal::createDeviceExpression<DevExpr, PlaceHolderExpr>(functors, tuple_of_accessors); - auto device_evaluator = TensorEvaluatorContainer<DevExpr>(device_expr.expr, Eigen::DefaultDevice()); - typedef TensorEvaluatorContainer<DevExpr> DevEvaluator; + typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression<HostExpr>::Type DevExpr; + typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression<LHSHostExpr>::Type LHSDevExpr; + typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression<RHSHostExpr>::Type RHSDevExpr; + auto lhs_dev_expr = Eigen::TensorSycl::internal::createDeviceExpression<LHSDevExpr, LHSPlaceHolderExpr>(lhs_functors, left_tuple_of_accessors); + auto rhs_dev_expr = Eigen::TensorSycl::internal::createDeviceExpression<RHSDevExpr, RHSPlaceHolderExpr>(rhs_functors, right_tuple_of_accessors); + typedef decltype(lhs_dev_expr.expr) LeftArgType; + typedef decltype(rhs_dev_expr.expr) RightArgType; + typedef typename internal::conditional<static_cast<int>(Eigen::internal::traits<DevExpr>::Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType; + typedef typename internal::conditional<static_cast<int>(Eigen::internal::traits<DevExpr>::Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType; + typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator; + typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator; typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, - typename DevEvaluator::LeftEvaluator, LeftNocontractT, + LeftEvaluator, LeftNocontractT, ContractT, 1, lhs_inner_dim_contiguous, false, Unaligned, MakeGlobalPointer> LhsMapper; typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, - typename DevEvaluator::RightEvaluator, RightNocontractT, + RightEvaluator, RightNocontractT, ContractT, 1, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned, MakeGlobalPointer> RhsMapper; // initialize data mappers must happen inside the kernel for device eval - LhsMapper lhs(device_evaluator.m_leftImpl, m_left_nocontract_strides, m_i_strides, m_left_contracting_strides, m_k_strides); - RhsMapper rhs(device_evaluator.m_rightImpl, m_right_nocontract_strides, m_j_strides, m_right_contracting_strides, m_k_strides); + LhsMapper lhs(LeftEvaluator(choose(Cond<static_cast<int>(Eigen::internal::traits<DevExpr>::Layout) == static_cast<int>(ColMajor)>(), + lhs_dev_expr.expr, rhs_dev_expr.expr), dev), m_left_nocontract_strides, m_i_strides, m_left_contracting_strides, m_k_strides); + RhsMapper rhs(RightEvaluator(choose(Cond<static_cast<int>(Eigen::internal::traits<DevExpr>::Layout) == static_cast<int>(ColMajor)>(), + rhs_dev_expr.expr, lhs_dev_expr.expr),dev), m_right_nocontract_strides, m_j_strides, m_right_contracting_strides, m_k_strides); auto out_ptr = ConvertToActualTypeSycl(OutScalar, out_res); // Matmul Kernel // Thread identifiers @@ -327,7 +310,6 @@ int LocalThreadSizeM, int LocalThreadSizeN, int LoadPerThreadLhs, int LoadPerThr firstHalf++; } while (firstHalf<numTiles); - // Store the final results in C for (int wLPTM=0; wLPTM<WorkLoadPerThreadM; wLPTM++) { int globalRow = mGroupId*TileSizeDimM + mLocalThreadId + wLPTM*LocalThreadSizeM; @@ -364,35 +346,52 @@ template< typename Self, typename OutScalar, typename Index, typename ContractT, static void Run(const Self& self, OutScalar* buffer, Index M, Index N, Index K, ContractT m_k_strides, ContractT m_left_contracting_strides, ContractT m_right_contracting_strides, LeftNocontractT m_i_strides, RightNocontractT m_j_strides, LeftNocontractT m_left_nocontract_strides, RightNocontractT m_right_nocontract_strides){ - // create a tuple of accessors from Evaluator + typedef typename Self::XprType HostExpr; - // typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression<HostExpr>::Type PlaceHolderExpr; - // typedef KernelNameConstructor<PlaceHolderExpr, lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered> KernelName; - auto functors = Eigen::TensorSycl::internal::extractFunctors(self); - typedef decltype(functors) FunctorExpr; + typedef typename Eigen::internal::traits<HostExpr>::_LhsNested LHSHostExpr; + typedef typename Eigen::internal::traits<HostExpr>::_RhsNested RHSHostExpr; + typedef TensorEvaluator<LHSHostExpr, const Eigen::SyclDevice> OrigLHSExpr; + typedef TensorEvaluator<RHSHostExpr, const Eigen::SyclDevice> OrigRHSExpr; + typedef Eigen::TensorSycl::internal::FunctorExtractor<OrigLHSExpr> LHSFunctorExpr; + typedef Eigen::TensorSycl::internal::FunctorExtractor<OrigRHSExpr> RHSFunctorExpr; + // extract lhs functor list + LHSFunctorExpr lhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.left_impl()); + // extract rhs functor list + RHSFunctorExpr rhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.left_impl()); + Index roundUpK = RoundUp(K, TileSizeDimK); Index roundUpM = RoundUp(M, TileSizeDimM); Index roundUpN = RoundUp(N, TileSizeDimN); + self.device().sycl_queue().submit([&](cl::sycl::handler &cgh) { - auto tuple_of_accessors = Eigen::TensorSycl::internal::createTupleOfAccessors<Self>(cgh, self); - typedef decltype(tuple_of_accessors) TupleType; + /// work-around for gcc bug + typedef decltype(Eigen::TensorSycl::internal::createTupleOfAccessors<OrigLHSExpr>(cgh, self.left_impl())) LHSTupleType; + /// work-around for gcc bug + typedef decltype(Eigen::TensorSycl::internal::createTupleOfAccessors<OrigRHSExpr>(cgh, self.right_impl())) RHSTupleType; + // create lhs tuple of accessors + LHSTupleType left_tuple_of_accessors = Eigen::TensorSycl::internal::createTupleOfAccessors<OrigLHSExpr>(cgh, self.left_impl()); + // create rhs tuple of accessors + RHSTupleType right_tuple_of_accessors = Eigen::TensorSycl::internal::createTupleOfAccessors<OrigRHSExpr>(cgh, self.right_impl()); + // Local memory for elements of Lhs typedef cl::sycl::accessor<LhsScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> LhsLocalAcc; LhsLocalAcc localLhs(cl::sycl::range<1>(2* TileSizeDimM * TileSizeDimK), cgh); // Local memory for elements of Rhs typedef cl::sycl::accessor<RhsScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> RhsLocalAcc; RhsLocalAcc localRhs(cl::sycl::range<1>(2* TileSizeDimK * TileSizeDimN), cgh); + + typedef cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::write, cl::sycl::access::target::global_buffer> OutAccessor; //OutScalar memory - auto out_res= self.device(). template get_sycl_accessor<cl::sycl::access::mode::write>(cgh, buffer); - typedef decltype(out_res) OutAccessor; + OutAccessor out_res= self.device(). template get_sycl_accessor<cl::sycl::access::mode::write>(cgh, buffer); + // sycl parallel for cgh.parallel_for(cl::sycl::nd_range<2>(cl::sycl::range<2>(roundUpM/WorkLoadPerThreadM, roundUpN/WorkLoadPerThreadN), cl::sycl::range<2>(LocalThreadSizeM, LocalThreadSizeN)), - KernelConstructor<HostExpr, OutScalar, LhsScalar, RhsScalar, FunctorExpr, LhsLocalAcc, RhsLocalAcc, OutAccessor, Index, ContractT, LeftNocontractT, + KernelConstructor<HostExpr, OutScalar, LhsScalar, RhsScalar, LHSFunctorExpr, RHSFunctorExpr, LhsLocalAcc, RhsLocalAcc, OutAccessor, Index, ContractT, LeftNocontractT, RightNocontractT, lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, TileSizeDimM, TileSizeDimN, TileSizeDimK, - WorkLoadPerThreadM, WorkLoadPerThreadN, LocalThreadSizeM, LocalThreadSizeN, LoadPerThreadLhs, LoadPerThreadRhs, TupleType>(functors, + WorkLoadPerThreadM, WorkLoadPerThreadN, LocalThreadSizeM, LocalThreadSizeN, LoadPerThreadLhs, LoadPerThreadRhs, LHSTupleType, RHSTupleType, Eigen::DefaultDevice>(lhs_functors, rhs_functors, localLhs, localRhs, out_res, roundUpK, M, N, K, m_k_strides, m_left_contracting_strides, m_right_contracting_strides,m_i_strides, m_j_strides, - m_left_nocontract_strides,m_right_nocontract_strides, tuple_of_accessors)); + m_left_nocontract_strides,m_right_nocontract_strides, left_tuple_of_accessors, right_tuple_of_accessors, Eigen::DefaultDevice())); }); self.device().asynchronousExec(); } |