From e46e7223817cfd982edec6d8e25c77e8e2493d78 Mon Sep 17 00:00:00 2001 From: Mehdi Goli Date: Mon, 16 Jan 2017 13:58:49 +0000 Subject: Adding Tensor ReverseOp; TensorStriding; TensorConversionOp; Modifying Tensor Contractsycl to be located in any place in the expression tree. --- .../Eigen/CXX11/src/Tensor/TensorContractionSycl.h | 131 ++++++++++----------- 1 file changed, 65 insertions(+), 66 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h index b170a1a5c..dc16f89e0 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h @@ -146,9 +146,9 @@ struct TensorEvaluatorm_device.memset(buffer, 0, m * n * sizeof(Scalar)); - LaunchSyclKernels::Run(*this, buffer, m, n, k, - this->m_k_strides, this->m_left_contracting_strides, this->m_right_contracting_strides, - this->m_i_strides, this->m_j_strides, this->m_left_nocontract_strides, this->m_right_nocontract_strides); + LaunchSyclKernels::Run(*this, buffer, m, n, k, + this->m_k_strides, this->m_left_contracting_strides, this->m_right_contracting_strides, + this->m_i_strides, this->m_j_strides, this->m_left_nocontract_strides, this->m_right_nocontract_strides); } // required by sycl to construct the expr on the device. Returns original left_impl const TensorEvaluator& left_impl() const { @@ -158,47 +158,18 @@ struct TensorEvaluator& right_impl() const { return choose(Cond(Layout) == static_cast(ColMajor)>(), this->m_rightImpl, this->m_leftImpl); } - // required by sycl to construct the expr on the device - const Indices& indices() const {return this->m_expr_indices;} }; -/// Dummy container on the device. This is used to avoid calling the constructor of TensorEvaluator for TensorContractionOp. This makes the code much faster. -template struct TensorEvaluatorContainer; -template -struct TensorEvaluatorContainer>{ - typedef Eigen::DefaultDevice Device; - typedef TensorContractionOp XprType; - typedef typename internal::remove_const::type Scalar; - typedef typename XprType::Index Index; - typedef typename XprType::CoeffReturnType CoeffReturnType; - typedef typename PacketType::type PacketReturnType; - enum { - Layout = TensorEvaluator::Layout, - }; - - typedef typename internal::conditional(Layout) == static_cast(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType; - typedef typename internal::conditional(Layout) == static_cast(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType; - typedef TensorEvaluator LeftEvaluator; - typedef TensorEvaluator RightEvaluator; - - TensorEvaluatorContainer(const XprType& op, const Eigen::DefaultDevice& device) - : m_leftImpl(choose(Cond(Layout) == static_cast(ColMajor)>(), - op.lhsExpression(), op.rhsExpression()), device), - m_rightImpl(choose(Cond(Layout) == static_cast(ColMajor)>(), - op.rhsExpression(), op.lhsExpression()), device){} -LeftEvaluator m_leftImpl; -RightEvaluator m_rightImpl; -}; - - -template struct KernelConstructor{ - - typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression::Type PlaceHolderExpr; - - FunctorExpr functors; +int LocalThreadSizeM, int LocalThreadSizeN, int LoadPerThreadLhs, int LoadPerThreadRhs, typename LHSTupleType, typename RHSTupleType, typename Device> struct KernelConstructor{ + typedef typename Eigen::internal::traits::_LhsNested LHSHostExpr; + typedef typename Eigen::internal::traits::_RhsNested RHSHostExpr; + typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression::Type LHSPlaceHolderExpr; + typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression::Type RHSPlaceHolderExpr; + LHSFunctorExpr lhs_functors; + RHSFunctorExpr rhs_functors; LhsLocalAcc localLhs; RhsLocalAcc localRhs; OutAccessor out_res; @@ -206,38 +177,50 @@ int LocalThreadSizeM, int LocalThreadSizeN, int LoadPerThreadLhs, int LoadPerThr ContractT m_k_strides, m_left_contracting_strides, m_right_contracting_strides; LeftNocontractT m_i_strides, m_left_nocontract_strides; RightNocontractT m_j_strides, m_right_nocontract_strides; - TupleType tuple_of_accessors; + LHSTupleType left_tuple_of_accessors; + RHSTupleType right_tuple_of_accessors; + Device dev; + - KernelConstructor(FunctorExpr functors_, LhsLocalAcc localLhs_, RhsLocalAcc localRhs_, OutAccessor out_res_, + KernelConstructor(LHSFunctorExpr lhs_functors_, RHSFunctorExpr rhs_functors_, LhsLocalAcc localLhs_, RhsLocalAcc localRhs_, OutAccessor out_res_, Index roundUpK_, Index M_, Index N_, Index K_, ContractT m_k_strides_, ContractT m_left_contracting_strides_, ContractT m_right_contracting_strides_, LeftNocontractT m_i_strides_, RightNocontractT m_j_strides_, - LeftNocontractT m_left_nocontract_strides_, RightNocontractT m_right_nocontract_strides_, TupleType tuple_of_accessors_) - :functors(functors_), localLhs(localLhs_), localRhs(localRhs_), out_res(out_res_), roundUpK(roundUpK_), M(M_), N(N_), K(K_), + LeftNocontractT m_left_nocontract_strides_, RightNocontractT m_right_nocontract_strides_, LHSTupleType left_tuple_of_accessors_, RHSTupleType right_tuple_of_accessors_, Device dev_) + :lhs_functors(lhs_functors_), rhs_functors(rhs_functors_), localLhs(localLhs_), localRhs(localRhs_), out_res(out_res_), roundUpK(roundUpK_), M(M_), N(N_), K(K_), m_k_strides(m_k_strides_), m_left_contracting_strides(m_left_contracting_strides_), m_right_contracting_strides(m_right_contracting_strides_), m_i_strides(m_i_strides_), m_left_nocontract_strides(m_left_nocontract_strides_), m_j_strides(m_j_strides_), m_right_nocontract_strides(m_right_nocontract_strides_), - tuple_of_accessors(tuple_of_accessors_){} + left_tuple_of_accessors(left_tuple_of_accessors_), right_tuple_of_accessors(right_tuple_of_accessors_), dev(dev_){} void operator()(cl::sycl::nd_item<1> itemID) { - typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression::Type DevExpr; - auto device_expr =Eigen::TensorSycl::internal::createDeviceExpression(functors, tuple_of_accessors); - auto device_evaluator = TensorEvaluatorContainer(device_expr.expr, Eigen::DefaultDevice()); - typedef TensorEvaluatorContainer DevEvaluator; + typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression::Type DevExpr; + typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression::Type LHSDevExpr; + typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression::Type RHSDevExpr; + auto lhs_dev_expr = Eigen::TensorSycl::internal::createDeviceExpression(lhs_functors, left_tuple_of_accessors); + auto rhs_dev_expr = Eigen::TensorSycl::internal::createDeviceExpression(rhs_functors, right_tuple_of_accessors); + typedef decltype(lhs_dev_expr.expr) LeftArgType; + typedef decltype(rhs_dev_expr.expr) RightArgType; + typedef typename internal::conditional(Eigen::internal::traits::Layout) == static_cast(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType; + typedef typename internal::conditional(Eigen::internal::traits::Layout) == static_cast(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType; + typedef TensorEvaluator LeftEvaluator; + typedef TensorEvaluator RightEvaluator; typedef internal::TensorContractionInputMapper LhsMapper; typedef internal::TensorContractionInputMapper RhsMapper; // initialize data mappers must happen inside the kernel for device eval - LhsMapper lhs(device_evaluator.m_leftImpl, m_left_nocontract_strides, m_i_strides, m_left_contracting_strides, m_k_strides); - RhsMapper rhs(device_evaluator.m_rightImpl, m_right_nocontract_strides, m_j_strides, m_right_contracting_strides, m_k_strides); + LhsMapper lhs(LeftEvaluator(choose(Cond(Eigen::internal::traits::Layout) == static_cast(ColMajor)>(), + lhs_dev_expr.expr, rhs_dev_expr.expr), dev), m_left_nocontract_strides, m_i_strides, m_left_contracting_strides, m_k_strides); + RhsMapper rhs(RightEvaluator(choose(Cond(Eigen::internal::traits::Layout) == static_cast(ColMajor)>(), + rhs_dev_expr.expr, lhs_dev_expr.expr),dev), m_right_nocontract_strides, m_j_strides, m_right_contracting_strides, m_k_strides); auto out_ptr = ConvertToActualTypeSycl(OutScalar, out_res); // Matmul Kernel // Thread identifiers @@ -327,7 +310,6 @@ int LocalThreadSizeM, int LocalThreadSizeN, int LoadPerThreadLhs, int LoadPerThr firstHalf++; } while (firstHalf::Type PlaceHolderExpr; - // typedef KernelNameConstructor KernelName; - auto functors = Eigen::TensorSycl::internal::extractFunctors(self); - typedef decltype(functors) FunctorExpr; + typedef typename Eigen::internal::traits::_LhsNested LHSHostExpr; + typedef typename Eigen::internal::traits::_RhsNested RHSHostExpr; + typedef TensorEvaluator OrigLHSExpr; + typedef TensorEvaluator OrigRHSExpr; + typedef Eigen::TensorSycl::internal::FunctorExtractor LHSFunctorExpr; + typedef Eigen::TensorSycl::internal::FunctorExtractor RHSFunctorExpr; + // extract lhs functor list + LHSFunctorExpr lhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.left_impl()); + // extract rhs functor list + RHSFunctorExpr rhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.left_impl()); + Index roundUpK = RoundUp(K, TileSizeDimK); Index roundUpM = RoundUp(M, TileSizeDimM); Index roundUpN = RoundUp(N, TileSizeDimN); + self.device().sycl_queue().submit([&](cl::sycl::handler &cgh) { - auto tuple_of_accessors = Eigen::TensorSycl::internal::createTupleOfAccessors(cgh, self); - typedef decltype(tuple_of_accessors) TupleType; + /// work-around for gcc bug + typedef decltype(Eigen::TensorSycl::internal::createTupleOfAccessors(cgh, self.left_impl())) LHSTupleType; + /// work-around for gcc bug + typedef decltype(Eigen::TensorSycl::internal::createTupleOfAccessors(cgh, self.right_impl())) RHSTupleType; + // create lhs tuple of accessors + LHSTupleType left_tuple_of_accessors = Eigen::TensorSycl::internal::createTupleOfAccessors(cgh, self.left_impl()); + // create rhs tuple of accessors + RHSTupleType right_tuple_of_accessors = Eigen::TensorSycl::internal::createTupleOfAccessors(cgh, self.right_impl()); + // Local memory for elements of Lhs typedef cl::sycl::accessor LhsLocalAcc; LhsLocalAcc localLhs(cl::sycl::range<1>(2* TileSizeDimM * TileSizeDimK), cgh); // Local memory for elements of Rhs typedef cl::sycl::accessor RhsLocalAcc; RhsLocalAcc localRhs(cl::sycl::range<1>(2* TileSizeDimK * TileSizeDimN), cgh); + + typedef cl::sycl::accessor OutAccessor; //OutScalar memory - auto out_res= self.device(). template get_sycl_accessor(cgh, buffer); - typedef decltype(out_res) OutAccessor; + OutAccessor out_res= self.device(). template get_sycl_accessor(cgh, buffer); + // sycl parallel for cgh.parallel_for(cl::sycl::nd_range<2>(cl::sycl::range<2>(roundUpM/WorkLoadPerThreadM, roundUpN/WorkLoadPerThreadN), cl::sycl::range<2>(LocalThreadSizeM, LocalThreadSizeN)), - KernelConstructor(functors, + WorkLoadPerThreadM, WorkLoadPerThreadN, LocalThreadSizeM, LocalThreadSizeN, LoadPerThreadLhs, LoadPerThreadRhs, LHSTupleType, RHSTupleType, Eigen::DefaultDevice>(lhs_functors, rhs_functors, localLhs, localRhs, out_res, roundUpK, M, N, K, m_k_strides, m_left_contracting_strides, m_right_contracting_strides,m_i_strides, m_j_strides, - m_left_nocontract_strides,m_right_nocontract_strides, tuple_of_accessors)); + m_left_nocontract_strides,m_right_nocontract_strides, left_tuple_of_accessors, right_tuple_of_accessors, Eigen::DefaultDevice())); }); self.device().asynchronousExec(); } -- cgit v1.2.3