aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
diff options
context:
space:
mode:
authorGravatar Mehdi Goli <mehdi.goli@codeplay.com>2017-01-16 13:58:49 +0000
committerGravatar Mehdi Goli <mehdi.goli@codeplay.com>2017-01-16 13:58:49 +0000
commite46e7223817cfd982edec6d8e25c77e8e2493d78 (patch)
tree3b8345ae7bb7ab2434b117932aea51f016acf43d /unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
parent23778a15d8570b4287820f540b719203e07cfb44 (diff)
Adding Tensor ReverseOp; TensorStriding; TensorConversionOp; Modifying Tensor Contractsycl to be located in any place in the expression tree.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h131
1 files changed, 65 insertions, 66 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
index b170a1a5c..dc16f89e0 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h
@@ -146,9 +146,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
- LaunchSyclKernels<LhsScalar, RhsScalar,lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered>::Run(*this, buffer, m, n, k,
- this->m_k_strides, this->m_left_contracting_strides, this->m_right_contracting_strides,
- this->m_i_strides, this->m_j_strides, this->m_left_nocontract_strides, this->m_right_nocontract_strides);
+ LaunchSyclKernels<LhsScalar, RhsScalar,lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered>::Run(*this, buffer, m, n, k,
+ this->m_k_strides, this->m_left_contracting_strides, this->m_right_contracting_strides,
+ this->m_i_strides, this->m_j_strides, this->m_left_nocontract_strides, this->m_right_nocontract_strides);
}
// required by sycl to construct the expr on the device. Returns original left_impl
const TensorEvaluator<LeftArgType, Device>& left_impl() const {
@@ -158,47 +158,18 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
const TensorEvaluator<RightArgType, Device>& right_impl() const {
return choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), this->m_rightImpl, this->m_leftImpl);
}
- // required by sycl to construct the expr on the device
- const Indices& indices() const {return this->m_expr_indices;}
};
-/// Dummy container on the device. This is used to avoid calling the constructor of TensorEvaluator for TensorContractionOp. This makes the code much faster.
-template<typename Expr> struct TensorEvaluatorContainer;
-template<typename Indices, typename LeftArgType, typename RightArgType>
-struct TensorEvaluatorContainer<TensorContractionOp<Indices, LeftArgType, RightArgType>>{
- typedef Eigen::DefaultDevice Device;
- typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
- typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
- typedef typename XprType::Index Index;
- typedef typename XprType::CoeffReturnType CoeffReturnType;
- typedef typename PacketType<CoeffReturnType, Eigen::DefaultDevice>::type PacketReturnType;
- enum {
- Layout = TensorEvaluator<LeftArgType, Device>::Layout,
- };
-
- typedef typename internal::conditional<static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
- typedef typename internal::conditional<static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
- typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
- typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
-
- TensorEvaluatorContainer(const XprType& op, const Eigen::DefaultDevice& device)
- : m_leftImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
- op.lhsExpression(), op.rhsExpression()), device),
- m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
- op.rhsExpression(), op.lhsExpression()), device){}
-LeftEvaluator m_leftImpl;
-RightEvaluator m_rightImpl;
-};
-
-
-template <typename HostExpr, typename OutScalar, typename LhsScalar, typename RhsScalar, typename FunctorExpr, typename LhsLocalAcc, typename RhsLocalAcc, typename OutAccessor, typename Index, typename ContractT, typename LeftNocontractT,
+template <typename HostExpr, typename OutScalar, typename LhsScalar, typename RhsScalar, typename LHSFunctorExpr, typename RHSFunctorExpr, typename LhsLocalAcc, typename RhsLocalAcc, typename OutAccessor, typename Index, typename ContractT, typename LeftNocontractT,
typename RightNocontractT, bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered,
int TileSizeDimM, int TileSizeDimN,int TileSizeDimK, int WorkLoadPerThreadM,int WorkLoadPerThreadN,
-int LocalThreadSizeM, int LocalThreadSizeN, int LoadPerThreadLhs, int LoadPerThreadRhs, typename TupleType> struct KernelConstructor{
-
- typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression<HostExpr>::Type PlaceHolderExpr;
-
- FunctorExpr functors;
+int LocalThreadSizeM, int LocalThreadSizeN, int LoadPerThreadLhs, int LoadPerThreadRhs, typename LHSTupleType, typename RHSTupleType, typename Device> struct KernelConstructor{
+ typedef typename Eigen::internal::traits<HostExpr>::_LhsNested LHSHostExpr;
+ typedef typename Eigen::internal::traits<HostExpr>::_RhsNested RHSHostExpr;
+ typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression<LHSHostExpr>::Type LHSPlaceHolderExpr;
+ typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression<RHSHostExpr>::Type RHSPlaceHolderExpr;
+ LHSFunctorExpr lhs_functors;
+ RHSFunctorExpr rhs_functors;
LhsLocalAcc localLhs;
RhsLocalAcc localRhs;
OutAccessor out_res;
@@ -206,38 +177,50 @@ int LocalThreadSizeM, int LocalThreadSizeN, int LoadPerThreadLhs, int LoadPerThr
ContractT m_k_strides, m_left_contracting_strides, m_right_contracting_strides;
LeftNocontractT m_i_strides, m_left_nocontract_strides;
RightNocontractT m_j_strides, m_right_nocontract_strides;
- TupleType tuple_of_accessors;
+ LHSTupleType left_tuple_of_accessors;
+ RHSTupleType right_tuple_of_accessors;
+ Device dev;
+
- KernelConstructor(FunctorExpr functors_, LhsLocalAcc localLhs_, RhsLocalAcc localRhs_, OutAccessor out_res_,
+ KernelConstructor(LHSFunctorExpr lhs_functors_, RHSFunctorExpr rhs_functors_, LhsLocalAcc localLhs_, RhsLocalAcc localRhs_, OutAccessor out_res_,
Index roundUpK_, Index M_, Index N_, Index K_, ContractT m_k_strides_, ContractT m_left_contracting_strides_,
ContractT m_right_contracting_strides_, LeftNocontractT m_i_strides_, RightNocontractT m_j_strides_,
- LeftNocontractT m_left_nocontract_strides_, RightNocontractT m_right_nocontract_strides_, TupleType tuple_of_accessors_)
- :functors(functors_), localLhs(localLhs_), localRhs(localRhs_), out_res(out_res_), roundUpK(roundUpK_), M(M_), N(N_), K(K_),
+ LeftNocontractT m_left_nocontract_strides_, RightNocontractT m_right_nocontract_strides_, LHSTupleType left_tuple_of_accessors_, RHSTupleType right_tuple_of_accessors_, Device dev_)
+ :lhs_functors(lhs_functors_), rhs_functors(rhs_functors_), localLhs(localLhs_), localRhs(localRhs_), out_res(out_res_), roundUpK(roundUpK_), M(M_), N(N_), K(K_),
m_k_strides(m_k_strides_), m_left_contracting_strides(m_left_contracting_strides_),
m_right_contracting_strides(m_right_contracting_strides_),
m_i_strides(m_i_strides_), m_left_nocontract_strides(m_left_nocontract_strides_),
m_j_strides(m_j_strides_), m_right_nocontract_strides(m_right_nocontract_strides_),
- tuple_of_accessors(tuple_of_accessors_){}
+ left_tuple_of_accessors(left_tuple_of_accessors_), right_tuple_of_accessors(right_tuple_of_accessors_), dev(dev_){}
void operator()(cl::sycl::nd_item<1> itemID) {
- typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression<HostExpr>::Type DevExpr;
- auto device_expr =Eigen::TensorSycl::internal::createDeviceExpression<DevExpr, PlaceHolderExpr>(functors, tuple_of_accessors);
- auto device_evaluator = TensorEvaluatorContainer<DevExpr>(device_expr.expr, Eigen::DefaultDevice());
- typedef TensorEvaluatorContainer<DevExpr> DevEvaluator;
+ typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression<HostExpr>::Type DevExpr;
+ typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression<LHSHostExpr>::Type LHSDevExpr;
+ typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression<RHSHostExpr>::Type RHSDevExpr;
+ auto lhs_dev_expr = Eigen::TensorSycl::internal::createDeviceExpression<LHSDevExpr, LHSPlaceHolderExpr>(lhs_functors, left_tuple_of_accessors);
+ auto rhs_dev_expr = Eigen::TensorSycl::internal::createDeviceExpression<RHSDevExpr, RHSPlaceHolderExpr>(rhs_functors, right_tuple_of_accessors);
+ typedef decltype(lhs_dev_expr.expr) LeftArgType;
+ typedef decltype(rhs_dev_expr.expr) RightArgType;
+ typedef typename internal::conditional<static_cast<int>(Eigen::internal::traits<DevExpr>::Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
+ typedef typename internal::conditional<static_cast<int>(Eigen::internal::traits<DevExpr>::Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
+ typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
+ typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
- typename DevEvaluator::LeftEvaluator, LeftNocontractT,
+ LeftEvaluator, LeftNocontractT,
ContractT, 1,
lhs_inner_dim_contiguous,
false, Unaligned, MakeGlobalPointer> LhsMapper;
typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
- typename DevEvaluator::RightEvaluator, RightNocontractT,
+ RightEvaluator, RightNocontractT,
ContractT, 1,
rhs_inner_dim_contiguous,
rhs_inner_dim_reordered, Unaligned, MakeGlobalPointer> RhsMapper;
// initialize data mappers must happen inside the kernel for device eval
- LhsMapper lhs(device_evaluator.m_leftImpl, m_left_nocontract_strides, m_i_strides, m_left_contracting_strides, m_k_strides);
- RhsMapper rhs(device_evaluator.m_rightImpl, m_right_nocontract_strides, m_j_strides, m_right_contracting_strides, m_k_strides);
+ LhsMapper lhs(LeftEvaluator(choose(Cond<static_cast<int>(Eigen::internal::traits<DevExpr>::Layout) == static_cast<int>(ColMajor)>(),
+ lhs_dev_expr.expr, rhs_dev_expr.expr), dev), m_left_nocontract_strides, m_i_strides, m_left_contracting_strides, m_k_strides);
+ RhsMapper rhs(RightEvaluator(choose(Cond<static_cast<int>(Eigen::internal::traits<DevExpr>::Layout) == static_cast<int>(ColMajor)>(),
+ rhs_dev_expr.expr, lhs_dev_expr.expr),dev), m_right_nocontract_strides, m_j_strides, m_right_contracting_strides, m_k_strides);
auto out_ptr = ConvertToActualTypeSycl(OutScalar, out_res);
// Matmul Kernel
// Thread identifiers
@@ -327,7 +310,6 @@ int LocalThreadSizeM, int LocalThreadSizeN, int LoadPerThreadLhs, int LoadPerThr
firstHalf++;
} while (firstHalf<numTiles);
-
// Store the final results in C
for (int wLPTM=0; wLPTM<WorkLoadPerThreadM; wLPTM++) {
int globalRow = mGroupId*TileSizeDimM + mLocalThreadId + wLPTM*LocalThreadSizeM;
@@ -364,35 +346,52 @@ template< typename Self, typename OutScalar, typename Index, typename ContractT,
static void Run(const Self& self, OutScalar* buffer, Index M, Index N, Index K,
ContractT m_k_strides, ContractT m_left_contracting_strides, ContractT m_right_contracting_strides,
LeftNocontractT m_i_strides, RightNocontractT m_j_strides, LeftNocontractT m_left_nocontract_strides, RightNocontractT m_right_nocontract_strides){
- // create a tuple of accessors from Evaluator
+
typedef typename Self::XprType HostExpr;
- // typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression<HostExpr>::Type PlaceHolderExpr;
- // typedef KernelNameConstructor<PlaceHolderExpr, lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered> KernelName;
- auto functors = Eigen::TensorSycl::internal::extractFunctors(self);
- typedef decltype(functors) FunctorExpr;
+ typedef typename Eigen::internal::traits<HostExpr>::_LhsNested LHSHostExpr;
+ typedef typename Eigen::internal::traits<HostExpr>::_RhsNested RHSHostExpr;
+ typedef TensorEvaluator<LHSHostExpr, const Eigen::SyclDevice> OrigLHSExpr;
+ typedef TensorEvaluator<RHSHostExpr, const Eigen::SyclDevice> OrigRHSExpr;
+ typedef Eigen::TensorSycl::internal::FunctorExtractor<OrigLHSExpr> LHSFunctorExpr;
+ typedef Eigen::TensorSycl::internal::FunctorExtractor<OrigRHSExpr> RHSFunctorExpr;
+ // extract lhs functor list
+ LHSFunctorExpr lhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.left_impl());
+ // extract rhs functor list
+ RHSFunctorExpr rhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.left_impl());
+
Index roundUpK = RoundUp(K, TileSizeDimK);
Index roundUpM = RoundUp(M, TileSizeDimM);
Index roundUpN = RoundUp(N, TileSizeDimN);
+
self.device().sycl_queue().submit([&](cl::sycl::handler &cgh) {
- auto tuple_of_accessors = Eigen::TensorSycl::internal::createTupleOfAccessors<Self>(cgh, self);
- typedef decltype(tuple_of_accessors) TupleType;
+ /// work-around for gcc bug
+ typedef decltype(Eigen::TensorSycl::internal::createTupleOfAccessors<OrigLHSExpr>(cgh, self.left_impl())) LHSTupleType;
+ /// work-around for gcc bug
+ typedef decltype(Eigen::TensorSycl::internal::createTupleOfAccessors<OrigRHSExpr>(cgh, self.right_impl())) RHSTupleType;
+ // create lhs tuple of accessors
+ LHSTupleType left_tuple_of_accessors = Eigen::TensorSycl::internal::createTupleOfAccessors<OrigLHSExpr>(cgh, self.left_impl());
+ // create rhs tuple of accessors
+ RHSTupleType right_tuple_of_accessors = Eigen::TensorSycl::internal::createTupleOfAccessors<OrigRHSExpr>(cgh, self.right_impl());
+
// Local memory for elements of Lhs
typedef cl::sycl::accessor<LhsScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> LhsLocalAcc;
LhsLocalAcc localLhs(cl::sycl::range<1>(2* TileSizeDimM * TileSizeDimK), cgh);
// Local memory for elements of Rhs
typedef cl::sycl::accessor<RhsScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> RhsLocalAcc;
RhsLocalAcc localRhs(cl::sycl::range<1>(2* TileSizeDimK * TileSizeDimN), cgh);
+
+ typedef cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::write, cl::sycl::access::target::global_buffer> OutAccessor;
//OutScalar memory
- auto out_res= self.device(). template get_sycl_accessor<cl::sycl::access::mode::write>(cgh, buffer);
- typedef decltype(out_res) OutAccessor;
+ OutAccessor out_res= self.device(). template get_sycl_accessor<cl::sycl::access::mode::write>(cgh, buffer);
+
// sycl parallel for
cgh.parallel_for(cl::sycl::nd_range<2>(cl::sycl::range<2>(roundUpM/WorkLoadPerThreadM, roundUpN/WorkLoadPerThreadN),
cl::sycl::range<2>(LocalThreadSizeM, LocalThreadSizeN)),
- KernelConstructor<HostExpr, OutScalar, LhsScalar, RhsScalar, FunctorExpr, LhsLocalAcc, RhsLocalAcc, OutAccessor, Index, ContractT, LeftNocontractT,
+ KernelConstructor<HostExpr, OutScalar, LhsScalar, RhsScalar, LHSFunctorExpr, RHSFunctorExpr, LhsLocalAcc, RhsLocalAcc, OutAccessor, Index, ContractT, LeftNocontractT,
RightNocontractT, lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, TileSizeDimM, TileSizeDimN, TileSizeDimK,
- WorkLoadPerThreadM, WorkLoadPerThreadN, LocalThreadSizeM, LocalThreadSizeN, LoadPerThreadLhs, LoadPerThreadRhs, TupleType>(functors,
+ WorkLoadPerThreadM, WorkLoadPerThreadN, LocalThreadSizeM, LocalThreadSizeN, LoadPerThreadLhs, LoadPerThreadRhs, LHSTupleType, RHSTupleType, Eigen::DefaultDevice>(lhs_functors, rhs_functors,
localLhs, localRhs, out_res, roundUpK, M, N, K, m_k_strides, m_left_contracting_strides, m_right_contracting_strides,m_i_strides, m_j_strides,
- m_left_nocontract_strides,m_right_nocontract_strides, tuple_of_accessors));
+ m_left_nocontract_strides,m_right_nocontract_strides, left_tuple_of_accessors, right_tuple_of_accessors, Eigen::DefaultDevice()));
});
self.device().asynchronousExec();
}