diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2018-07-17 14:09:37 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2018-07-17 14:09:37 -0700 |
commit | c95aacab90e9d8bb9f9e082395b3b843a530fa41 (patch) | |
tree | 1bf812626899fb13c0e1cfc350ecbac95bd135c2 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h | |
parent | 038b55464b1d43612b88789f26006163ca638928 (diff) |
Fix TensorContractionOp evaluators for GPU and SYCL
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h index e6840bc87..35f931c53 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h @@ -23,15 +23,18 @@ namespace Eigen { template <typename Index, typename LhsScalar, typename RhsScalar,bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> struct LaunchSyclKernels; -template<typename Indices, typename LeftArgType, typename RightArgType> -struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, const Eigen::SyclDevice> : - public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, const Eigen::SyclDevice> > { +template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType> +struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, const Eigen::SyclDevice> : + public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, const Eigen::SyclDevice> > { + + static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value, + "SYCL tensor contraction does not support output kernels."); typedef const Eigen::SyclDevice Device; - typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self; + typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self; typedef TensorContractionEvaluatorBase<Self> Base; - typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; + typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType; typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; |